【问题标题】:Keras CNN ClassifierKeras CNN 分类器
【发布时间】:2020-03-06 17:07:10
【问题描述】:

如果您愿意帮助我,我确实对 Keras 的 CNN 有疑问,我将非常感激。

免责声明:我是 CNN 和 Keras 的菜鸟,我现在只是在学习它们。


我的数据:

2 类(狗和猫)

训练:每个类别 30 张图片

测试:每个类别 14 张图片

有效:每个类别 30 张图片


我的代码:

data_path = Path("../data")

train_path = data_path / "train"
test_path = data_path / "test"
valid_path = data_path / "valid"

train_batch = ImageDataGenerator().flow_from_directory(directory=train_path,
                                                       target_size=(200, 200),
                                                       classes=animals,
                                                       batch_size=10)

valid_batch = ImageDataGenerator().flow_from_directory(directory=valid_path,
                                                       target_size=(200, 200),
                                                       classes=animals,
                                                       batch_size=10)

test_path = ImageDataGenerator().flow_from_directory(directory=test_path,
                                                     target_size=(200, 200),
                                                     classes=animals,
                                                     batch_size=4)

imgs, labels = next(train_batch)

model = Sequential(
    [Conv2D(32, (3, 3), activation="relu", input_shape=(200, 200, 3)), Flatten(),
     Dense(len(animals), activation='softmax')])

model.compile(Adam(lr=.0001), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(train_path, steps_per_epoch=4, validation_data=valid_batch, validation_steps=3, epochs=5, verbose=2)

这是我的错误信息:

我已将路径替换为“”

Traceback (most recent call last):
  File "", line 191, in <module>
    model.fit_generator(train_path, steps_per_epoch=4, validation_data=valid_batch, validation_steps=3, epochs=5, verbose=2)
  File "y", line 91, in wrapper
    return func(*args, **kwargs)
  File "", line 1732, in fit_generator
    initial_epoch=initial_epoch)
  File "", line 185, in fit_generator
    generator_output = next(output_generator)
  File "", line 742, in get
    six.reraise(*sys.exc_info())
  File "", line 693, in reraise
    raise value
  File "", line 711, in get
    inputs = future.get(timeout=30)
  File "", line 657, in get
    raise self._value
  File "", line 121, in worker
    result = (True, func(*args, **kwds))
  File "", line 650, in next_sample
    return six.next(_SHARED_SEQUENCES[uid])
TypeError: 'PosixPath' object is not an iterator

谁能向我解释一下我做错了什么?另外,如果这是一个离题的问题,请告诉我在哪里可以问。

【问题讨论】:

  • 我相信你应该通过 train_batch 而不是 train_path。试试这个:model.fit_generator(train_batch, steps_per_epoch=4, validation_data=valid_batch, validation_steps=3, epochs=5, verbose=2)
  • 哦,完美,你是对的,那是错误。非常感谢你:D
  • 我只是把我的评论变成了答案,因为你告诉我它可以帮助你解决问题!

标签: python python-3.x keras theano


【解决方案1】:

您遇到的问题是您没有通过生成器进行训练,而是通过了文件的路径(您使用的是 train_path 而不是 train_batch强>。

而在使用.fit_generator() 时,您需要为对象传递一个生成器:

model.fit_generator(train_batch, steps_per_epoch=4, validation_data=valid_batch, validation_steps=3, epochs=5, verbose=2)

【讨论】:

    【解决方案2】:

    这行不是必须的

    imgs, labels = next(train_batch)

    来自docs fit_generator 第一个参数是生成器对象,而不是您提供的字符串。像这样

    model.fit_generator(train_path, steps_per_epoch=4, validation_data=valid_batch, validation_steps=3, epochs=5, verbose=2)

    【讨论】:

    • 我确实评论了该链接,但没有任何变化,一切都完全一样
    • 你是否也更新了你的model.fit_generator 函数,正如我上面所说的那样?
    猜你喜欢
    • 1970-01-01
    • 2021-11-18
    • 1970-01-01
    • 1970-01-01
    • 2021-07-22
    • 1970-01-01
    • 2020-10-21
    • 2019-09-07
    • 2018-06-13
    相关资源
    最近更新 更多