【问题标题】:Basic classification in tensorflow tutorialtensorflow教程中的基本分类
【发布时间】:2019-08-02 04:33:34
【问题描述】:

我刚开始学习 tensorflow,正在他们的官方页面上编写基本分类教程。

Basic Classification Tutorial

从下面的一段代码

def plot_image(i, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])

plt.imshow(img, cmap=plt.cm.binary)

predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'

plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                            100*np.max(predictions_array),
                            class_names[true_label]),
                            color=color)

def plot_value_array(i, predictions_array, true_label):
predictions_array, true_label = predictions_array[i], true_label[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777")
plt.ylim([0, 1])
predicted_label = np.argmax(predictions_array)

thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue')

以下是测试数据的样本结果。

案例 1:

案例 2

虽然系统100%预测,但为什么结果显示为红色?

在预测的标签中,它没有显示任何其他类。

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:

    代码中存在错误。

    • "i" 将用于 predictions_array,取第一个数组元素(它是一个数组)。这部分没问题,但它始终是预测数组所需的索引 0。解决此问题的两种方法:在调用时将其传递为“predictions_array [0]”,就像我在下面所做的那样。或者修改函数以包含“predictions_array = predictions_array[0]”

    • 由于预测数组的“i”必须为 0,因此在原始代码中将始终检查 test_labels[0]。当您预测 9 以外的值时,这将在所有情况下给出红色(因为它认为这是一个错误的预测)。因此,通过 i 作为测试图像的索引将为您提供正确的标签。

    修改功能的建议:

    def plot_value_array(i, predictions_array, true_label):
      print(true_label)
      true_label = true_label[i]
      plt.grid(False)
      plt.xticks([])
      plt.yticks([])
      thisplot = plt.bar(range(10), predictions_array, color="#777777")
      plt.ylim([0, 1])
      predicted_label = np.argmax(predictions_array)
    
      thisplot[predicted_label].set_color('red')
      thisplot[true_label].set_color('blue')
    

    并修改了调用,其中“1”是我在这种情况下测试的图像(将其设为变量,这样您在测试时不必输入两次)。 换句话说:如果 img = test_images[1] 我必须将 1 传递给函数。

    plot_value_array(1, predictions_single[0], test_labels)
    plt.xticks(range(10), class_names, rotation=45)
    plt.show()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-10-25
      • 1970-01-01
      • 2021-06-08
      • 2022-01-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多