【发布时间】:2020-09-18 20:53:08
【问题描述】:
我正在 python 中使用 Tensorflow 进行对象检测。
我想使用 tensorflow 输入管道来加载批量输入数据。 问题是图像中的对象数量是可变的。
想象一下我想做以下事情。 annotations 是一个包含图像文件名和边界框的数组。标签被排除在外。每个边界框由四个数字表示。
import tensorflow as tf
@tf.function()
def prepare_sample(annotation):
annotation_parts = tf.strings.split(annotation, sep=' ')
image_file_name = annotation_parts[0]
image_file_path = tf.strings.join(["/images/", image_file_name])
depth_image = tf.io.read_file(image_file_path)
bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])
return depth_image, bboxes
annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']
dataset = tf.data.Dataset.from_tensor_slices(annotations)
dataset = dataset.shuffle(len(annotations))
dataset = dataset.map(prepare_sample)
dataset = dataset.batch(16)
for image, bboxes in dataset:
pass
在上面的示例中,image1 包含一个对象,而 image2 包含两个对象。 我收到以下错误:
InvalidArgumentError:无法将张量添加到批次:数量 元素不匹配。形状是:[张量]:[1,4],[批次]:[2,4]
这是有道理的。 我正在寻找从映射函数返回不同长度数组的方法。我能做什么?
谢谢!
编辑:
我想我找到了解决方案;我不再收到错误消息。
我将dataset.batch(16) 更改为dataset.padded_batch(16)。
【问题讨论】:
标签: python tensorflow