【问题标题】:Save history of model.fit for different epochs保存不同时期的 model.fit 历史
【发布时间】:2020-06-28 19:45:31
【问题描述】:

我正在使用 epoch=10 训练我的模型。我再次用 epoch=3 重新训练。又是第 5 纪元。 所以每次我用 epoch=10、3、5 训练模型。我想结合所有 3 的历史。例如,让 h1 = model.fit 的历史,对于 epoch=10,h2 = model.fit 的历史对于 epoch=3, h3 = model.fit for epoch=5 的历史。

现在在变量 h 中,我想要 h1 + h2 + h3。所有历史记录都附加到单个变量,以便我可以绘制一些图表。

代码是,

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=10, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")


start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=3, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")

start_time = time.time()

model.fit(x=X_train, y=y_train, batch_size=32, epochs=5, validation_data=(X_val, y_val), callbacks=[tensorboard, checkpoint])

end_time = time.time()
execution_time = (end_time - start_time)
print(f"Elapsed time: {hms_string(execution_time)}")

【问题讨论】:

    标签: tensorflow keras neural-network epoch


    【解决方案1】:

    您可以通过创建一个子类tf.keras.callbacks.Callback 并使用该类的对象作为model.fit 的回调来实现此功能。

    import csv
    import tensorflow.keras.backend as K
    from tensorflow import keras
    import os
    
    model_directory='./xyz' # directory to save model history after every epoch 
    
    class StoreModelHistory(keras.callbacks.Callback):
    
      def on_epoch_end(self,batch,logs=None):
        if ('lr' not in logs.keys()):
          logs.setdefault('lr',0)
          logs['lr'] = K.get_value(self.model.optimizer.lr)
    
        if not ('model_history.csv' in os.listdir(model_directory)):
          with open(model_directory+'model_history.csv','a') as f:
            y=csv.DictWriter(f,logs.keys())
            y.writeheader()
    
        with open(model_directory+'model_history.csv','a') as f:
          y=csv.DictWriter(f,logs.keys())
          y.writerow(logs)
    
    
    model.fit(...,callbacks=[StoreModelHistory()])
    

    然后你可以加载 csv 文件并绘制模型的损失、学习率、指标等。

    import pandas as pd
    import matplotlib.pyplot as plt
    
    EPOCH = 10 # number of epochs the model has trained for
    
    history_dataframe = pd.read_csv(model_directory+'model_history.csv',sep=',')
    
    
    # Plot training & validation loss values
    plt.style.use("ggplot")
    plt.plot(range(1,EPOCH+1),
             history_dataframe['loss'])
    plt.plot(range(1,EPOCH+1),
             history_dataframe['val_loss'],
             linestyle='--')
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.show()
    

    【讨论】:

      【解决方案2】:

      每次调用model.fit(),它都会返回一个keras.callbacks.History 对象,其history 属性包含一个字典。字典的键是 loss 用于训练,val_loss 用于验证损失,以及您在编译时可能设置的任何其他 metrics

      因此,在你的情况下,你可以这样做:

      hist1 = model.fit(...)
      
      # other code lines
      
      hist2 = model.fit(...)
      
      # other code lines
      
      hist3 = model.fit(...)
      
      # create an empty dict to save all three history dicts into
      total_history_dict = dict()
      
      for some_key in hist1.keys():
          current_values = [] # to save values from all three hist dicts
          for hist_dict in [hist1.history, hist2.history, hist3.history]:
              current_values += hist_dict[some_key]
          total_history_dict[some_key] = current_values
      

      现在,total_history_dict 是一个字典键,其中像往常一样是 lossval_lossother metrics 和显示 loss/ 的值列表每个时期的指标。 (列表的长度将是对 model.fit 的所有三个调用中的时期数的总和)

      您现在可以使用字典来使用 matplotlib 绘制内容或将其保存到 pandas 数据框等...

      【讨论】:

        【解决方案3】:

        在 2020 年,您可以使用内置 CSVLoggerappend=True

        保存示例:

        epoch,accuracy,loss,val_accuracy,val_loss
        0,0.7649424076080322,0.49990198016166687,0.6675007939338684,0.8114446401596069
        1,0.8209356665611267,0.406406044960022,0.7569596767425537,0.5224416851997375
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2013-02-20
          • 1970-01-01
          • 2015-02-20
          • 1970-01-01
          • 2018-06-24
          • 1970-01-01
          • 2010-11-01
          • 1970-01-01
          相关资源
          最近更新 更多