【问题标题】:locations of maximum values of tensorflow tensor across an axis张量流张量的最大值在轴上的位置
【发布时间】:2020-02-08 11:02:15
【问题描述】:

我有一个 3 阶张量和另一个相同形状的空张量。我试图找到所有 X、Y 位置的第 3(Z)轴上的最大值,并将 1 插入空张量中的相应位置。我通过以下方式在 numpy 中实现了这一点

a = np.random.rand(5,5,3)>=0.5
empty_tensor = np.zeros((5,5,3))
max_z_indices = a.argmax(axis=-1)
empty_tensor[np.arange(a.shape[0])[:,None],np.arange(a.shape[1]),max_z_indices] = 1

在张量流中我有

a_tf = tf.Variable(a)
empty_tensor_tf = tf.Variable(np.zeros((5,5,3)))
max_z_indices = sess.run(tf.argmax(a_tf,axis=-1))

我知道我可以沿张量 a_tf 的第三维显式编写最大值的 X、Y、Z 索引,并使用 tf.scatter_nd_update 更新 empty_tensor_tf 但我希望找到更好的方法(广播)就像 numpy 代码的最后一行一样。

【问题讨论】:

    标签: python python-2.7 tensorflow


    【解决方案1】:

    您可以使用tf.reduce_max 获取每个z-index 的最大值,然后根据cond 使用tf.where 将其转换为1 或0。

    import tensorflow as tf
    
    # tf_a is (5,5,3) tensor
    max_val = tf.reduce_max(tf_a, axis=-1,keepdims=True)
    cond = tf.equal(tf_a, max_val)
    res = tf.where(cond, tf.ones_like(tf_a), tf.zeros_like(tf_a))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-05-01
      • 2023-01-11
      • 2018-02-24
      • 1970-01-01
      • 1970-01-01
      • 2017-07-31
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多