问题描述
我是PyTorch的新手。现在,我有两个名为A和B的数据集(例如:MNIST)。我想将A和B混合在一起以形成一个新的数据集。我想改组这个新的数据集。在培训期间,我需要确定批次中的样品是否属于A。如何做到这一点?
两个问题如下: 1)如何混合两个数据集并对其进行随机播放? 2)如何确定新数据集中的样本是否属于原始数据集A?
解决方法
通过定义自定义数据集和一些标志标签,您可以实现这一目标。 这是示例代码:
import torch
from torch.utils.data import DataLoader,Dataset
class to_dataset(Dataset):
def __init__(self,data_A,data_B):
self.lena = data_A.shape[0]
self.len = data_A.shape[0] + data_B.shape[0]
self.A = torch.from_numpy(data_A).float()
self.B = torch.from_numpy(data_B).float()
# returns dataset A data with flag label 0
# dataset B data with flag label 1
def __getitem__(self,index):
if index < self.lena:
return self.A[index],0
retutn self.B[index-self.lena],1
def __len__(self):
return self.len
#reading sample numpy dataset
data_a = np.load(pathofA)
data_b = np.load(pathofB)
# loading custom dataset
dataset = to_dataset(data_a,data_b)
#loading dataloader with training data
train_loader = DataLoader(dataset=dataset,batch_size=bsize,shuffle=True)
#sample train loop
for epoch in range(epochs):
for data,label in train_loader:
for d,l in zip(data,label):
if l == 0:
print('from A')
else:
print('from B')