【发布时间】:2022-01-19 18:55:39
【问题描述】:
如何找到训练 keras 模型的 epoch 数?
-
我使用callback_early_stopping() 提前停止训练以避免过度拟合。
-
我一直在使用callback_csv_logger() 来记录训练表现。但有时,我训练 100 多个 keras 模型,仅仅为了了解每个模型的 epoch 数而记录整个训练是没有意义的。
library(keras)
library(kerasR)
library(tidyverse)
# Data
x = matrix(data = runif(30000), nrow = 10000, ncol = 3)
y = ifelse(rowSums(x) > 1.5 + runif(10000), 1, 0)
y = to_categorical(y)
# keras model
model <- keras_model_sequential() %>%
layer_dense(units = 50, activation = "relu", input_shape = ncol(x)) %>%
layer_dense(units = ncol(y), activation = "softmax")
model %>%
compile(loss = "categorical_crossentropy",
optimizer = optimizer_rmsprop(),
metrics = "accuracy")
model %>%
fit(x, y,
epochs = 1000,
batch_size = 128,
validation_split = 0.2,
callbacks = callback_early_stopping(monitor = "val_loss", patience = 5),
verbose = 1)
【问题讨论】: