【问题标题】:Checkpoint deep learning models in KerasKeras 中的检查点深度学习模型
【发布时间】:2017-06-15 17:12:55
【问题描述】:

我需要帮助来实现 Keras 中的检查点功能。我要训练一个大型数据集,所以为了做到这一点,首先我使用鸢尾花数据集训练了一个模型:http://machinelearningmastery.com/multi-class-classification-tutorial-keras-deep-learning-library/

由于我自己的数据集与它非常相似,唯一的区别是我的数据集更大。

检查点功能:http://machinelearningmastery.com/check-point-deep-learning-models-keras/

我理解使用 pima-indians 数据集的示例。 现在我正在尝试在 iris-flower 脚本中实现相同的检查点功能。这是我到目前为止所尝试的。

import numpy
from pandas import *
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from keras.callbacks import ModelCheckpoint

seed = 7
numpy.random.seed(seed)

dataframe = read_csv("iris.csv", header=None)
dataset = dataframe.values
X = dataset[:,0:4].astype(float)
Y = dataset[:,4]

# encode class value as integers
encoder = LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
dummy_y = np_utils.to_categorical(encoded_Y)

def baseline_model():
    model = Sequential()
    model.add(Dense(4, input_dim=4, init='normal', activation='relu'))
    model.add(Dense(3, init='normal', activation='sigmoid'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

estimator = KerasClassifier(build_fn=baseline_model, validation_split=0.33, nb_epoch=200, batch_size=5, callbacks=callbacks_list, verbose=0)
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(estimator, X, dummy_y, cv=kfold)
print("Baseline: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))

此脚本产生了以下错误。我不知道如何解决它,或者我在脚本中的安排是错误的。

RuntimeError: Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x10e120fd0>, as the constructor does not seem to set parameter callbacks

我希望有人可以帮助我解决这个问题。谢谢。

【问题讨论】:

  • 你知道是哪一行导致了错误吗?

标签: python keras


【解决方案1】:

我遇到了同样的错误,但在 NN 层中设置了“单位”参数。

RuntimeError: Cannot clone object <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x1496b97f0>, as the constructor either does not set or modifies parameter units

将 scikit-learn 从 0.23.1 降级到 0.21.2 为我解决了这个问题。

在 github 上查看此问题:Link

【讨论】:

    【解决方案2】:

    使用模型本身,而不是 KerasClassifier。

    model = baseline_model()
    #declare your callback methods here
    model.fit(x,y, batch_size=32, verbose=0, epochs=10, shuffle=True, validation_split = 0.1, callbacks = <your list of callbacks>)
    

    【讨论】:

      【解决方案3】:

      我认为问题在于您的 baseline_model() 函数没有返回它正在创建的模型;它应该是这样的:

      def baseline_model():
          model = Sequential()
          model.add(Dense(4, input_dim=4, init='normal', activation='relu'))
          model.add(Dense(3, init='normal', activation='sigmoid'))
          model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
          return model
      

      【讨论】:

      • 感谢您为我指出这一点。我在脚本中添加了“返回模型”,但发生了同样的错误。
      • @Ling Ah 对,好像是you are not the only one with this problem。查看this questionthe source,它可能是包装程序编程中的一个错误。你可以试试estimator = KerasClassifier(build_fn=baseline_model, **{validation_split=0.33, nb_epoch=200, batch_size=5, callbacks=callbacks_list, verbose=0}),不过如果有帮助的话,我想知道。
      • 感谢您的建议,但它在“validation_split=0.33”处返回语法错误。我将具有单个“=”的行更改为“==”,但出现另一个错误,提示未定义“validation_split”
      • @Ling Whops,对不起,我的错:S 我的意思是estimator = KerasClassifier(build_fn=baseline_model, **{'validation_split': 0.33, 'nb_epoch': 200, 'batch_size': 5, 'callbacks': callbacks_list, 'verbose': 0})。虽然,正如我所说,我不确定这是否会解决它(我不确定 SciPy 元训练器是如何实现的)。
      • 不,那也行不通。它返回与问题相同的运行时错误。也许没有办法实现它。无论如何,谢谢你的帮助。
      猜你喜欢
      • 2015-12-31
      • 2021-05-26
      • 2018-02-03
      • 2021-04-08
      • 1970-01-01
      • 2019-12-10
      • 2021-04-04
      • 1970-01-01
      • 2020-05-27
      相关资源
      最近更新 更多