【问题标题】:How do I get the first 3 max numbers from tensor如何从张量中获取前 3 个最大数字
【发布时间】:2021-11-19 11:59:44
【问题描述】:

如何从y_classe = tf.argmax(preds, axis=1, output_type=tf.int32) 获得前 3 个最大数字?

【问题讨论】:

    标签: python tensorflow tensor


    【解决方案1】:

    你可以排序取前3个:

    import tensorflow as tf
    
    a = [1, 10, 26.9, 2.8, 166.32, 62.3]
    sorted_a = tf.sort(a,direction='DESCENDING')
    max_3 = tf.gather(sorted_a, [0,1,2])
    print(max_3)
    

    【讨论】:

    • 我在预测函数中使用这个函数。 preds = model([img_feat, ques_feat]) y_classe = tf.argmax(preds, axis=1, output_type=tf.int32) sorted_a = tf.sort(preds, direction='DESCENDING') max_3 = tf.gather(sorted_a, [0, 1, 2]) 但它显示错误:ensorflow.python.framework.errors_impl.InvalidArgumentError: indices[1] = 1 is not in [0, 1) [Op:GatherV2]
    【解决方案2】:

    你可以使用tf.math.top_k:

    import tensorflow as tf
    
    y_pred = [[-18.6, 0.51, 2.94, -12.8]]
    
    max_entries = 3
    
    values, indices = tf.math.top_k(y_pred, k=max_entries)
    print(values)
    print(indices)
    
    tf.Tensor([[  2.94   0.51 -12.8 ]], shape=(1, 3), dtype=float32)
    tf.Tensor([[2 1 3]], shape=(1, 3), dtype=int32)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2012-08-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多