【问题标题】:Returning arrays of different lengths in a Tensorflow dataset pipeline在 Tensorflow 数据集管道中返回不同长度的数组
【发布时间】:2020-09-18 20:53:08
【问题描述】:

我正在 python 中使用 Tensorflow 进行对象检测。

我想使用 tensorflow 输入管道来加载批量输入数据。 问题是图像中的对象数量是可变的。

想象一下我想做以下事情。 annotations 是一个包含图像文件名和边界框的数组。标签被排除在外。每个边界框由四个数字表示。

import tensorflow as tf

@tf.function()
def prepare_sample(annotation):
    annotation_parts = tf.strings.split(annotation, sep=' ')
    image_file_name = annotation_parts[0]
    image_file_path = tf.strings.join(["/images/", image_file_name])
    depth_image = tf.io.read_file(image_file_path)
    bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])
    return depth_image, bboxes

annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']
dataset = tf.data.Dataset.from_tensor_slices(annotations)
dataset = dataset.shuffle(len(annotations))
dataset = dataset.map(prepare_sample)
dataset = dataset.batch(16)

for image, bboxes in dataset:
  pass

在上面的示例中,image1 包含一个对象,而 image2 包含两个对象。 我收到以下错误:

InvalidArgumentError:无法将张量添加到批次:数量 元素不匹配。形状是:[张量]:[1,4],[批次]:[2,4]

这是有道理的。 我正在寻找从映射函数返回不同长度数组的方法。我能做什么?

谢谢!

编辑: 我想我找到了解决方案;我不再收到错误消息。 我将dataset.batch(16) 更改为dataset.padded_batch(16)

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    为了社区的利益在这里回答解决方案,因为这篇文章的作者@Ladislav Ondris 能够解决这个问题。

    dataset.batch(16)更改为dataset.padded_batch(16)后,错误将得到解决。

    下面是相同的修改代码。

    import tensorflow as tf
    
    @tf.function()
    def prepare_sample(annotation):
        annotation_parts = tf.strings.split(annotation, sep=' ')
        image_file_name = annotation_parts[0]
        image_file_path = tf.strings.join(["/images/", image_file_name])
        depth_image = tf.io.read_file(image_file_path)
        bboxes = tf.reshape(annotation_parts[1:], shape=[-1,4])
        return depth_image, bboxes
    
    annotations = ['image1.png 1 2 3 4', 'image2.png 1 2 3 4 5 6 7 8']
    dataset = tf.data.Dataset.from_tensor_slices(annotations)
    dataset = dataset.shuffle(len(annotations))
    dataset = dataset.map(prepare_sample)
    dataset = dataset.padded_batch(16)
    
    for image, bboxes in dataset:
      pass
    

    【讨论】:

      猜你喜欢
      • 2018-05-23
      • 2022-01-15
      • 1970-01-01
      • 2021-02-25
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-01-28
      • 2015-11-28
      相关资源
      最近更新 更多