Pytorch 数据生成器,用于从许多 3D 立方体中提取 2D 图像

问题描述

我正在努力在 PyTorch 中创建一个数据生成器,以从许多以 .dat 格式保存的 3D 立方体中提取 2D 图像

总共有 200 个 3D 立方体,每个立方体的形状为 128*128*128。现在我想从所有这些立方体中沿长度和宽度提取 2D 图像。

例如,a一个大小为 128*128*128 的立方体

所以我想沿长度提取所有 2D 图像,即 [:,i,:],这将使我沿长度获得 128 个 2D 图像,同样我想沿宽度提取,即 [:,:,i],这将给出沿宽度有 128 张 2D 图像。因此,我从 1 个 3D 立方体中总共得到 256 个 2D 图像,我想对所有 200 个立方体重复整个过程,给我 51200 个 2D 图像。

到目前为止,我已经尝试了一个非常基本的实现,它运行良好,但需要大约 10 分钟才能运行。我希望你们帮助我创建一个更优化的实现,同时考虑到时间和空间的复杂性。现在我目前的方法有 O(n2) 的时间复杂度,我们可以进一步分解它以降低时间复杂度

我在当前的实现下面提供

from os.path import join as pjoin
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils import data


class DataGenerator(data.Dataset):

    def __init__(self,is_transform=True,augmentations=None):

        self.is_transform = is_transform
        self.augmentations = augmentations
        self.dim = (128,128,128)

        seismicSections = [] #Input
        faultSections = [] #Ground Truth
        for fileName in tqdm(os.listdir(pjoin('train','seis')),total = len(os.listdir(pjoin('train','seis')))):
            unrolledVolSeismic = np.fromfile(pjoin('train','seis',fileName),dtype = np.single) #dat file contains unrolled cube,we need to reshape it
            reshapedVolSeismic = np.transpose(unrolledVolSeismic.reshape(self.dim)) #need to transpose the axis to get height axis at axis = 0,while length (axis = 1),and width(axis = 2)

            unrolledVolFault = np.fromfile(pjoin('train','fault',dtype=np.single)
            reshapedVolFault = np.transpose(unrolledVolFault.reshape(self.dim))

            for idx in range(reshapedVolSeismic.shape[2]):
                seismicSections.append(reshapedVolSeismic[:,idx])
                faultSections.append(reshapedVolFault[:,idx])

            for idx in range(reshapedVolSeismic.shape[1]):
                seismicSections.append(reshapedVolSeismic[:,idx,:])
                faultSections.append(reshapedVolFault[:,:])

        self.seismicSections = seismicSections
        self.faultSections = faultSections

    def __len__(self):
        return len(self.seismicSections)

    def __getitem__(self,index):

        X = self.seismicSections[index]
        Y = self.faultSections[index]

        return X,Y

请帮忙!!!

解决方法

为什么不只将 3D 数据存储在 mem 中,而让 __getitem__ 方法动态“切片”它?

class CachedVolumeDataset(Dataset):
  def __init__(self,...):
    super(...)
    self._volumes_x = # a list of 200 128x128x128 volumes
    self._volumes_y = # a list of 200 128x128x128 volumes

  def __len__(self):
    return len(self._volumes_x) * (128 + 128)

  def __getitem__(self,index):
    # extract volume index from general index:
    vidx = index // (128 + 128)
    # extract slice index
    sidx = index % (128 + 128)
    if sidx < 128:
      # first dim
      x = self._volumes_x[vidx][:,:,sidx]
      y = self._volumes_y[vidx][:,sidx]
    else:
      sidx -= 128
      # second dim
      x = self._volumes_x[vidx][:,sidx,:]
      y = self._volumes_y[vidx][:,:]
    return torch.squeeze(x),torch.squeeze(y)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...