【发布时间】:2019-11-19 01:24:31
【问题描述】:
我正在 keras 中训练一个模型,我想在每个 epoch 之后绘制结果图。我知道 keras 回调提供了“on_epoch_end”函数,如果一个人想在每个 epoch 之后进行一些计算,那么该函数可以被重载,但是我的函数需要一些额外的参数,当给定这些参数时,元类错误会导致代码崩溃。具体如下:
这是我现在的做法,效果很好:-
class NewCallback(Callback):
def on_epoch_end(self, epoch, logs={}): #working fine, printing epoch after each epoch
print("EPOCH IS: "+str(epoch))
epochs=5
batch_size = 16
model_saved=False
if model_saved:
vae.load_weights(args.weights)
else:
# train the autoencoder
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None),
callbacks=[NewCallback()])
但我想要这样的回调函数:-
class NewCallback(Callback,models,data,batch_size):
def on_epoch_end(self, epoch, logs={}):
print("EPOCH IS: "+str(epoch))
x=models.predict(data)
plt.plot(x)
plt.savefig(epoch+".png")
如果我这样称呼它:
callbacks=[NewCallback(models, data, batch_size=batch_size)]
我收到此错误:
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
我正在寻找一个更简单的解决方案来调用我的函数或解决元类的这个错误,非常感谢任何帮助!
【问题讨论】:
标签: keras callback metaclass custom-function