【问题标题】:How does argmax work when given a 3d tensor - tensorflow给定 3d 张量时,argmax 如何工作 - tensorflow
【发布时间】:2021-01-15 17:41:14
【问题描述】:

我想知道在给定 3D 张量时 argmax 是如何工作的。我知道当它有 2D 时会发生什么,但 3D 让我很困惑。

例子:

import tensorflow as tf
import numpy as np

sess = tf.Session()

coordinates = np.random.randint(0, 100, size=(3, 3, 2))
coordinates
Out[20]: 
array([[[15, 23],
        [ 3,  1],
        [80, 56]],
       [[98, 95],
        [97, 82],
        [10, 37]],
       [[65, 32],
        [25, 39],
        [54, 68]]])
sess.run([tf.argmax(coordinates, axis=1)])
Out[21]: 
[array([[2, 2],
        [0, 0],
        [0, 2]], dtype=int64)]



【问题讨论】:

    标签: python numpy tensorflow argmax


    【解决方案1】:

    tf.argmax 根据指定的轴返回最大值的索引。指定轴被压碎,返回每个单位最大值的索引。返回的形状将具有相同的形状,除了将消失的指定轴。我会用tf.reduce_max 做例子,所以我们可以遵循这些值。

    让我们从你的数组开始:

    x = np.array([[[15, 23],
                   [3, 1],
                   [80, 56]],
                  [[98, 95],
                   [97, 82],
                   [10, 37]],
                  [[65, 32],
                   [25, 39],
                   [54, 68]]])
    

    tf.reduce_max(x, axis=0)

                ([[[15, 23],
                              [3, 1],
                                         [80, 56]],
                  [[98, 95],               ^
                     ^   ^    [97, 82],
                                ^  ^     [10, 37]],
                  [[65, 32],
                              [25, 39],
                                         [54, 68]]]) 
                                               ^    
    
    <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
    array([[98, 95],
           [97, 82],
           [80, 68]])>
    

    现在tf.reduce_max(x, 1)

                ([[[15, 23], [[98, 95],  [[65, 32],
                                ^   ^       ^
                   [3, 1],    [97, 82],   [25, 39],
               
                   [80, 56]], [10, 37]],  [54, 68]]])
                     ^   ^                      ^
    
    <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
    array([[80, 56],
           [98, 95],
           [65, 68]])>
    

    现在tf.reduce_max(x, axis=2)

                ([[[15, 23],
                         ^
                   [3, 1],
                    ^
                   [80, 56]],
                    ^   
                  [[98, 95],
                     ^
                   [97, 82],
                     ^
                   [10, 37]],
                         ^
                  [[65, 32],
                     ^
                   [25, 39],
                         ^
                   [54, 68]]])
                         ^
    
    <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[23,  3, 80],
           [98, 97, 37],
           [65, 39, 68]])>
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-11-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多