场景
使用tensorflow将TF模型转化成PyTorch模型
步骤
获取如下三个文件:
- src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py
:这个是将tensorflow2.x Bert模型转化成PyTorch可用的模型。 - src/transformers/models/bert/modeling_bert.py
:Bert模型使用例子。 - BERT-Base, Multilingual Cased (New, recommended)
:基于Bert的多语言预训练模型。
这里假设已经安装过PyTorch了。
开始转化TF2模型位PyTorch模型:
# 安装依赖
pip3 install tensorflow transformers
export BERT_BASE_DIR=~/Downloads/nlp_bert/multi_cased_L-12_H-768_A-12
transformers-cli convert --model_type bert \
--tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
--config $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin
这里的pytorch_model.bin就是TF2的已经训练好的模型转化过来的PyTorch模型。