如何使用来自Torchvision MNIST数据集中的一个原始批次和一个扩充批次,并进行混洗,“对齐”批次样本,但批次大小不同?

问题描述

我想为Torchvision MNIST数据集实现这种情况,并使用DataLoader加载数据:

batch A (unaugmented images): 5,4,...
batch B (augmented images): 5*,5+,5-,0*,0+,0-,4*,4+,4-,...

...其中,对于A的每个图像,批次B中都有3个增强。len(B)= 3 * len(A)对应。这些批次应在一次迭代中使用,以将批次A的原始图像与批次B中增强的图像进行比较,以建立损失。

class MyMNIST(Dataset):

def __init__(self,mnist_dir,train,augmented,transform=None,repeat=1):

    self.mnist_dir = mnist_dir
    self.train = train
    self.augmented = augmented
    self.repeat = repeat
    self.transform = transform
    self.dataset = None

    if augmented and train:
        self.dataset = datasets.MNIST(self.mnist_dir,train=train,download=True,transform=transform)
        self.dataset.data = torch.repeat_interleave(self.dataset.data,repeats=self.repeat,dim=0)
        self.dataset.targets = torch.repeat_interleave(self.dataset.targets,dim=0)
    elif augmented and not train:
        raise Exception("Test set should not be augmented.")
    else:
        self.dataset = datasets.MNIST(MNIST_DIR,transform=transform)

使用此类,我想初始化两个不同的数据加载器:

orig_train = MyMNIST(MNIST_DIR,train=True,augmented=False,transform=orig_transforms)
orig_train_loader = torch.utils.data.DataLoader(orig_train.dataset,batch_size=100,shuffle=True)

aug_train = MyMNIST(MNIST_DIR,augmented=True,transform=aug_transforms,repeat=3)
aug_train_loader = torch.utils.data.DataLoader(aug_train.dataset,batch_size=300,shuffle=True)

我现在的问题是,我还需要在每次迭代中都进行洗牌,同时保持A和B之间的顺序相关。对于上述代码,这是不可能的,因为两个DataLoader都会产生不同的顺序。因此,我尝试使用单个DataLoader并手动复制重复的批次:

for batch_no,(images,labels) in enumerate(orig_train_loader):
    repeat_images = torch.repeat_interleave(images,3,dim=0)

这样,我正确地获得了批次B(repeat_images)的订单,但是现在我缺少了需要在批次/迭代中应用的转换。这似乎不是Pytorch的范式,至少我没有找到做到这一点的方法

如果有人可以帮助我,我会很高兴-我对Pytorch(以及Stackoverflow)还很陌生,所以也欢迎批评我的整个方法,可能出现的性能问题等。

非常感谢!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)