【发布时间】:2020-10-19 10:52:22
【问题描述】:
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
@tf.function
def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
x = self.dense1(enc_input)
return self.dense2(x)
x = tf.random.normal((10,20))
model = MyModel()
y = model(x, x, False, None, None, None)
tf.keras.models.save_model(model, '/saved')
当我尝试保存模型时,即使我传递了所有参数也会引发错误。
tf__call() missing 4 required positional arguments: 'training', 'mask1', 'mask2', and 'mask3'
如何保存整个模型而不仅仅是保存权重?
【问题讨论】:
-
call应该有一个固定的签名,比如def call(self, inputs, training): ...。因此,您需要使用inputs参数传递这些额外的输入,而不是为该方法创建额外的参数。
标签: tensorflow keras tensorflow2.0 tf.keras