【发布时间】:2020-02-18 22:27:35
【问题描述】:
是否可以为 TensorFlows 的 py_function 指定嵌套输出类型?
作为特定情况,我希望py_function 的返回类型为((tf.float32, tf.float32), (tf.float32, tf.float32)),其中各个元素不一定具有相同的尺寸。有没有办法为py_function 指定这个?
正如对为什么这在我的案例中有用的一些见解一样,我有一个带有文件路径列表的tf.data.Dataset。 py_function 采用这些文件路径之一,并从文件中生成一个负样本和正样本以及相应的标签,从而产生 ((positive_data, positive_label), (negative_data, negative_label))(注意,标签不一定是单个值,但它们的形状也不相同作为输入数据)。这个py_function 可以映射到数据集,并且(使用上述结构)有一层扁平化以生成具有(data, label) 结构化元素的训练数据集。虽然可以有一种解决方法,将数据和标签堆叠在 py_function 中,然后再取消堆叠(或者从 py_function 开始完全非结构化,然后才配对),但这会导致设置混乱和混乱。如果py_function 可以直接输出((tf.float32, tf.float32), (tf.float32, tf.float32)) 类型,这将导致更干净的设置。
【问题讨论】:
标签: python tensorflow tensorflow-datasets