【问题标题】:Determining if A Value is in a Set in TensorFlow确定一个值是否在 TensorFlow 的集合中
【发布时间】:2016-01-05 17:54:49
【问题描述】:

tf.logical_ortf.logical_andtf.select 函数非常有用。

但是,假设您有值 x,并且您想查看它是否在 set(a, b, c, d, e) 中。在 python 中,您只需编写:

if x in set([a, b, c, d, e]):
  # Do some action.

据我所知,在 TensorFlow 中执行此操作的唯一方法是将“tf.logical_or”与“tf.equal”一起嵌套。我在下面只提供了这个概念的一个迭代:

tf.logical_or(
    tf.logical_or(tf.equal(x, a), tf.equal(x, b)),
    tf.logical_or(tf.equal(x, c), tf.equal(x, d))
)

我觉得在 TensorFlow 中一定有更简单的方法可以做到这一点。有吗?

【问题讨论】:

    标签: python set tensorflow


    【解决方案1】:

    要提供更具体的答案,假设您要检查张量 x 的最后一个维度是否包含来自一维张量 s 的任何值,您可以执行以下操作:

    tile_multiples = tf.concat([tf.ones(tf.shape(tf.shape(x)), dtype=tf.int32), tf.shape(s)], axis=0)
    x_tile = tf.tile(tf.expand_dims(x, -1), tile_multiples)
    x_in_s = tf.reduce_any(tf.equal(x_tile, s), -1))
    

    例如,对于sx

    s = tf.constant([3, 4])
    x = tf.constant([[[1, 2, 3, 0, 0], 
                      [4, 4, 4, 0, 0]], 
                     [[3, 5, 5, 6, 4], 
                      [4, 7, 3, 8, 9]]])
    

    x 的形状为[2, 2, 5]s 的形状为[2] 所以tile_multiples = [1, 1, 1, 2],这意味着我们会将x 的最后一个维度平铺2 次(s 中的每个元素一次)新维度。所以,x_tile 看起来像:

    [[[[1 1]
       [2 2]
       [3 3]
       [0 0]
       [0 0]]
    
      [[4 4]
       [4 4]
       [4 4]
       [0 0]
       [0 0]]]
    
     [[[3 3]
       [5 5]
       [5 5]
       [6 6]
       [4 4]]
    
      [[4 4]
       [7 7]
       [3 3]
       [8 8]
       [9 9]]]]
    

    x_in_s 会将每个平铺值与s 中的一个值进行比较。如果任何平铺值在s 中,最后一个暗淡的tf.reduce_any 将返回true,给出最终结果:

    [[[False False  True False False]
      [ True  True  True False False]]
    
     [[ True False False False  True]
      [ True False  True False False]]]
    

    【讨论】:

      【解决方案2】:

      看看这个相关的问题:Count number of "True" values in boolean Tensor

      您应该能够构建一个由 [a, b, c, d, e] 组成的张量,然后使用 tf.equal(.) 检查是否有任何行等于 x

      【讨论】:

      • 感谢您的洞察力。 Reduce_sum 是最好的方法。
      • 你也可以使用tf.listdiff来完成同样的事情。
      • @dga 只显示不同之处,而不是相似之处?
      • 对于 Tensorflow 的新手来说,很难看到您的链接帖子如何解决 OP 的情况。也许您可以在这里发布完整的代码 sn-p?
      【解决方案3】:

      这里有两个解决方案,我们要检查query是否在whitelist

      whitelist = tf.constant(["CUISINE", "DISH", "RESTAURANT", "ADDRESS"])
      query = "RESTAURANT"
      
      #use broadcasting for element-wise tensor operation
      broadcast_equal = tf.equal(whitelist, query)
      
      #method 1: using tensor ops
      broadcast_equal_int = tf.cast(broadcast_equal, tf.int8)
      broadcast_sum = tf.reduce_sum(broadcast_equal_int)
      
      #method 2: using some tf.core API
      nz_cnt = tf.count_nonzero(broadcast_equal)
      
      sess.run([broadcast_equal, broadcast_sum, nz_cnt])
      #=> [array([False, False,  True, False]), 1, 1]
      

      因此,如果输出为 > 0,则该项目在集合中。

      【讨论】:

      • querywhitelist 有多个元素时,它是如何工作的?
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2011-05-23
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2023-02-18
      • 1970-01-01
      相关资源
      最近更新 更多