【发布时间】:2022-01-09 22:52:32
【问题描述】:
我想创建一个用于训练 TensorFlow 模型的数据管道。数据存储在非常大的 HDF5 文件中 (250+ GB)。
我编写了一个适用于较小输入文件的管道,但在消耗过多 RAM+swap 后最终被内核杀死(通过监控验证了这一点)。
import tensorflow as tf
import h5py
class TestGenerator:
"""
Implements a generator that can be used by tf.data.Dataset.from_generator
to produce a dataset for any test data.
"""
def __init__(self, src, dset):
self.src = src
self.dset = dset
self.output_signature = (
tf.TensorSpec(shape=(2,), dtype=tf.uint64)
)
def __call__(self):
"""This is needed for tf.data.Dataset.from_generator to work."""
with h5py.File(self.src, 'r', swmr=True) as f:
for sample in f[self.dset]:
yield sample[0], sample[1]
gen = TestGenerator('h5file.h5', 'dset_path')
dataset = tf.data.Dataset.from_generator(
gen,
output_signature=gen.output_signature
)
for sample in dataset:
pass
一开始我以为这可能是h5py模块的问题,所以单独测试了一下:
with h5py.File('h5file.h5', 'r', swmr=True) as f:
for sample in f['dset_path']:
pass
这没有问题。由此得出结论,TensorFlow 是造成内存问题的原因。让我恼火的是,我假设 TensorFlow 可以即时获取所需的数据,因此可以避免内存问题。
代码经过测试,适用于较小的文件。我还测试了在迭代之前使用dataset.prefetch 的版本,但结果相同。
TensorFlow 是否会在后台加载整个数据集?
【问题讨论】:
标签: python python-3.x tensorflow dataset generator