问题描述
我正在从事一个项目,该项目要求我在进行 PyTorch 培训之前从 csv 文件中获取数据和标签。 csv文件的结构为:
形象,动物,灭绝
blabla.png,1
................
等等..
from torch.utils.data import Dataset
class CallDataset(Dataset):
# Todo implement the Dataset class according to the description
def __init__(self,data,mode):
#self.img,self.lbl = data
self.data = data
self.mode = mode
self._transform = transforms.Compose([
transforms.ToPILImage(),transforms.ToTensor(),transforms.normalize(mean=train_mean,std=train_std)
])
# self.tab = pd.read_csv(self.data,sep=';')
self.image = data.iloc[0:,0]
self.label = data.iloc[0:,1:]
#self.image = self.img
#self.label = self.lbl
def __len__(self):
return len(self.image)
def __getitem__(self,index):
if self.mode == 'train':
x = imread(self.image[index])
sample = gray2rgb(x)
transformed_img = self._transform(sample)
transformed_label = self.label.to_numpy()
transformed_label = transformed_label[index]
transformed_label = torch.from_numpy(transformed_label)
return (transformed_img,transformed_label)
if self.mode == 'val':
x = imread(self.image[index])
sample = gray2rgb(x)
transformed_img = self._transform(sample)
transformed_label = self.label.to_numpy()
transformed_label = transformed_label[index]
transformed_label = torch.from_numpy(transformed_label)
return (transformed_img,transformed_label)
所以这个想法是返回一个包含 2 个火炬张量的元组。第一个是形状(1、3、高度、宽度)的RGB变换图像张量,第二个是包含标签(是否是动物以及是否已灭绝)的火炬张量。它的形状如果 (1,2)。
csv 文件和对应的图片文件夹分别包含 1000 张图片和 1000 组标签。
对于培训,我试图这样称呼它们:
def train_epoch(self):
self._model = self._model.train()
l = 0
c = 0
for x,y_true in self._train_dl:
c +=1
if self._cuda:
x = x.to('cuda')
y_true = y_true.to('cuda')
l += self.train_step(x,y_true)
avg_loss = l/c
return avg_loss
self.train_dl 定义如下:
def train_dataset():
train_data = CallDataset(train,'train')
return train_data
train_dl = t.utils.data.DataLoader(train_dataset(),batch_size=200,shuffle=True)
所以这个想法是在每次迭代时获取图像和相应的标签,将它们发送到 cuda,训练它们(train_step 是我调用模型的函数,向前跑,丢失,向后跑等)。初始损失 l = 0 是我定义 l 的地方。 c 基本上存储了我需要找到平均损失的交互次数(此时 l,c 不打扰我)。
我的问题是我是否正确使用了 DataLoader。另外,请让我知道使用它们的最佳方式并提供解释。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)