【问题标题】:Tensorflow: Filter a tensor out of a list of tensorsTensorflow:从张量列表中过滤出一个张量
【发布时间】:2020-02-01 22:23:35
【问题描述】:

我正在使用一个数据集,我从每个 tfrecord 文件中解析四个张量。每隔一段时间,四个张量中的一个会是空的,我希望能够过滤掉这个张量并将其余的张量发送到 tf.data 管道的下一步。我将四个张量保存在字典中,我希望能够做这样的事情。

@tf.function
def filter_and_reshape(tensor_dict, shape):
    return {k: tf.reshape(t, shape)
            for k, t in tensor_dict.items() if not tf.equal(tf.size(t), 0)}

tensor_dict 是我刚刚从文件中解析出来的张量的字典,但还没有恢复到原来的形状。

不幸的是,这不起作用,因为 tf.equal(tf.size(t), 0) 返回张量而不是布尔值,而且签名似乎无法解决问题。

有没有其他方法可以做到这一点?

【问题讨论】:

  • tf.data.Dataset.filter 不能解决你的问题吗?
  • 否,因为那是为了过滤掉整条记录。我想过滤掉记录中的张量。

标签: python tensorflow tensorflow2.0 tensorflow-datasets


【解决方案1】:

tf.data.Dataset 生成的所有元素必须具有相同的结构。例如,如果数据集生成带有键“a”、“b”、“c”、“d”的字典,则数据集生成的所有元素将始终具有这四个键。你不能产生一些带有 4 个键的元素和其他带有 3 个键的元素。

我建议更改模型代码以忽略可能为空的张量(如果它为空)。

要将 4 个可能为空的张量相加,您可以这样做

import tensorflow as tf
tf.enable_v2_behavior()

a = tf.constant([1, 1, 1])
b = tf.constant([2, 2, 2])
c = tf.constant([], dtype=tf.int32)
d = tf.constant([7, 7, 7])

def add_py_fn(a, b, c, d):
  s = tf.zeros((3,), dtype=tf.int32)
  for t in [a, b, c, d]:
    if tf.size(t) > 0:
      s += t
  return s

def add_fn(a, b, c, d):
  return tf.py_function(add_py_fn, [a, b, c, d], tf.int32)

ds = tf.data.Dataset.from_tensors((a, b, c, d))
ds = ds.map(add_fn)
print(next(iter(ds)))

需要 py_function 转义舱口,因为运行时无法判断 c 永远不会添加到 s,因此它会抛出形状错误。

【讨论】:

  • 您说得对,数据集的输出必须始终相同,但您可以在数据集管道本身中更改数据的形状。例如,您可以获取 4 个元素,然后将它们添加在一起并将其传递给模型。我要求的是当其中一个为空时仅添加 3。
  • 我已经更新了答案,以展示一种仅在元素非空时才添加元素的方法。
  • 谢谢!您实际上不需要 pyfunc 来添加。您可以使用 tf.cond 检查是否为空,然后在这种情况下返回一个零张量并添加它。我更感兴趣的是能够过滤掉其中一个张量,但也许这是不可能的。我认为签名可能会有所帮助,但也许没有。
猜你喜欢
  • 2017-12-23
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2016-06-20
  • 2020-12-04
相关资源
最近更新 更多