【问题标题】:(tf.)keras loading saved model weights with trainable word embeddings(tf.)keras 使用可训练的词嵌入加载保存的模型权重
【发布时间】:2020-11-05 17:07:05
【问题描述】:

我在 (tf.)Keras 中加载模型权重时遇到问题。

我的模型只是一个带有预训练词嵌入的简单 LSTM 模型,但我在训练时让词嵌入可训练

我使用以下代码保存了模型权重:

mc = ModelCheckpoint(filepath, save_weights_only=True, monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)

我检查了文件路径中是否存在hdf5文件,大小约为18MB。

后来,我尝试使用以下代码加载权重:

model = build_model() #the function that I used to make the model in Training process
model = model.load_weights(filepath)

但是,model.load_weights(filepath) 返回 None

问题1。 这些代码有问题吗?如果不是,这可能是因为我让词嵌入可训练?

Question2. = 在这种情况下,修改后的词嵌入保存在哪里?是和hdf5文件中的其他参数一起保存的吗?如果是这种情况,我该如何加载这个经过微调的词嵌入?

【问题讨论】:

  • 我刚刚尝试了相同的代码,但无法训练单词嵌入。但是,仍然得到相同的结果(返回无)。

标签: python tensorflow keras word-embedding


【解决方案1】:

要提取词嵌入,您需要首先从所需模型中提取嵌入层

embed_layer = model.get_layer('embedding_26') #embedding_26 is generated name of embedding layer

提取经过训练的词嵌入

embed_layer.get_weights()

>>> [array([[ 9.0566e-01, -7.1792e-01, -1.9574e-01, ...,  1.1230e-03,
          2.8188e-02,  3.0385e-01],
        [ 5.8560e-01, -3.6964e-01,  6.3480e-02, ...,  5.6656e-01,
         -3.6404e-01, -2.5202e-01],
        [ 4.5269e-01, -6.2509e-01,  1.6866e-01, ..., -5.0146e-01,
          2.9764e-01,  1.4548e-01],
        ...,
        [-1.0632e-01,  6.8057e-01, -1.5388e+00, ..., -4.8493e-01,
          3.2478e-01, -1.1330e-01],
        [ 7.6822e-01,  7.1786e-01,  5.8778e-01, ...,  1.6097e-01,
          8.9411e-02,  8.4237e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=float32)]

我不确定您是否可以直接从文件中加载权重,但您可以这样做:

model = load_model('best_model.h5')
weights = model.get_weights()  # load weights of a model

然后您可以使用它在相同架构的另一个模型中加载它

model2.set_weights(weights)

【讨论】:

  • 感谢您的回答!谢谢你,我知道我可以通过这种方式使用微调嵌入。但是对于我提出的问题,您是否知道保存的模型参数(hdf5文件)是否包含词嵌入值,以及为什么我在问题中编写的代码不起作用?谢谢!
  • 在答案中更新
猜你喜欢
  • 2019-12-05
  • 2019-11-14
  • 1970-01-01
  • 2018-08-03
  • 2019-06-04
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多