【发布时间】:2021-07-25 09:38:41
【问题描述】:
请解释tf.tensor_scatter_nd_update的索引深度是什么。
tf.tensor_scatter_nd_update(
tensor, indices, updates, name=None
)
为什么一维张量的索引是二维的?
indexes 至少有两个轴,最后一个轴是索引向量的深度。 对于更高等级的输入张量标量更新,可以使用匹配 tf.rank(tensor) 的 index_depth 插入:
tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1 # <--- what is depth and why 2D for 1D tensor?
updates = [9, 10, 11, 12] # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
updates = [5, 10] # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
【问题讨论】:
标签: tensorflow2.0