【发布时间】:2018-05-20 10:47:53
【问题描述】:
我正在安装一个 train_generator,并且我想通过一个自定义回调来计算我的 validation_generator 上的自定义指标。
如何在自定义回调中访问参数 validation_steps 和 validation_data?
不在self.params,在self.model也找不到。这就是我想做的。欢迎任何不同的方法。
model.fit_generator(generator=train_generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
callbacks=[CustomMetrics()])
class CustomMetrics(keras.callbacks.Callback):
def on_epoch_end(self, batch, logs={}):
for i in validation_steps:
# features, labels = next(validation_data)
# compute custom metric: f(features, labels)
return
keras:2.1.1
更新
我设法将我的验证数据传递给自定义回调的构造函数。但是,这会导致令人讨厌的“内核似乎已经死机。它将自动重新启动。”信息。我怀疑这是否是正确的方法。有什么建议吗?
class CustomMetrics(keras.callbacks.Callback):
def __init__(self, validation_generator, validation_steps):
self.validation_generator = validation_generator
self.validation_steps = validation_steps
def on_epoch_end(self, batch, logs={}):
self.scores = {
'recall_score': [],
'precision_score': [],
'f1_score': []
}
for batch_index in range(self.validation_steps):
features, y_true = next(self.validation_generator)
y_pred = np.asarray(self.model.predict(features))
y_pred = y_pred.round().astype(int)
self.scores['recall_score'].append(recall_score(y_true[:,0], y_pred[:,0]))
self.scores['precision_score'].append(precision_score(y_true[:,0], y_pred[:,0]))
self.scores['f1_score'].append(f1_score(y_true[:,0], y_pred[:,0]))
return
metrics = CustomMetrics(validation_generator, validation_steps)
model.fit_generator(generator=train_generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
shuffle=True,
callbacks=[metrics],
verbose=1)
【问题讨论】:
-
我认为没有好的选择。如果您查看 keras 中 _fit_loop 的代码,它并没有真正将 validation_steps 和 validation_data 传递给回调。
-
在(批量开始时)上使用 next(validation_generatro) 怎么样,这会比你的方式更好吗?我的意思是,在这种情况下,我不知道 next(val_generator) 是否会进行下一次迭代,或者它总是从头开始随机开始并且永远不会覆盖所有验证数据。
-
如果您查看 Keras TensorBoard 回调,似乎有一种方法可以从模型中获取验证数据,但我无法在代码中找到它发生的位置:github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/…
-
我在这里提供一个可能的答案:stackoverflow.com/a/59697739/880783