【发布时间】:2020-02-08 11:02:15
【问题描述】:
我有一个 3 阶张量和另一个相同形状的空张量。我试图找到所有 X、Y 位置的第 3(Z)轴上的最大值,并将 1 插入空张量中的相应位置。我通过以下方式在 numpy 中实现了这一点
a = np.random.rand(5,5,3)>=0.5
empty_tensor = np.zeros((5,5,3))
max_z_indices = a.argmax(axis=-1)
empty_tensor[np.arange(a.shape[0])[:,None],np.arange(a.shape[1]),max_z_indices] = 1
在张量流中我有
a_tf = tf.Variable(a)
empty_tensor_tf = tf.Variable(np.zeros((5,5,3)))
max_z_indices = sess.run(tf.argmax(a_tf,axis=-1))
我知道我可以沿张量 a_tf 的第三维显式编写最大值的 X、Y、Z 索引,并使用 tf.scatter_nd_update 更新 empty_tensor_tf 但我希望找到更好的方法(广播)就像 numpy 代码的最后一行一样。
【问题讨论】:
标签: python python-2.7 tensorflow