具有可迭代数据集的Pytorch数据加载器在多处理模式下经过一个时期后停止

问题描述

我有一个使用可迭代数据集初始化的数据加载器。我发现当我在数据加载器中使用多重处理(即DataLoader中的num_workers> 0)时,一旦数据加载器在一个时期后用尽,当我在第二个时期再次对其进行迭代时,它不会自动重置。下面是一个可重现的小示例。

根据documentation,我知道“一旦迭代结束,工作人员将被关闭”。但是,我想知道如何实现“自动重置”的预期行为。感谢您的任何事先帮助!

import torch
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self,start,end):
        super().__init__()
        self.start = start
        self.end = end
        
    def __iter__(self):
        return iter(range(self.start,self.end))

    
dataset = MyIterableDataset(0,4)
DataLoader = torch.utils.data.DataLoader(dataset,batch_size=2,shuffle=False,num_workers=1,drop_last=False)


for epoch in range(2):
    for i,data in enumerate(DataLoader):
        print(i,data)

"""
stdout:
0 tensor([0,1])
1 tensor([2,3])
2 _IterableDatasetstopiteration(worker_id=0)
"""

我对stdout的期望是

"""
0 tensor([0,3])
0 tensor([0,3])
"""

我正在使用最新的pytorch版本(1.6.0)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)