问题描述
我正在处理一个pytorch项目,我的数据保存在zarr
中。
在zarr
上进行随机访问的成本很高,但是由于zarr
使用逐块缓存,因此迭代非常快。为了利用这一事实,我将IterableDataset
与多个工作程序一起使用:
class Data(IterableDataset):
def __init__(self,path,start=None,end=None):
super(Data,self).__init__()
store = zarr.DirectoryStore(path)
self.array = zarr.open(store,mode='r')
if start is None:
start = 0
if end is None:
end = self.array.shape[0]
assert end > start
self.start = start
self.end = end
def __iter__(self):
return islice(self.array,self.start,self.end)
问题在于self.array
的行数约为10e9
,并且对于连续的工作者来说,self.start
和self.end
自然会变大,从而生成{{ 1}}在我的训练/验证过程中花费了大量时间,因为itertools.islice(array,start,end)
仍然必须遍历不需要的元素,直到到达islice
。一旦为每个工人创建了一个生成器,这就像一个咒语,但是到达那里需要太长时间。
是否有更好的方法来创建这种生成器?还是在start
中有一种更聪明的使用zarr
的方法?
解决方法
我对zarr进行了一小段潜水,看来这很容易从zarr内部启用。我打开了一个问题here,与此同时,我制作了一个实现功能array.islice(start,end)
的{{3}}。
数据集__iter__
方法如下:
def __iter__(self):
return self.array.islice(self.start,self.end)