【问题标题】:Trouble saving tf.keras model with Bert (huggingface) classifier使用 Bert(拥抱脸)分类器保存 tf.keras 模型时遇到问题
【发布时间】:2020-04-26 14:14:57
【问题描述】:

我正在训练一个使用 Bert(拥抱脸)的二元分类器。模型如下所示:

def get_model(lr=0.00001):
    inp_bert = Input(shape=(512), dtype="int32")
    bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0]
    doc_encodings = tf.squeeze(bert[:, 0:1, :], axis=1)
    out = Dense(1, activation="sigmoid")(doc_encodings)
    model = Model(inp_bert, out)
    adam = optimizers.Adam(lr=lr)
    model.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])
    return model

对我的分类任务进行微调后,我想保存模型。

model.save("best_model.h5")

但是这会引发 NotImplementedError:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-55-8c5545f0cd9b> in <module>()
----> 1 model.save("best_spam.h5")
      2 # import transformers

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    110           'or using `save_weights`.')
    111     hdf5_format.save_model_to_hdf5(
--> 112         model, filepath, overwrite, include_optimizer)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
     97 
     98   try:
---> 99     model_metadata = saving_utils.model_metadata(model, include_optimizer)
    100     for k, v in model_metadata.items():
    101       if isinstance(v, (dict, list, tuple)):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    163   except NotImplementedError as e:
    164     if require_config:
--> 165       raise e
    166 
    167   metadata = dict(

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    160   model_config = {'class_name': model.__class__.__name__}
    161   try:
--> 162     model_config['config'] = model.get_config()
    163   except NotImplementedError as e:
    164     if require_config:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    885     if not self._is_graph_network:
    886       raise NotImplementedError
--> 887     return copy.deepcopy(get_network_config(self))
    888 
    889   @classmethod

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1940           filtered_inbound_nodes.append(node_data)
   1941 
-> 1942     layer_config = serialize_layer_fn(layer)
   1943     layer_config['name'] = layer.name
   1944     layer_config['inbound_nodes'] = filtered_inbound_nodes

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    138   if hasattr(instance, 'get_config'):
    139     return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140                                             instance.get_config())
    141   if hasattr(instance, '__name__'):
    142     return instance.__name__

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    884   def get_config(self):
    885     if not self._is_graph_network:
--> 886       raise NotImplementedError
    887     return copy.deepcopy(get_network_config(self))
    888 

NotImplementedError: 

我知道 huggingface 为 TFBertModel 提供了一个 model.save_pretrained() 方法,但我更喜欢将它包装在 tf.keras.Model 中,因为我计划向这个网络添加其他组件/功能。任何人都可以提出保存当前模型的解决方案吗?

【问题讨论】:

  • 从关于 tensorflow 的 GIT 页面的一些讨论中,我认为是 tensorflow 2.0 的问题,尝试升级/降级 tensorflow。
  • 另外,model.save("model_name",save_format='tf') 应该可以工作
  • model.save("model_name",save_format='tf') 解决了我的问题。谢谢!如果您发表评论作为答案,我会接受。

标签: python tensorflow2.0 huggingface-transformers


【解决方案1】:

这确实是tensorflow 2.0的问题。

请使用:model.save("model_name",save_format='tf')

或者,您也可以尝试升级或降级 tensorflow。

【讨论】:

    猜你喜欢
    • 2022-06-28
    • 2021-12-17
    • 2022-01-17
    • 2021-11-24
    • 2021-12-07
    • 2020-10-26
    • 1970-01-01
    • 2021-11-02
    • 2021-08-08
    相关资源
    最近更新 更多