解决方案
我已经编写了一个如何执行计算的示例。此示例中的所有输入都被编码为 tf.constant ,但当然您可以替换您的变量。
主要技巧是矩阵乘法。首先是 input_y 重新装入为2d矩阵(称为 to_top5 )。第二个是 corrict_preds 加权_matrix 。
代码 h1>
import tensorflow as tf
input_y = tf.constant( [5,2,9,1] , dtype=tf.int32 )
predictions = tf.constant( [[9,3,5,2,1],[8,9,0,6,5],[1,9,3,4,5],[1,2,3,4,5]])
to_top5 = tf.constant( [[1,1,1,1,1]] , dtype=tf.int32 )
input_y_for_top5 = tf.matmul( tf.reshape(input_y,[-1,1]) , to_top5 )
correct_preds = tf.cast( tf.equal( input_y_for_top5 , predictions ) , dtype=tf.float16 )
weighted_matrix = tf.constant( [[5.],[4.],[3.],[2.],[1.]] , dtype=tf.float16 )
weighted = tf.matmul(correct_preds,weighted_matrix)
recall = tf.reduce_sum(correct_preds) / tf.cast( tf.reduce_sum(input_y) , tf.float16)
precision = tf.reduce_sum(correct_preds) / tf.constant(5.0,dtype=tf.float16)
## training
# Run tensorflow and print the result
with tf.Session() as sess:
print "\n\n=============\n\n"
print "\ninput_y_for_top5"
print sess.run(input_y_for_top5)
print "\ncorrect_preds"
print sess.run(correct_preds)
print "\nweighted"
print sess.run(weighted)
print "\nrecall"
print sess.run(recall)
print "\nprecision"
print sess.run(precision)
print "\n\n=============\n\n"
输出 h1>
=============
input_y_for_top5
[[5 5 5 5 5]
[2 2 2 2 2]
[9 9 9 9 9]
[1 1 1 1 1]]
correct_preds
[[ 0. 0. 1. 0. 0.]
[ 0. 0. 0. 0. 0.]
[ 0. 1. 0. 0. 0.]
[ 1. 0. 0. 0. 0.]]
weighted
[[ 3.]
[ 0.]
[ 4.]
[ 5.]]
recall
0.17651
precision
0.6001
=============
总结
上述实施例显示了批量强>尺寸为4。
第一个批处理具有 y_label 5 ,这意味着具有索引为5的元素是第一个批次的正确标签。此外,第一个批处理的预测是[9,3,5,2,1],这意味着预测功能认为第9个元素是最有可能的,然后元素3是下一个最有可能等等的。
假设我们想要一个批量大小为3的示例,然后使用以下代码
input_y = tf.constant( [5,2,9] , dtype=tf.int32 )
predictions = tf.constant( [[9,3,5,2,1],[8,9,0,6,5],[1,9,3,4,5]])
如果我们将上述行替换为程序,我们可以看到它确实可以正确计算一切批量3的一切。