如何使用 Torch Dataloader 获取具有相同类的图片?

问题描述

在我的数据集中,有 6 个类,每个类 23 张图片
我使用 torchvision.dataset 制作 ImageFolder 并且效果很好。

dataset = vision_dataset.ImageFolder(root = DATA_ROOT,transform = vision_trans.Compose([
                                                    vision_trans.Resize(256),vision_trans.CenterCrop(256),vision_trans.ToTensor()
                                     ]))

DataLoader = torch.utils.data.DataLoader(dataset = dataset,batch_size = SHOT_K,shuffle = False,num_workers = 2,)

但我想获得具有相同类的批量图像。

...
tensor([2,2,2])
tensor([2,2])
tensor([3,3,3])
...

这就是我想要的标签(批处理数据的类)形式
但实际上 DataLoader 会这样工作

...
tensor([2,3])
tensor([3,3])
...

如何获取每个标签的批次数据?

解决方法

使用 ImageFolder 无法方便地做到这一点。您应该为每个类创建一个数据集,并从您需要的数据集中加载您的批次。

更具体地说,假设您的文件夹结构是 ImageFolder 所需的结构,您需要创建一个小型数据集类:

class ImageSubFolder(torch.utils.data.Dataset):
    def __init__(self,root_dir,label):
        # Path toward the label-sorted subfolders of your dataset
        # Assuming images are named smthg like /path/to/label/xxxx.npy
        self._path = root_dir + label+ "{:04d}"

    def __len__(self):
        return count_files_in_directory(self._path)

    def __getitem__(self,index):
        return (np.load(self._path.format(index),label)

这只是为了展示类的逻辑,我相信您仍然需要实现一些功能(您可以遵循this tutorial)。 “要实现的其余功能留给读者作为练习”。无论如何,对于这个类,您只需要创建它的 6 个实例(每个类一个):

loaders = {}
for label in ("dog","cat","plane","tree","mug","car"):
    dataset = SubFolderDataset(DATA_ROOT,label)
    loaders[label] = torch.utils.data.DataLoader(dataset = dataset,batch_size = SHOT_K,shuffle = False,num_workers = 2,)

现在你有一个包含数据加载器的字典,它只加载给定类的样本。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...