在带有pytorch数据加载器的起点和终点的奇异数组上创建生成器

问题描述

我正在处理一个pytorch项目,我的数据保存在zarr中。

zarr上进行随机访问的成本很高,但是由于zarr使用逐块缓存,因此迭代非常快。为了利用这一事实,我将IterableDataset与多个工作程序一起使用:

class Data(IterableDataset):
    def __init__(self,path,start=None,end=None):
        super(Data,self).__init__()
        store = zarr.DirectoryStore(path)
        self.array = zarr.open(store,mode='r')

        if start is None:
            start = 0
        if end is None:
            end = self.array.shape[0]

        assert end > start

        self.start = start
        self.end = end

    def __iter__(self):
        return islice(self.array,self.start,self.end)

问题在于self.array的行数约为10e9,并且对于连续的工作者来说,self.startself.end自然会变大,从而生成{{ 1}}在我的训练/验证过程中花费了大量时间,因为itertools.islice(array,start,end)仍然必须遍历不需要的元素,直到到达islice。一旦为每个工人创建了一个生成器,这就像一个咒语,但是到达那里需要太长时间。

是否有更好的方法来创建这种生成器?还是在start中有一种更聪明的使用zarr方法

解决方法

我对zarr进行了一小段潜水,看来这很容易从zarr内部启用。我打开了一个问题here,与此同时,我制作了一个实现功能array.islice(start,end)的{​​{3}}。

数据集__iter__方法如下:

def __iter__(self):
    return self.array.islice(self.start,self.end)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...