问题描述
FashionMNIST数据集具有10个不同的输出类别。如何仅使用特定类来获取此数据集的子集?就我而言,我只想要运动鞋,套头衫,凉鞋和衬衫类的图片(它们的类别分别为7,2,5和6)。
这就是我加载数据集的方式。
train_dataset_full = torchvision.datasets.FashionMNIST(data_folder,train = True,download = True,transform = transforms.ToTensor())
我遵循的方法如下。 依次遍历数据集,然后将返回的元组中的第一个元素(即类)与我所需的类进行比较。我被困在这里。如果返回的值为true,如何将这个观察值追加/添加到空数据集?
sneaker = 0
pullover = 0
sandal = 0
shirt = 0
for i in range(60000):
if train_dataset_full[i][1] == 7:
sneaker += 1
elif train_dataset_full[i][1] == 2:
pullover += 1
elif train_dataset_full[i][1] == 5:
sandal += 1
elif train_dataset_full[i][1] == 6:
shirt += 1
现在,我想代替sneaker += 1
,pullover += 1
,sandal += 1
和shirt += 1
做类似empty_dataset.append(train_dataset_full[i])
或类似的事情。
解决方法
最后找到了答案。
dataset_full = torchvision.datasets.FashionMNIST(data_folder,train = True,download = True,transform = transforms.ToTensor())
# Selecting classes 7,2,5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
,
您可以使用列表理解来匹配标签。例如
idx = dataset.train_labels == 1
dataset.train_labels = dataset.train_labels[idx]
那只会选择您想要的标签。