【问题标题】:how can I find the number of epochs for which keras model was trained?如何找到训练 keras 模型的 epoch 数?
【发布时间】:2022-01-19 18:55:39
【问题描述】:

如何找到训练 keras 模型的 epoch 数?

  1. 我使用callback_early_stopping() 提前停止训练以避免过度拟合。

  2. 我一直在使用callback_csv_logger() 来记录训练表现。但有时,我训练 100 多个 keras 模型,仅仅为了了解每个模型的 epoch 数而记录整个训练是没有意义的。

library(keras)
library(kerasR)
library(tidyverse)


# Data
x = matrix(data = runif(30000), nrow = 10000, ncol = 3)
y = ifelse(rowSums(x) > 1.5 + runif(10000), 1, 0)
y = to_categorical(y)

# keras model
model <- keras_model_sequential() %>%   
  layer_dense(units = 50, activation = "relu", input_shape = ncol(x)) %>%
  layer_dense(units = ncol(y), activation = "softmax")

model %>%
  compile(loss = "categorical_crossentropy", 
          optimizer = optimizer_rmsprop(), 
          metrics = "accuracy")

model %>% 
  fit(x, y, 
      epochs = 1000,
      batch_size = 128,
      validation_split = 0.2, 
      callbacks = callback_early_stopping(monitor = "val_loss", patience = 5),
      verbose = 1)

【问题讨论】:

    标签: r keras


    【解决方案1】:

    要打印 epoch 的数量(无论您想要什么),您都可以使用回调。 这是一个例子:

    class print_log_Callback(Callback):
      def __init__(self, logpath, steps):
        self.logpath = logpath
        self.losslst = np.zeros(steps)
    
      def on_train_batch_end(self, batch, logs=None):
        self.losslst[batch] = logs["loss"]
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))
            writefile.write("\n")
    
      def on_test_batch_end(self, batch, logs=None):
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("For batch {}, val_loss is {:7.2f}.".format(batch, logs["loss"]))
            writefile.write("\n")
    
      def on_epoch_end(self, epoch, logs=None):
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("The val_loss  for epoch {} is {:7.2f}.".format(epoch, logs['val_loss']))
            writefile.write("\n")
            print("The mean train loss is: ", np.mean(self.losslst))
            writefile.write("\n")
            writefile.write("\n")
    
        self.losslst = np.zeros(steps)
    

    你这样称呼它:

    print_log_Callback(logpath=logpath, steps=int(steps))
    

    其中 logpath 是您编写代码的文本文件的路径,steps 是步骤数。

    这个回调基本上将网络的整个历史记录打印在一个文本文件上。

    每批次和每个时期结束后的损失。

    如果您只需要时代,您可以只使用方法on_epoch_end 并删除其他所有内容。

    如果你想在每个 epoch 之后打印损失,你可以使用这个修改后的版本:

    class print_log_Callback(Callback):
      def __init__(self, logpath, steps):
        self.logpath = logpath
        self.losslst = np.zeros(steps)
    
      def on_train_batch_end(self, batch, logs=None):
        self.losslst[batch] = logs["loss"]
    
      def on_epoch_end(self, epoch, logs=None):
        with open(logpath, 'a') as writefile:
          with redirect_stdout(writefile):
            print("The val_loss  for epoch {} is {:7.2f}.".format(epoch, logs['val_loss']))
            writefile.write("\n")
            print("The mean train loss is: ", np.mean(self.losslst))
            writefile.write("\n")
            writefile.write("\n")
    
        self.losslst = np.zeros(steps)
    

    您可以修改此回调以同时打印指标:例如,只需打印 logs["accuracy"]

    【讨论】:

    • 重点是,通过自定义回调,您可以使用变量epoch 打印实际的纪元数,并以更灵活的方式执行其他操作
    【解决方案2】:

    我在 python 中使用 tensorflow keras,但我的初始搜索将在历史记录中保存来自拟合后模型相关日志记录的所有信息(损失、验证损失、准确性、F1 等)

    我怀疑这在 R 中是一样的 -

    根据:https://keras.rstudio.com/articles/training_visualization.html

    只需将历史变量分配给您的模型拟合调用,例如:

    history <- model %>% 
      fit(x, y, 
          epochs = 1000,
          batch_size = 128,
          validation_split = 0.2, 
          callbacks = callback_early_stopping(monitor = "val_loss", patience = 5),
          verbose = 1)
    

    将历史记录转换为数据框 (as.data.frame(history)),您可以在其中找到指标 - 指标的长度与模型训练的 epoch 数相同

    【讨论】:

    • 谢谢你,我用callbacks_csv_logger()做了一些非常相似的事情。训练日志的长度是 epoch 的数量。有没有办法直接保存epochs数
    • 我不知道(即使在 python 中,我也不确定是否有直接的方法可以在不查询指标长度的情况下访问它)
    猜你喜欢
    • 1970-01-01
    • 2020-10-17
    • 2019-08-27
    • 1970-01-01
    • 1970-01-01
    • 2019-06-18
    • 2021-01-31
    • 2018-10-12
    • 1970-01-01
    相关资源
    最近更新 更多