如何判断新数据集中的样本是否属于PyTorch中的原始数据集?

问题描述

我是PyTorch的新手。现在,我有两个名为A和B的数据集(例如:MNIST)。我想将A和B混合在一起以形成一个新的数据集。我想改组这个新的数据集。在培训期间,我需要确定批次中的样品是否属于A。如何做到这一点?

两个问题如下: 1)如何混合两个数据集并对其进行随机播放? 2)如何确定新数据集中的样本是否属于原始数据集A?

解决方法

通过定义自定义数据集和一些标志标签,您可以实现这一目标。 这是示例代码:

import torch
from torch.utils.data import DataLoader,Dataset 

class to_dataset(Dataset):

    def __init__(self,data_A,data_B):
        self.lena   = data_A.shape[0]
        self.len    = data_A.shape[0] + data_B.shape[0]
        self.A      = torch.from_numpy(data_A).float()
        self.B      = torch.from_numpy(data_B).float()
     
    # returns dataset A data with flag label 0
    # dataset B data with flag label 1
    def __getitem__(self,index):
        if index < self.lena:            
            return self.A[index],0
        retutn self.B[index-self.lena],1

    def __len__(self):
        return self.len

   
#reading sample numpy dataset
data_a  = np.load(pathofA)
data_b  = np.load(pathofB)
 
# loading custom dataset
dataset = to_dataset(data_a,data_b)

#loading dataloader with training data
train_loader = DataLoader(dataset=dataset,batch_size=bsize,shuffle=True)

#sample train loop
for epoch in range(epochs):
    for data,label in train_loader:
        for d,l in zip(data,label):  
            if l == 0:
                print('from A')
            else:
                print('from B')

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...