创建一个 pyTorch 测试数据集不带标签

问题描述

我为我的训练数据创建了一个 pyTorch 数据集,其中包含特征和标签,以便能够使用 this 教程使用 pyTorch DataLoader。这对我的训练数据很有效,但在加载测试 csv 文件时出现错误 (KeyError: "['label'] not found in axis"),该文件除了没有“标签”列之外是相同的。

如果有帮助,预期的输入 csv 文件是 csv 文件中的 MNIST 数据,其中包含 28*28 个特征列。

import torch

class mnist(torch.utils.data.Dataset):
    
    def __init__(self,csv_file):
        self.train = pd.read_csv(csv_file)
        self.train_x = self.train.drop("label",axis=1)
    
    def __len__(self):
        return len(self.train)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if isinstance(idx,list):
            idx_len = len(idx)
        else:
            idx_len = 1
        
        X = np.asarray(self.train_x.iloc[idx],dtype=np.float32)
        X = np.reshape(X,(1,28,28))
        y = np.asarray(self.train.iloc[idx]['label'])
        
        sample = {'X': X,'y':y}
        
        return torch.from_numpy(sample['X']),torch.from_numpy(sample['y'])

解决方法

您应该能够使用这两种数据:

import torch

class mnist(torch.utils.data.Dataset):
    
    def __init__(self,csv_file):
        self.train = pd.read_csv(csv_file)

        self.training = "label" in self.train.columns
        self.train_x = self.train if not self.training else self.train.drop("label",axis=1)
    
    def __len__(self):
        return len(self.train)
    
    def __getitem__(self,idx):
        ...
        
        X = np.asarray(self.train_x.iloc[idx],dtype=np.float32)
        X = np.reshape(X,(1,28,28))
        if not self.training:
            return torch.from_numpy(X])

        y = np.asarray(self.train.iloc[idx]['label'])

        sample = {'X': X,'y':y}
        return torch.from_numpy(sample['X']),torch.from_numpy(sample['y'])

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...