【问题标题】:Pytorch Data Generator for extracting 2D images from many 3D cubePytorch 数据生成器,用于从许多 3D 立方体中提取 2D 图像
【发布时间】:2021-08-29 12:53:50
【问题描述】:

我正在努力在 PyTorch 中创建一个数据生成器,以便从以 .dat 格式保存的许多 3D 立方体中提取 2D 图像

共有200 3D 立方体,每个立方体具有128*128*128 形状。现在我想从所有这些立方体中沿长度和宽度提取二维图像。

例如,a 是一个大小为 128*128*128 的立方体

所以我想沿长度提取所有 2D 图像,即[:, i, :],这将获得沿长度的 128 个 2D 图像,同样我想沿宽度提取,即,[:, :, i],这将给我 128 个 2D 图像沿宽度。因此,我从 1 个 3D 立方体中总共获得了 256 个 2D 图像,我想对所有 200 个立方体重复整个过程,给我 51200 个 2D 图像。

到目前为止,我已经尝试了一个非常基本的实现,它运行良好,但运行大约需要 10 分钟。我希望你们帮助我创建一个更优化的实现,同时牢记时间和空间的复杂性。现在我目前的方法的时间复杂度为 O(n2),我们可以进一步降低它以降低时间复杂度

我在当前实现下提供

from os.path import join as pjoin
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils import data


class DataGenerator(data.Dataset):

    def __init__(self, is_transform=True, augmentations=None):

        self.is_transform = is_transform
        self.augmentations = augmentations
        self.dim = (128, 128, 128)

        seismicSections = [] #Input
        faultSections = [] #Ground Truth
        for fileName in tqdm(os.listdir(pjoin('train', 'seis')), total = len(os.listdir(pjoin('train', 'seis')))):
            unrolledVolSeismic = np.fromfile(pjoin('train', 'seis', fileName), dtype = np.single) #dat file contains unrolled cube, we need to reshape it
            reshapedVolSeismic = np.transpose(unrolledVolSeismic.reshape(self.dim)) #need to transpose the axis to get height axis at axis = 0, while length (axis = 1), and width(axis = 2)

            unrolledVolFault = np.fromfile(pjoin('train', 'fault', fileName),dtype=np.single)
            reshapedVolFault = np.transpose(unrolledVolFault.reshape(self.dim))

            for idx in range(reshapedVolSeismic.shape[2]):
                seismicSections.append(reshapedVolSeismic[:, :, idx])
                faultSections.append(reshapedVolFault[:, :, idx])

            for idx in range(reshapedVolSeismic.shape[1]):
                seismicSections.append(reshapedVolSeismic[:, idx, :])
                faultSections.append(reshapedVolFault[:, idx, :])

        self.seismicSections = seismicSections
        self.faultSections = faultSections

    def __len__(self):
        return len(self.seismicSections)

    def __getitem__(self, index):

        X = self.seismicSections[index]
        Y = self.faultSections[index]

        return X, Y

请帮忙!!!

【问题讨论】:

    标签: python computer-vision pytorch data-generation


    【解决方案1】:

    为什么不在 mem 中只存储 3D 数据,而让__getitem__ 方法动态“切片”它?

    class CachedVolumeDataset(Dataset):
      def __init__(self, ...):
        super(...)
        self._volumes_x = # a list of 200 128x128x128 volumes
        self._volumes_y = # a list of 200 128x128x128 volumes
    
      def __len__(self):
        return len(self._volumes_x) * (128 + 128)
    
      def __getitem__(self, index):
        # extract volume index from general index:
        vidx = index // (128 + 128)
        # extract slice index
        sidx = index % (128 + 128)
        if sidx < 128:
          # first dim
          x = self._volumes_x[vidx][:, :, sidx]
          y = self._volumes_y[vidx][:, :, sidx]
        else:
          sidx -= 128
          # second dim
          x = self._volumes_x[vidx][:, sidx, :]
          y = self._volumes_y[vidx][:, sidx, :]
        return torch.squeeze(x), torch.squeeze(y)
    

    【讨论】:

      猜你喜欢
      • 2014-07-21
      • 2014-09-06
      • 1970-01-01
      • 2020-10-26
      • 2013-12-17
      • 2021-07-27
      • 2015-10-28
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多