如何仅从PyTorch的FashionMNIST数据集中获取特定的类?

问题描述

FashionMNIST数据集具有10个不同的输出类别。如何仅使用特定类来获取此数据集的子集?就我而言,我只想要运动鞋,套头衫,凉鞋和衬衫类的图片(它们的类别分别为7,2,5和6)。

这就是我加载数据集的方式。

train_dataset_full = torchvision.datasets.FashionMNIST(data_folder,train = True,download = True,transform = transforms.ToTensor())

我遵循的方法如下。 依次遍历数据集,然后将返回的元组中的第一个元素(即类)与我所需的类进行比较。我被困在这里。如果返回的值为true,如何将这个观察值追加/添加到空数据集?

sneaker = 0
pullover = 0
sandal = 0
shirt = 0
for i in range(60000):
    if train_dataset_full[i][1] == 7:
        sneaker += 1
    elif train_dataset_full[i][1] == 2:
        pullover += 1
    elif train_dataset_full[i][1] == 5:
        sandal += 1
    elif train_dataset_full[i][1] == 6:
        shirt += 1

现在,我想代替sneaker += 1pullover += 1sandal += 1shirt += 1做类似empty_dataset.append(train_dataset_full[i])或类似的事情。

如果上述方法不正确,请提出另一种方法

解决方法

最后找到了答案。

dataset_full = torchvision.datasets.FashionMNIST(data_folder,train = True,download = True,transform = transforms.ToTensor())
# Selecting classes 7,2,5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
,

您可以使用列表理解来匹配标签。例如

idx = dataset.train_labels == 1
dataset.train_labels = dataset.train_labels[idx]

那只会选择您想要的标签。