【问题标题】:Keras loadmodel for custom model with custom layers - Transformer documentation exampleKeras loadmodel 用于具有自定义层的自定义模型 - Transformer 文档示例
【发布时间】:2021-02-22 10:08:21
【问题描述】:

我正在运行以下示例:

https://keras.io/examples/nlp/text_classification_with_transformer/

我已经按照描述创建并训练了一个模型,并且效果很好:

inputs = layers.Input(shape=(maxlen,))
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
x = transformer_block(x,training=True)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(2, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)


"""
## Train and Evaluate
"""

model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(
    x_train, y_train, batch_size=1024, epochs=1, validation_data=(x_val, y_val)
)

model.save('SPAM.h5')

如何在 Keras 中正确保存和加载此类自定义模型?

我试过了

 best_model=tf.keras.models.load_model('SPAM.h5')
 ValueError: Unknown layer: TokenAndPositionEmbedding

但模型似乎错过了自定义层。但是以下也不起作用

best_model=tf.keras.models.load_model('SPAM.h5',custom_objects={"TokenAndPositionEmbedding": TokenAndPositionEmbedding()})
 
TypeError: __init__() missing 3 required positional arguments:
 'maxlen', 'vocab_size', and 'embed_dim'

同样通过类也解决不了。

best_model=tf.keras.models.load_model('SPAM.h5',
 custom_objects={"TokenAndPositionEmbedding": TokenAndPositionEmbedding})
 TypeError: __init__() got an unexpected keyword argument 'name'



 best_model=tf.keras.models.load_model('SPAM.h5',
{"TokenAndPositionEmbedding":
TokenAndPositionEmbedding,'TransformerBlock':TransformerBlock,
'MultiHeadSelfAttention':MultiHeadSelfAttention})

【问题讨论】:

  • 也许这可以帮助你:Save and serialize: custom objects
  • 为什么在 tf 而不是在 keras 文档中?
  • 除非你真的考虑使用 tensorflow 之外的其他后端(我什至不确定是否仍然支持 theano),否则 tensorflow 网站上的文档通常更完整/最新。
  • 为了保存/加载具有自定义层的模型或子类模型,您应该覆盖 get_config 和可选的 from_config 方法。此外,您应该使用注册自定义对象,以便 Keras 知道它。
  • 我已经按照了,但还是同样的错误

标签: python tensorflow keras serialization transformer


【解决方案1】:

基于this answer,你需要在每个类(TokenAndPositionEmbedding 和 TransformerBlock)中添加这个方法(get_config):

变压器块:

def get_config(self):
    config = super().get_config().copy()
    config.update({
        'embed_dim': self.embed_dim,
        'num_heads': self.num_heads,
        'ff_dim': self.ff_dim,
        'rate': self.rate
    })
    return config

并将构造函数更改为

def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
    super(TransformerBlock, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.ff_dim = ff_dim
    self.rate = rate
    self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
    self.ffn = keras.Sequential(
        [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
    )
    self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
    self.dropout1 = layers.Dropout(rate)
    self.dropout2 = layers.Dropout(rate)

TokenAndPositionEmbedding:

同样,将其添加到类中

def get_config(self):
    config = super().get_config().copy()
    config.update({
        'maxlen': self.maxlen,
        'vocab_size': self.vocab_size,
        'embed_dim': self.embed_dim
    })
    return config

并将构造函数替换为:

def __init__(self, maxlen, vocab_size, embed_dim, **kwargs):
    super(TokenAndPositionEmbedding, self).__init__()
    self.maxlen = maxlen
    self.vocab_size = vocab_size
    self.embed_dim = embed_dim
    self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
    self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

无法保存和加载自定义图层的原因在答案链接中进行了说明。加载时只做:

x = load_model('model.h5', custom_objects = {"TransformerBlock": TransformerBlock, "TokenAndPositionEmbedding": TokenAndPositionEmbedding})

【讨论】:

    猜你喜欢
    • 2021-03-12
    • 2020-09-28
    • 1970-01-01
    • 1970-01-01
    • 2018-11-06
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多