【问题标题】:Calculate accuracy over the entire training set计算整个训练集的准确率
【发布时间】:2017-10-09 21:03:02
【问题描述】:

我正在使用张量流对卷积神经网络进行第一次测试。我使用编程指南中的队列运行器调整了推荐的方法(请参阅下面的会话定义)。输出是来自 cnn 的最后一个结果(这里只给出最后一步)。 label_batch_vector 是训练标签批次。

output = tf.matmul(h_pool2_flat, W_fc1) + b_fc1
label_batch_vector = tf.one_hot(label_batch, 33)

correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(label_batch_vector, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

print_accuracy = tf.Print(accuracy, [accuracy])

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_step)
        sess.run(print_accuracy)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

我的问题是为每个批次计算准确性,我希望为每个时期计算它。我需要执行以下操作:初始化一个 epoch_accuracy 张量,对于该 epoch 中计算的每个批次精度,将其添加到 epoch_accuracy。在 epoch 结束时显示计算的训练集精度。但是,我在我实现的这个队列线程中没有找到任何这样的例子(这实际上是 TensorFlow 推荐的方法)。谁能帮忙?

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    要计算数据流的准确性(您的批次序列,请点击此处),您可以使用 tensorflow 中的 tf.metrics.accuracy 函数。请参阅其文档here

    你这样定义操作

    _, accuracy = tf.metrics.accuracy(y_true, y_pred)
    

    那么你可以这样更新准确率:

    sess.run(accuracy)
    

    PS:tf.metrics 中的所有函数(auc、recall 等)都支持流式传输

    【讨论】:

    • 我还没有测试过,但我感觉它会为所有时期计算一个单一的值。我需要每个时代的价值。
    • 我现在已经测试并给出了我产生输入的方式,它不会那么容易计算每个时期。
    • 我有一个问题:在 sess.run(train_step) 之后的每批之后,网络模型中的权重都会不断更新。那么计算为准确度的不是整个数据集上最新权重的准确度,而是作为计算准确度平均值的准确度,直到当前时间发生变化的权重?我理解正确吗?
    • 我设法实施了最终解决方案。结论是,使用流式传输无法测试准确性并优化同一会话中的权重。我在训练后保存了模型,并从不同的脚本开始了队列,加载了模型,然后只测试了准确性。我将答案标记为正确,因为按照建议,我使用 tf.metrics.accuracy 来计算准确性。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2018-09-27
    • 2017-10-05
    • 2023-03-17
    • 2018-07-20
    • 2021-10-13
    • 2018-12-26
    • 2020-04-07
    相关资源
    最近更新 更多