【发布时间】:2020-11-10 02:34:13
【问题描述】:
我最近开始使用 tensorflow 研究 CNN,发现 tfrecords 对加快训练非常有帮助,但是我在数据 API 方面遇到了困难。
解析后,我的数据集由(图像,标签)元组组成,这对于训练来说很好,但是我试图在另一个数据集中提取图像以调用 keras.predict() 。
我试过这个解决方案:
test_set = get_set_tfrecord(test_path, _parse_function, num_parallel_calls = 4)
lab = []
f = True
for image, label in test_set.take(600):
if f:
img = tf.data.Dataset.from_tensors(image)
f = False
else:
img = img.concatenate(tf.data.Dataset.from_tensors(image))
lab.append(label.numpy())
天真,不是很好的代码,但它的工作原理是为了执行连接(即堆叠),它将每个图像加载到 RAM 中。
这样做的正确方法是什么?
【问题讨论】:
-
您必须使用 tfRecord 吗?我认为首先使用 tf.Dataset 是有意义的,并且只有当您仍然受 I/O 限制时才继续使用 tfRecord。无论如何,可能会有像 tf.Dataset 这样的批处理功能?
标签: python tensorflow machine-learning keras computer-vision