【发布时间】:2019-07-12 05:13:00
【问题描述】:
我有一个不适合内存 (150G) 的庞大数据集,我正在寻找在 pytorch 中使用它的最佳方法。数据集由几个.npz 文件组成,每个文件有 10k 个样本。我试图建立一个Dataset 类
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.files = os.listdir(self.path)
self.file_length = {}
for f in self.files:
# Load file in as a nmap
d = np.load(os.path.join(self.path, f), mmap_mode='r')
self.file_length[f] = len(d['y'])
def __len__(self):
raise NotImplementedException()
def __getitem__(self, idx):
# Find the file where idx belongs to
count = 0
f_key = ''
local_idx = 0
for k in self.file_length:
if count < idx < count + self.file_length[k]:
f_key = k
local_idx = idx - count
break
else:
count += self.file_length[k]
# Open file as numpy.memmap
d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
# Actually fetch the data
X = np.expand_dims(d['X'][local_idx], axis=1)
y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
return X, y
但实际提取样本时,需要 30 多秒。看起来整个 .npz 已打开,存储在 RAM 中并访问了正确的索引。
如何提高效率?
编辑
这似乎是对.npz文件see post的误解,但是有没有更好的方法?
解决方案建议
正如@covariantmonkey 所建议的,lmdb 可能是一个不错的选择。目前,由于问题来自.npz 文件而不是memmap,我通过将.npz 包文件拆分为几个.npy 文件来重构我的数据集。我现在可以使用与memmap 相同的逻辑,并且速度非常快(加载样本需要几毫秒)。
【问题讨论】:
标签: python numpy dataset pytorch large-data