【发布时间】:2021-08-29 12:53:50
【问题描述】:
我正在努力在 PyTorch 中创建一个数据生成器,以便从以 .dat 格式保存的许多 3D 立方体中提取 2D 图像
共有200 3D 立方体,每个立方体具有128*128*128 形状。现在我想从所有这些立方体中沿长度和宽度提取二维图像。
例如,a 是一个大小为 128*128*128 的立方体
所以我想沿长度提取所有 2D 图像,即[:, i, :],这将获得沿长度的 128 个 2D 图像,同样我想沿宽度提取,即,[:, :, i],这将给我 128 个 2D 图像沿宽度。因此,我从 1 个 3D 立方体中总共获得了 256 个 2D 图像,我想对所有 200 个立方体重复整个过程,给我 51200 个 2D 图像。
到目前为止,我已经尝试了一个非常基本的实现,它运行良好,但运行大约需要 10 分钟。我希望你们帮助我创建一个更优化的实现,同时牢记时间和空间的复杂性。现在我目前的方法的时间复杂度为 O(n2),我们可以进一步降低它以降低时间复杂度
我在当前实现下提供
from os.path import join as pjoin
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils import data
class DataGenerator(data.Dataset):
def __init__(self, is_transform=True, augmentations=None):
self.is_transform = is_transform
self.augmentations = augmentations
self.dim = (128, 128, 128)
seismicSections = [] #Input
faultSections = [] #Ground Truth
for fileName in tqdm(os.listdir(pjoin('train', 'seis')), total = len(os.listdir(pjoin('train', 'seis')))):
unrolledVolSeismic = np.fromfile(pjoin('train', 'seis', fileName), dtype = np.single) #dat file contains unrolled cube, we need to reshape it
reshapedVolSeismic = np.transpose(unrolledVolSeismic.reshape(self.dim)) #need to transpose the axis to get height axis at axis = 0, while length (axis = 1), and width(axis = 2)
unrolledVolFault = np.fromfile(pjoin('train', 'fault', fileName),dtype=np.single)
reshapedVolFault = np.transpose(unrolledVolFault.reshape(self.dim))
for idx in range(reshapedVolSeismic.shape[2]):
seismicSections.append(reshapedVolSeismic[:, :, idx])
faultSections.append(reshapedVolFault[:, :, idx])
for idx in range(reshapedVolSeismic.shape[1]):
seismicSections.append(reshapedVolSeismic[:, idx, :])
faultSections.append(reshapedVolFault[:, idx, :])
self.seismicSections = seismicSections
self.faultSections = faultSections
def __len__(self):
return len(self.seismicSections)
def __getitem__(self, index):
X = self.seismicSections[index]
Y = self.faultSections[index]
return X, Y
请帮忙!!!
【问题讨论】:
标签: python computer-vision pytorch data-generation