【问题标题】:on_epoch_end() not called in keras fit_generator()on_epoch_end() 未在 keras fit_generator() 中调用
【发布时间】:2020-04-25 23:19:03
【问题描述】:

我跟随this tutorial 使用fit_generator() Keras 方法即时生成数据,以训练我的神经网络模型。

我使用keras.utils.Sequence 类创建了一个生成器。对fit_generator() 的调用是:

history = model.fit_generator(generator=EVDSSequence(images_train, TRAIN_BATCH_SIZE, INPUT_IMG_DIR, INPUT_JSON_DIR, SPLIT_CHAR, sizeArray, NCHW, shuffle=True),
                              steps_per_epoch=None, epochs=EPOCHS,
                              validation_data=EVDSSequence(images_valid, VALID_BATCH_SIZE, INPUT_IMG_DIR, INPUT_JSON_DIR, SPLIT_CHAR, sizeArray, NCHW, shuffle=True),
                              validation_steps=None,
                              callbacks=callbacksList, verbose=1,
                              workers=0, max_queue_size=1, use_multiprocessing=False)

steps_per_epochNone,所以每个epoch的步数是通过Keras的__len()__方法计算出来的。

如上链接所述:

这里,on_epoch_end 方法在每个 epoch 的开始和结束时触发一次。如果shuffle 参数设置为True,我们将在每次通过时获得一个新的探索顺序(否则保持线性探索方案)。

我的问题是 on_epoch_end() 方法只在开始时被调用,而不会在每个纪元结束时被调用。 因此,在每个 epoch,批次顺序始终相同。

我尝试在__len__() 方法中使用np.ceil 而不是np.floor,但没有成功。

你知道为什么 on_epoch_end 在每个 epoch 结束时不被调用吗?你能告诉我在每个时期结束(或开始)时调整批次顺序的任何解决方法吗?

非常感谢!

【问题讨论】:

    标签: python tensorflow keras deep-learning training-data


    【解决方案1】:

    我遇到了同样的问题。我不知道为什么会发生这种情况,但有一种解决方法:在__len__() 内调用on_epoch_end(),因为__len__() 将在每个时期被调用。

    【讨论】:

    • 感谢您的回复!无论如何,几天前我已经(临时)以这种方式解决了,但这会打乱所有样本。我最初的目标是只打乱批次提供给网络的顺序,保留单个批次中样品的顺序。
    • 非常感谢,我也遇到了同样的问题,直到我发现我的实验结果很奇怪时才意识到。
    【解决方案2】:

    可能与问题有关:Keras model.fit not calling Sequence.on_epoch_end() #35911

    快速解决方法是使用LambdaCallback(请注意,我使用fit 就足够了,因为不推荐使用fit_generator

    from tf.keras.callbacks import LambdaCallback
    
    model.fit(generator, callbacks=[LambdaCallback(on_epoch_end=generator.on_epoch_end)])
    

    希望对你有帮助!

    【讨论】:

      【解决方案3】:

      而且我发现当您创建 on_predict_end() callback_lambda 时,它不会在预测结束时调用。顺便说一句,predict() 接受一个 callbacks=list(...) 参数。

      此外,您似乎可以像这样测试回调:

      (create your 'model' object)
      callback_batch_end <- callback_lambda(
          on_batch_end = function(batch, logs) {
              cat("Hello world\n")
          }
      )
      callback_batch_end$on_batch_end(1, "x")
      (prints 'Hello world')
      callback_predict_end <- callback_lambda(
          on_predict_end = function(logs) {
              cat("Hello world\n")
          }
      )
      callback_predict_end$on_predict_end("x")
      (prints nothing)
      

      【讨论】:

        猜你喜欢
        • 2020-03-14
        • 1970-01-01
        • 2019-10-27
        • 1970-01-01
        • 1970-01-01
        • 2019-05-18
        • 1970-01-01
        • 2020-03-06
        • 2019-04-25
        相关资源
        最近更新 更多