Pytorch模型转Tensorflow模型

引言

最近收到领导布置的一个小任务,需要将pth文件转换到tensorflow模型。这里将采用领导给的pytorchocr代码,来记录下如何进行模型转换。

一、基本思想

想要将pytorch训练得到的pth文件转换到tensorflow的pb文件,基本思想就是先将pth转成onnx,再又onnx转换到pb文件。整体没什么难度,我在下面直接给出代码。

二、转换代码

转换代码如下,我这里采用的是pytorchocr网络,有了具体网络结构就可以进行onnx转换,最后到tensorflow转换。

复制代码
import torch
import torchvision
import tensorflow as tf
from onnx_tf.backend import prepare
from pytorchocr.base_ocr_v20 import BaseOCRV20
import onnx

det_model_path = "E:\File\ckpt\pytorch\det.pth"
# det_model_path = r"D:\File_save\pycharm\demo\Script\ckpt\pytorch\rec.pth"

class PPOCRv3DetConverter(BaseOCRV20):
    def __init__(self, config, **kwargs):
        super(PPOCRv3DetConverter, self).__init__(config, **kwargs)
        # self.load_paddle_weights(paddle_pretrained_model_path)
        self.net.eval()
        # self.training = False


cfg = {'model_type':'det',
           'algorithm':'DB',
           'Transform':None,
           'Backbone':{'name':'MobileNetV3', 'model_name':'large', 'scale':0.5, 'disable_se':True},
           'Neck':{'name':'RSEFPN', 'out_channels':96, 'shortcut': True},
           'Head':{'name':'DBHead', 'k':50}

       }

# cfg = {'model_type': 'rec',
#        'algorithm': 'CRNN',
#        'Transform': None,
#        'Backbone': {'name': 'MobileNetV1Enhance',
#                     'scale': 0.5,
#                     'last_conv_stride': [1, 2],
#                     'last_pool_type': 'avg'},
#        'Neck': {'name': 'SequenceEncoder',
#                 'dims': 64,
#                 'depth': 2,
#                 'hidden_dims': 120,
#                 'use_guide': True,
#                 'encoder_type': 'svtr'},
#        'Head': {'name': 'CTCHead', 'fc_decay': 2e-05}
#        }

converter = PPOCRv3DetConverter(cfg)
converter.load_state_dict(torch.load(det_model_path))
# det
dummy_input = torch.randn(1, 3, 640, 640)
# rec
# dummy_input = torch.randn(1, 3, 48, 320)

det_model = converter.net

# 设置模型为评估模式
converter.net.eval()

# 导出模型为ONNX格式
# torch.onnx.export(det_model, dummy_input, 'rec_model.onnx', verbose=False, opset_version=11)
# torch.onnx.export(det_model, dummy_input, 'det_model.onnx', verbose=False, opset_version=12)

# 加载ONNX模型
onnx_model = onnx.load('E:\File\ckpt\onnx\det_model.onnx')
# onnx_model = onnx.load(r'D:\File_save\pycharm\demo\Script\ckpt\rec_model.onnx')
#
# 转换为TensorFlow模型
tf_model = prepare(onnx_model, strict=False)
#
# 保存为TensorFlow模型文件
tf_model.export_graph('det_model.pb')
# tf_model.export_graph('rec_model.pb')
# #
# tf.compat.v1.reset_default_graph()
# with tf.compat.v1.Session() as sess:
#     with tf.gfile.GFile('det_model.onnx', 'rb') as f:
#         graph_def = tf.compat.v1.GraphDef()
#         graph_def.ParseFromString(f.read())
#         tf.import_graph_def(graph_def, name='')
#
#     # 保存为.pb文件
#     tf.io.write_graph(sess.graph, '', 'det_model.pb', as_text=False)
相关推荐
我爱一条柴ya10 分钟前
【AI大模型】神经网络反向传播:核心原理与完整实现
人工智能·深度学习·神经网络·ai·ai编程
万米商云14 分钟前
企业物资集采平台解决方案:跨地域、多仓库、百部门——大型企业如何用一套系统管好百万级物资?
大数据·运维·人工智能
新加坡内哥谈技术17 分钟前
Google AI 刚刚开源 MCP 数据库工具箱,让 AI 代理安全高效地查询数据库
人工智能
慕婉030719 分钟前
深度学习概述
人工智能·深度学习
大模型真好玩20 分钟前
准确率飙升!GraphRAG如何利用知识图谱提升RAG答案质量(额外篇)——大规模文本数据下GraphRAG实战
人工智能·python·mcp
198921 分钟前
【零基础学AI】第30讲:生成对抗网络(GAN)实战 - 手写数字生成
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·近邻算法
6confim21 分钟前
AI原生软件工程师
人工智能·ai编程·cursor
阿里云大数据AI技术21 分钟前
Flink Forward Asia 2025 主旨演讲精彩回顾
大数据·人工智能·flink
i小溪22 分钟前
在使用 Docker 时,如果容器挂载的数据目录(如 `/var/moments`)位于数据盘,只要服务没有读写,数据盘是否就不会被唤醒?
人工智能·docker
程序员NEO25 分钟前
Spring AI 对话记忆大揭秘:服务器重启,聊天记录不再丢失!
人工智能·后端