【发布时间】:2019-11-10 01:24:09
【问题描述】:
我使用 SQUAD 2.0 训练了 BERT,并使用 BERT-master/run_squad.py 在输出目录中获得了 model.ckpt.data、model.ckpt.meta、model.ckpt.index(F1 分数:81)以及 predictions.json 等
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
我尝试将model.ckpt.meta、model.ckpt.index、model.ckpt.data 复制到$BERT_LARGE_DIR 目录并更改run_squad.py 标志如下,以仅预测答案而不使用数据集进行训练:
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
--do_train=False \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
抛出bucket directory/model.ckpt不存在错误。
如何利用训练后生成的检查点进行预测?
【问题讨论】:
标签: python tensorflow neural-network google-cloud-tpu bert-language-model