【发布时间】:2020-02-01 22:23:35
【问题描述】:
我正在使用一个数据集,我从每个 tfrecord 文件中解析四个张量。每隔一段时间,四个张量中的一个会是空的,我希望能够过滤掉这个张量并将其余的张量发送到 tf.data 管道的下一步。我将四个张量保存在字典中,我希望能够做这样的事情。
@tf.function
def filter_and_reshape(tensor_dict, shape):
return {k: tf.reshape(t, shape)
for k, t in tensor_dict.items() if not tf.equal(tf.size(t), 0)}
tensor_dict 是我刚刚从文件中解析出来的张量的字典,但还没有恢复到原来的形状。
不幸的是,这不起作用,因为 tf.equal(tf.size(t), 0) 返回张量而不是布尔值,而且签名似乎无法解决问题。
有没有其他方法可以做到这一点?
【问题讨论】:
-
tf.data.Dataset.filter 不能解决你的问题吗?
-
否,因为那是为了过滤掉整条记录。我想过滤掉记录中的张量。
标签: python tensorflow tensorflow2.0 tensorflow-datasets