【发布时间】:2021-07-30 16:21:46
【问题描述】:
我正在尝试加载数据以优化对象检测模型 + 实例分割。然而,使用 tf.data.Dataset 让我在加载实例分割掩码时有点头疼。 tf.data.Dataset 正在使用服务器上的所有内存(超过 128 GB)和一个小数据集。
有没有办法以更节省内存的方式有效地加载数据,现在我们正在使用这段代码:
train_dataset, train_examples = dataset.load_train_datasets()
ds = (
train_dataset.shuffle(min(100, train_examples), reshuffle_each_iteration=True)
.map(dataset.decode, num_parallel_calls=args.num_parallel_calls)
.map(train_processing.prepare_for_batch, num_parallel_calls=args.num_parallel_calls)
.batch(args.batch_size)
.map(train_processing.preprocess_batch, num_parallel_calls=args.num_parallel_calls)
.prefetch(AUTOTUNE)
)
问题在于,第二次调用 train_processing.prepare_for_batch(采用单个元素)和第三次调用 train_processing.preprocess_batch(采用一批元素)正在创建大量用于分割的二进制掩码,这些掩码占用了所有内存。
有没有办法重新组织映射函数以节省内存?我在想这样的事情:1. 取前 100 个样本,2. 解码样本,3. 为一个样本准备掩码和边界框 4. 取其中一批 5. 每批数据的最终准备 6. FIT ONE step/一批数据 7. 清理内存中的数据
【问题讨论】:
-
dataset.decode、train_processing.prepare_for_batch、train_processing.preprocess_batch中的操作是什么?您有 3 个地图操作将针对每个输入行运行。这些可能很昂贵。
-
你有没有考虑过ImageDataGenerator?我所做的通常是先将所有图像保存到我的硬盘上,然后通过 ImageDataGenerators 进行所有预处理并将批次提供给我的模型。
-
所有地图末尾的 ds.element_spec 是什么?我很好奇你的批量大小、图像大小和掩码数量最终是多少。
标签: python tensorflow tensorflow2.0 tensorflow-datasets