【问题标题】:Batched Gather/GatherND批量收集/GatherND
【发布时间】:2018-07-03 04:58:40
【问题描述】:

我想知道是否有办法在 TensorFlow 中执行以下操作,使用 gather_nd 或类似的东西。

我有两个张量:

  • values 形状为[128, 100]
  • indices 形状为[128, 3]

indices 的每一行都包含沿values 的第二维的索引(对于同一行)。我想使用indices 索引values。例如,我想要这样做的东西(使用松散的符号来表示张量):

values  = [[0, 0, 0, 1, 1, 0, 1], 
           [1, 1, 0, 0, 1, 0, 0]]
indices = [[2, 3, 6], 
           [0, 2, 3]]
batched_gather(values, indices) = [[0, 1, 1], [1, 0, 0]]

此操作将遍历valuesindices 的每一行,并使用indices 行对values 行执行聚集。

在 TensorFlow 中是否有一种简单的方法可以做到这一点?

谢谢!

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    不确定这是否符合“简单”的条件,但您可以使用gather_nd

    def batched_gather(values, indices):
        row_indices = tf.range(0, tf.shape(values)[0])[:, tf.newaxis]
        row_indices = tf.tile(row_indices, [1, tf.shape(indices)[-1]])
        indices = tf.stack([row_indices, indices], axis=-1)
        return tf.gather_nd(values, indices)
    

    解释:想法是构造索引向量,例如[0, 1],意思是“第0行第1列的值”。
    列索引已在函数的 indices 参数中给出。
    行索引是从 0 到例如的简单级数。 128(在您的示例中),但根据每行的列索引数重复(平铺)(在您的示例中为 3;如果此数字是固定的,则可以对其进行硬编码,而不是使用 tf.shape)。
    然后堆叠行和列索引以生成索引向量。在您的示例中,生成的索引将是

    array([[[0, 2],
            [0, 3],
            [0, 6]],
    
           [[1, 0],
            [1, 2],
            [1, 3]]])
    

    gather_nd 会产生所需的结果。

    【讨论】:

    • 谢谢!这确实可以满足我的需要。我想知道是否还有一种更有效的方法可以避免创建中间范围和平铺张量。这有点像不必要的低效率。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-07-12
    • 2013-08-07
    • 1970-01-01
    • 2022-11-03
    相关资源
    最近更新 更多