【问题标题】:InvalidArgumentError with model.fit in TensorflowTensorflow 中带有 model.fit 的 InvalidArgumentError
【发布时间】:2021-04-07 14:08:40
【问题描述】:

使用 CNN 进行图像分类。当model.fit()被调用时,它开始训练模型一段时间,在执行过程中被中断并返回错误信息。

错误信息如下

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Input size should match (header_size + row_size * abs_height) but they differ by 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]]
     [[IteratorGetNext/_4]]
  (1) Invalid argument:  Input size should match (header_size + row_size * abs_height) but they differ by 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_8873]

Function call stack:
train_function -> train_function

更新:我的建议是检查数据集的元数据。它帮助解决了我的问题。

【问题讨论】:

  • 代码存在一些问题,但我注意到的主要问题是您正在为训练数据集和测试数据集加载相同的目录。
  • @yudhiesh 你的意思是训练集和验证集?是的,它们是使用 image_dataset_from_directory() 和不同子集从同一目录加载的。测试集在另一个文件夹中分离。由于它与问题关系不大,所以我没有包括它。
  • 很抱歉,这实际上是正确的。我将添加一个包含更改的答案。
  • @yudhiesh 没关系。稍后我会尝试分享访问数据集的链接。
  • 你没有具体说明你是如何修复它的?你提到检查元数据,但要寻找什么?你发现了什么?你究竟做了什么来修复它?

标签: python tensorflow keras deep-learning google-colaboratory


【解决方案1】:

您不必指定参数 label_mode 。为了使用SparseCategoricalCrossentropy 作为损失函数,您需要将其设置为int。 如果您不指定它,则将其设置为Noneper the documentation

您还需要根据从中读取图像的目录结构将参数labels 指定为inferred

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  labels="inferred",
  label_mode="int",
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
  
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  labels="inferred",
  label_mode="int",
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

【讨论】:

  • 感谢您的提醒。我只是尝试运行它,但仍然返回相同的错误...
  • 您能打印出img_height, img_width, IMG_SHAPE 的值并将其添加到问题中吗?
  • 就在那儿。 img_heightimg_width 分别为 180。 IMG_SHAPE(180, 180, 3)
  • 我无法检查数据集中的图像,因为它太大了,但我猜测图像中的输入形状与您在创建模型时指定的大小不同。
  • 我想这可能不是我在导入数据集时指定输入形状inputs = tf.keras.Input(shape=(180, 180, 3)) 的原因,这与img_heightimg_width 相同。而且它也没有解释为什么当其中一个类被删除时它工作得很好......
猜你喜欢
  • 2023-03-22
  • 2017-12-02
  • 1970-01-01
  • 1970-01-01
  • 2021-05-19
  • 1970-01-01
  • 1970-01-01
  • 2017-08-02
  • 1970-01-01
相关资源
最近更新 更多