问题描述
我正在努力在 PyTorch 中创建一个数据生成器,以从许多以 .dat
格式保存的 3D 立方体中提取 2D 图像
总共有 200
个 3D 立方体,每个立方体的形状为 128*128*128
。现在我想从所有这些立方体中沿长度和宽度提取 2D 图像。
例如,a
是一个大小为 128*128*128
的立方体
所以我想沿长度提取所有 2D 图像,即 [:,i,:]
,这将使我沿长度获得 128 个 2D 图像,同样我想沿宽度提取,即 [:,:,i]
,这将给出沿宽度有 128 张 2D 图像。因此,我从 1 个 3D 立方体中总共得到 256 个 2D 图像,我想对所有 200 个立方体重复整个过程,给我 51200 个 2D 图像。
到目前为止,我已经尝试了一个非常基本的实现,它运行良好,但需要大约 10 分钟才能运行。我希望你们帮助我创建一个更优化的实现,同时考虑到时间和空间的复杂性。现在我目前的方法有 O(n2) 的时间复杂度,我们可以进一步分解它以降低时间复杂度
我在当前的实现下面提供
from os.path import join as pjoin
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils import data
class DataGenerator(data.Dataset):
def __init__(self,is_transform=True,augmentations=None):
self.is_transform = is_transform
self.augmentations = augmentations
self.dim = (128,128,128)
seismicSections = [] #Input
faultSections = [] #Ground Truth
for fileName in tqdm(os.listdir(pjoin('train','seis')),total = len(os.listdir(pjoin('train','seis')))):
unrolledVolSeismic = np.fromfile(pjoin('train','seis',fileName),dtype = np.single) #dat file contains unrolled cube,we need to reshape it
reshapedVolSeismic = np.transpose(unrolledVolSeismic.reshape(self.dim)) #need to transpose the axis to get height axis at axis = 0,while length (axis = 1),and width(axis = 2)
unrolledVolFault = np.fromfile(pjoin('train','fault',dtype=np.single)
reshapedVolFault = np.transpose(unrolledVolFault.reshape(self.dim))
for idx in range(reshapedVolSeismic.shape[2]):
seismicSections.append(reshapedVolSeismic[:,idx])
faultSections.append(reshapedVolFault[:,idx])
for idx in range(reshapedVolSeismic.shape[1]):
seismicSections.append(reshapedVolSeismic[:,idx,:])
faultSections.append(reshapedVolFault[:,:])
self.seismicSections = seismicSections
self.faultSections = faultSections
def __len__(self):
return len(self.seismicSections)
def __getitem__(self,index):
X = self.seismicSections[index]
Y = self.faultSections[index]
return X,Y
请帮忙!!!
解决方法
为什么不只将 3D 数据存储在 mem 中,而让 __getitem__
方法动态“切片”它?
class CachedVolumeDataset(Dataset):
def __init__(self,...):
super(...)
self._volumes_x = # a list of 200 128x128x128 volumes
self._volumes_y = # a list of 200 128x128x128 volumes
def __len__(self):
return len(self._volumes_x) * (128 + 128)
def __getitem__(self,index):
# extract volume index from general index:
vidx = index // (128 + 128)
# extract slice index
sidx = index % (128 + 128)
if sidx < 128:
# first dim
x = self._volumes_x[vidx][:,:,sidx]
y = self._volumes_y[vidx][:,sidx]
else:
sidx -= 128
# second dim
x = self._volumes_x[vidx][:,sidx,:]
y = self._volumes_y[vidx][:,:]
return torch.squeeze(x),torch.squeeze(y)