【发布时间】:2020-02-09 19:03:57
【问题描述】:
我从 TensorFlow 网站获取了一个函数来在我的笔记本中显示一批图像。我想以website 上显示的方式打印它,并带有上面图像的类。 函数代码如下:
def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
ax = plt.subplot(5,5,n+1)
plt.imshow(image_batch[n])
plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
plt.axis('off')
问题在于 plt.title... 行。我收到错误:无法将 1 转换为 dtype bool 的 EagerTensor
我不明白问题出在哪里,因为我完全按照网站教程中的处理方式处理了我的数据。
标签返回一个形状数组:[False False True False] 并且应该根据这个打印类名(我有 4 个类)。但事实并非如此。该函数的其余部分工作得很好,但只显示图像而不显示每个图像所属的类的名称是没有用的。
【问题讨论】:
标签: python tensorflow tensorflow2.0 tensor tensorflow-datasets