【问题标题】:How to plot class name from tensorflow dataset?如何从张量流数据集中绘制类名?
【发布时间】:2020-02-09 19:03:57
【问题描述】:

我从 TensorFlow 网站获取了一个函数来在我的笔记本中显示一批图像。我想以website 上显示的方式打印它,并带有上面图像的类。 函数代码如下:

def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
    ax = plt.subplot(5,5,n+1)
    plt.imshow(image_batch[n])
    plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
    plt.axis('off')

问题在于 plt.title... 行。我收到错误:无法将 1 转换为 dtype bool 的 EagerTensor

我不明白问题出在哪里,因为我完全按照网站教程中的处理方式处理了我的数据。

标签返回一个形状数组:[False False True False] 并且应该根据这个打印类名(我有 4 个类)。但事实并非如此。该函数的其余部分工作得很好,但只显示图像而不显示每个图像所属的类的名称是没有用的。

【问题讨论】:

    标签: python tensorflow tensorflow2.0 tensor tensorflow-datasets


    【解决方案1】:

    我没有找到一个很好的方法来做到这一点,所以我用一个额外的 for 循环来做到这一点。我浏览了标签批次并使用真实值保存了索引。

    def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(np.squeeze(image_batch[n]), cmap = 'gray')
        ix = 0
        for a in label_batch[n]:
            if a == 1:
                break;
            else:
                ix+=1
        plt.title(CLASS_NAMES[ix].title())
        plt.axis('off')  
    

    只是为了让例子更清楚:

    • class_names如下[class1, class2, class3, class4]
    • label_batch 是另一个数组[false, false, true, false]
    • 在这种情况下,正确的索引是 2(计数从 0 开始),我需要的类是 class3

    【讨论】:

      猜你喜欢
      • 2021-12-10
      • 2017-11-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-09-29
      • 2018-08-04
      • 1970-01-01
      相关资源
      最近更新 更多