【问题标题】:How to save a keras subclassed model with positional parameters in Call() method?如何在 Call() 方法中保存带有位置参数的 keras 子类模型?
【发布时间】:2020-10-19 10:52:22
【问题描述】:
import tensorflow as tf

class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    @tf.function
    def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
        x = self.dense1(enc_input)
        return self.dense2(x)

x = tf.random.normal((10,20))

model = MyModel()

y = model(x, x, False, None, None, None)

tf.keras.models.save_model(model, '/saved')

当我尝试保存模型时,即使我传递了所有参数也会引发错误。

tf__call() missing 4 required positional arguments: 'training', 'mask1', 'mask2', and 'mask3'

如何保存整个模型而不仅仅是保存权重?

【问题讨论】:

  • call 应该有一个固定的签名,比如def call(self, inputs, training): ...。因此,您需要使用 inputs 参数传递这些额外的输入,而不是为该方法创建额外的参数。

标签: tensorflow keras tensorflow2.0 tf.keras


【解决方案1】:

我认为进行以下更改会起作用

 #def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
 def call(self, enc_input, dec_input, training=False, mask1=None, mask2=None, mask3=None):

经过挖掘,我认为函数参数上会发生sanity check,如果未指定位置参数,则 x 后面的参数将被视为 **kwargs 参数(我对此不太确定)。 但是为了它,如果你不想设置参数的默认映射,你可以解压它们,这样每个参数都会放在它对应的位置,如下所示:

y = model(*[x,x,False,None,None,None])

【讨论】:

    猜你喜欢
    • 2020-07-10
    • 1970-01-01
    • 2020-05-28
    • 2020-03-21
    • 1970-01-01
    • 2018-08-02
    • 2018-05-02
    • 2022-07-25
    • 1970-01-01
    相关资源
    最近更新 更多