【发布时间】:2019-08-19 02:56:24
【问题描述】:
我为特定领域的 bert 模型下载了 tensorflow 检查点,并将 zip 文件提取到包含以下三个文件的文件夹 pretrained_bert 中
model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta
我使用以下代码将 tensorflow 检查点转换为 pytorch
import torch
from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
tf_checkpoint_path="pretrained_bert/model.ckpt"
bert_config_file = "bert-base-cased-config.json"
pytorch_dump_path="pytorch_bert"
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
我在运行上面的代码时遇到了这个错误
NotFoundError:不成功的 TensorSliceReader 构造函数:失败 查找 pretrained_bert/model.ckpt 的任何匹配文件
非常感谢任何帮助............
【问题讨论】:
-
如你所说,我已经改变了路径,但我得到了同样的错误
-
对不起,我弄错了。你实际上已经正确地给予了它。但是,我确实看到您已经给出了相对路径。尝试提供绝对路径。
-
谢谢,指定绝对路径后就可以了。
-
嗨@KalyanKatikapalli,我遇到了同样的问题。我已经放了绝对路径,但仍然给我同样的错误。你有什么建议吗?谢谢
-
在代码中,必须为变量“tf_checkpoint_path”、“bert_config_file”和“pytorch_dump_path”指定绝对路径。指定相对路径时,模型找不到对应的文件。
标签: python tensorflow pytorch