问题描述
在我的数据集中,有 6 个类,每个类 23 张图片
我使用 torchvision.dataset
制作 ImageFolder
并且效果很好。
dataset = vision_dataset.ImageFolder(root = DATA_ROOT,transform = vision_trans.Compose([
vision_trans.Resize(256),vision_trans.CenterCrop(256),vision_trans.ToTensor()
]))
DataLoader = torch.utils.data.DataLoader(dataset = dataset,batch_size = SHOT_K,shuffle = False,num_workers = 2,)
但我想获得具有相同类的批量图像。
...
tensor([2,2,2])
tensor([2,2])
tensor([3,3,3])
...
这就是我想要的标签(批处理数据的类)形式
但实际上 DataLoader 会这样工作
...
tensor([2,3])
tensor([3,3])
...
解决方法
使用 ImageFolder
无法方便地做到这一点。您应该为每个类创建一个数据集,并从您需要的数据集中加载您的批次。
更具体地说,假设您的文件夹结构是 ImageFolder 所需的结构,您需要创建一个小型数据集类:
class ImageSubFolder(torch.utils.data.Dataset):
def __init__(self,root_dir,label):
# Path toward the label-sorted subfolders of your dataset
# Assuming images are named smthg like /path/to/label/xxxx.npy
self._path = root_dir + label+ "{:04d}"
def __len__(self):
return count_files_in_directory(self._path)
def __getitem__(self,index):
return (np.load(self._path.format(index),label)
这只是为了展示类的逻辑,我相信您仍然需要实现一些功能(您可以遵循this tutorial)。 “要实现的其余功能留给读者作为练习”。无论如何,对于这个类,您只需要创建它的 6 个实例(每个类一个):
loaders = {}
for label in ("dog","cat","plane","tree","mug","car"):
dataset = SubFolderDataset(DATA_ROOT,label)
loaders[label] = torch.utils.data.DataLoader(dataset = dataset,batch_size = SHOT_K,shuffle = False,num_workers = 2,)
现在你有一个包含数据加载器的字典,它只加载给定类的样本。