PyTorch:在使用 Dataloader 加载批量数据时,如何自动将数据传输到 GPU

问题描述

如果我们结合使用 DatasetDataLoader 类(如下所示),我必须使用 {{1} 将数据显式加载到 GPU } 或 .to()。有没有办法指示数据加载器自动/隐式执行此操作?

理解/重现场景的代码

.cuda()

输出以下内容;注意 - 没有明确的设备传输指令,数据加载到cpu

from torch.utils.data import Dataset,DataLoader
import numpy as np

class DemoData(Dataset):
    def __init__(self,limit):
        super(DemoData,self).__init__()
        self.data = np.arange(limit)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self,idx):
        return (self.data[idx],self.data[idx]*100)

demo = DemoData(100)

loader = DataLoader(demo,batch_size=50,shuffle=True)

for i,(i1,i2) in enumerate(loader):
    print('Batch Index: {}'.format(i))
    print('Shape of data item 1: {}; shape of data item 2: {}'.format(i1.shape,i2.shape))
    # i1,i2 = i1.to('cuda:0'),i2.to('cuda:0')
    print('Device of data item 1: {}; device of data item 2: {}\n'.format(i1.device,i2.device))

一个可能的解决方案是在 this PyTorch GitHub repo. Issue在发布此问题时仍处于打开状态),但是,当数据加载器必须返回多个数据时,我无法使其工作-物品!

解决方法

您可以修改 collate_fn 以同时处理多个项目:

from torch.utils.data.dataloader import default_collate

device = torch.device('cuda:0')  # or whatever device/cpu you like

# the new collate function is quite generic
loader = DataLoader(demo,batch_size=50,shuffle=True,collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

请注意,如果您想为数据加载器设置多个工作线程,则需要添加

torch.multiprocessing.set_start_method('spawn')

在您的 if __name__ == '__main__' 之后(参见 this issue)。

话虽如此,在您的 pin_memory=True 中使用 DataLoader 似乎效率更高。你试过这个选项吗?
有关详细信息,请参阅 memory pinning


更新(2021 年 2 月 8 日)
这篇文章让我了解了我在训练期间所花费的“数据到模型”时间。 我比较了三种选择:

  1. DataLoader 在 CPU 上运行,只有在检索到批次后数据才会移动到 GPU。
  2. 与 (1) 相同,但在 pin_memory=True 中带有 DataLoader
  3. 建议的使用 collate_fn 将数据移动到 GPU 的方法。

从我有限的实验来看,似乎第二个选项效果最好(但差距不大)。
第三个选项需要对数据加载器进程的 start_method 大惊小怪,而且似乎在每个 epoch 开始时都会产生开销。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...