如何修复数据集以返回所需的输出pytorch

问题描述

@H_404_0@我正在尝试使用外部函数提供的信息来确定要返回的数据。在这里,我添加一个简化的代码来演示该问题。当我使用num_workers = 0时,我得到了想要的行为(3个纪元后的输出为18)。但是,当我增加num_workers的值时,每个时期之后的输出是相同的。并且全局变量保持不变。

from torch.utils.data import Dataset,DataLoader

x = 6
def getx():
    global x
    x+=1
    print("x: ",x)
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self,index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,num_workers=0,shuffle=False
)

for epoch in range(4):
    for idx,data in enumerate(loader):
        print('Epoch {},idx {},val: {}'.format(epoch,idx,data))

@H_404_0@ num_workers=0为预期的18时的最终输出。但是当num_workers>0时,x保持不变(最终输出为6)。

@H_404_0@如何使用num_workers=0获得与num_workers>0类似的行为(即如何确保DataLoader__getitem__函数更改全局变量x的值)?

解决方法

这样做的原因是python中多处理的基本性质。设置num_workers意味着您的DataLoader创建该数目的子流程。每个子进程实际上是一个具有自己全局状态的单独的python实例,并且不知道其他进程中正在发生什么。

在python的多处理中,典型的解决方案是使用Manager。但是,由于通过DataLoader提供了多处理功能,因此您无法使用它。

幸运的是,还有其他事情可以做。 DataLoader实际上依赖于torch.multiprocessing,这反过来又允许进程在共享内存中共享张量。

因此,您可以做的就是简单地将x用作共享张量。

from torch.utils.data import Dataset,DataLoader
import torch 

x = torch.tensor([6])
x.share_memory_()

def getx():
    global x
    x+=1
    print("x: ",x.item())
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self,index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,num_workers=2,shuffle=False
)

for epoch in range(4):
    for idx,data in enumerate(loader):
        print('Epoch {},idx {},val: {}'.format(epoch,idx,data))

出局:

x:  7
x:  8
x:  9
Epoch 0,idx 0,val: tensor([[7]])
Epoch 0,idx 1,val: tensor([[8]])
Epoch 0,idx 2,val: tensor([[9]])
x:  10
x:  11
x:  12
Epoch 1,val: tensor([[10]])
Epoch 1,val: tensor([[12]])
Epoch 1,val: tensor([[12]])
x:  13
x:  14
x:  15
Epoch 2,val: tensor([[13]])
Epoch 2,val: tensor([[15]])
Epoch 2,val: tensor([[14]])
x:  16
x:  17
x:  18
Epoch 3,val: tensor([[16]])
Epoch 3,val: tensor([[18]])
Epoch 3,val: tensor([[17]])

虽然可行,但并不完美。查看时期1,注意有2个12,而不是11和12。这意味着两个单独的进程在执行打印之前已经执行了行x+=1。这是不可避免的,因为并行进程正在共享内存上工作。

如果您熟悉操作系统的概念,则可以通过附加变量进一步实现某种semaphore,以根据需要控制对x的访问-但超出了本章的范围问题,我不再赘述。