问题描述
我是 pytorch 的新手。 我的大数据集由两个 txt 文件组成,一个用于数据,另一个用于目标数据。 在训练文件中每行是长度为 340 的列表,在目标中每行是长度为 136 的列表。
我想问一下如何定义我的数据集,以便我可以使用 DataLoader 加载我的数据来训练 pytorch 模型?
我希望你回答
解决方法
Dataset
中的 torch.utils.data
是表示数据集的抽象类。您的自定义数据集应继承 Dataset 并覆盖以下方法:
__len__()
使 len(dataset) 返回数据集的大小。__getitem__()
支持索引,使得 dataset[i] 可用于获取第 i 个样本
例如编写自定义数据集
我已经为您编写了一个通用的自定义数据加载器作为您的问题陈述。
这里 data.txt 有数据,label.txt 有标签。
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self):
with open('data.txt','r') as f:
self.data_info = f.readlines()
with open('label.txt','r') as f:
self.label_info = f.readlines()
def __getitem__(self,index):
single_data = self.data_info[index].rstrip('\n')
single_label = self.label_info[index].rstrip('\n')
return ( single_data,single_label)
def __len__(self):
return len(self.data_info)
# Testing
d = CustomDataset()
print(d[1]) # should output data along with label
这将是您案例的基础,但必须进行一些与您的案例相匹配的更改。
注意:您必须根据数据集进行必要的更改