【问题标题】:How to plot accuracy curve in Tensorflow如何在 Tensorflow 中绘制精度曲线
【发布时间】:2019-03-14 23:05:01
【问题描述】:

我正在关注this tutorial 为 MNIST 分类构建一个简单的网络。我想绘制相同的损失和准确度曲线。我看到了this SO post,得到了一个不错的损失曲线。但我不知道如何为准确性做同样的事情。我在optimise function中尝试了以下代码

session.run(optimizer, feed_dict=feed_dict_train)
if i % 100 == 0:
    loss=session.run(cost, feed_dict=feed_dict_train)
    acc = session.run(accuracy, feed_dict=feed_dict_train)
    loss_list.append(loss)
    acc_list.append(acc)
    plt.plot(loss_list,acc_list)

得到错误

InvalidArgumentError:您必须使用 dtype int64 和 shape [?] 为占位符张量“y_true_cls”提供一个值 [[节点y_true_cls(定义于:2)]]

由于某种原因,我无法在 optimise function 中运行 accuracy。如何获得准确度曲线?

【问题讨论】:

  • 您需要提供第二个值的错误状态。考虑为您的问题添加完整代码
  • @Sharky 代码与 GitHub 链接相同,我只是在优化函数中添加了额外的代码

标签: python tensorflow matplotlib


【解决方案1】:

y_true_cls 说明您需要提供真正的类标签。来自您提到的博客:-

feed_dict_test = {x: data.x_test,
                  y_true: data.y_test,
                  y_true_cls: data.y_test_cls}
def print_accuracy():
    # Use TensorFlow to compute the accuracy.
    acc = session.run(accuracy, feed_dict=feed_dict_test)

    # Print the accuracy.
    print("Accuracy on test-set: {0:.1%}".format(acc))

如您所见,准确度由 feed_dict_test 提供,其中 y_true_cls 为“True_Labels”。 只有在预测和真实类标签的帮助下,我们才能找到准确度。

【讨论】:

  • 哦,是的。这就是错误的原因。但是那我如何得到准确度曲线呢?
猜你喜欢
  • 2016-11-07
  • 2021-08-17
  • 1970-01-01
  • 2016-08-24
  • 1970-01-01
  • 2022-06-30
  • 2017-02-11
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多