【发布时间】:2020-04-06 23:54:21
【问题描述】:
假设我有一个包含两个张量的批次,并且补丁中的张量大小为 3。
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, -0.7]]
现在我想从补丁中的每个张量中提取一个基于索引的单个元素:
index = [0, 2]
因此输出应该是
out = [0.3, -0.7] # Get index 0 from the first tensor in the batch and index 2 from the second tensor in the batch.
当然,这应该可以扩展到大批量。 index 的维度等于批量大小。
我尝试申请tf.gather 和tf.gather_nd,但没有得到我想要的结果。
例如下面的代码打印0.7 而不是上面指定的期望结果:
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]]
index = [0, 2]
out = tf.gather_nd(data, index)
print(out.numpy())
【问题讨论】:
标签: numpy tensorflow