【问题标题】:Tensorflow 2 - what is 'index depth' in tensor_scatter_nd_update?Tensorflow 2 - tensor_scatter_nd_update 中的“索引深度”是什么?
【发布时间】:2021-07-25 09:38:41
【问题描述】:

请解释tf.tensor_scatter_nd_update的索引深度是什么。

tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)

为什么一维张量的索引是二维的?

indexes 至少有两个轴,最后一个轴是索引向量的深度。 对于更高等级的输入张量标量更新,可以使用匹配 tf.rank(tensor) 的 index_depth 插入:

tensor = [0, 0, 0, 0, 0, 0, 0, 0]    # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]]       # num_updates == 4, index_depth == 1   # <--- what is depth and why 2D for 1D tensor?
updates = [9, 10, 11, 12]            # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))

tensor = [[1, 1], [1, 1], [1, 1]]    # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]]           # num_updates == 2, index_depth == 2
updates = [5, 10]                    # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))

【问题讨论】:

    标签: tensorflow2.0


    【解决方案1】:

    对于indicesindex depth索引向量的大小或长度。例如:

    indicesA = [[1], [3], [4], [7]] # index vector with 1 element: index_depth = 1
    indicesB = [[0, 1], [2, 0]]     # index vector with 2 element: index_depth = 2
    

    索引的原因是2D是为了保存两个信息,一个是更新的长度(num_updates索引向量的长度 .需要满足两点:

    • indicesindex depth 必须等于 input 张量的 rank
    • updates 的长度必须等于indices长度

    所以,在示例代码中

    # tf.rank(tensor) == 1
    tensor = [0, 0, 0, 0, 0, 0, 0, 0]    
    
    # num_updates == 4, index_depth == 1 | tf.rank(indices).numpy() == 2 
    indices = [[1], [3], [4], [7]]    
    
    # num_updates == 4 | tf.rank(output).numpy() == 1  
    updates = [9, 10, 11, 12]        
    
    output = tf.tensor_scatter_nd_update(tensor, indices, updates)
    tf.Tensor([ 0  9  0 10 11  0  0 12], shape=(8,), dtype=int32)
    

    还有

    # tf.rank(tensor) == 2
    tensor = [[1, 1], [1, 1], [1, 1]]    
    
     # num_updates == 2, index_depth == 2 | tf.rank(indices).numpy() == 2
    indices = [[0, 1], [2, 0]]          
    
    # num_updates == 2 | tf.rank(output).numpy() == 2
    updates = [5, 10]       
                 
    output = tf.tensor_scatter_nd_update(tensor, indices, updates)
    tf.Tensor(
    [[ 1  5]
     [ 1  1]
     [10  1]], shape=(3, 2), dtype=int32)
    
    num_updates, index_depth = tf.convert_to_tensor(indices).shape.as_list()
    [num_updates, index_depth]
    [2, 2]
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-06-15
      • 1970-01-01
      • 2011-12-06
      • 1970-01-01
      • 2022-08-23
      • 1970-01-01
      • 2022-12-25
      • 1970-01-01
      相关资源
      最近更新 更多