从 PyTorch 到 ONNX:深度学习模型导出全解析

在模型训练完毕后,我们通常希望将其部署到推理平台中,比如 TensorRT、ONNX Runtime 或移动端框架。而 ONNX(Open Neural Network Exchange)正是 PyTorch 与这些平台之间的桥梁。

本文将以一个图像去噪模型 SimpleDenoiser 为例,手把手带你完成 PyTorch 模型导出为 ONNX 格式的全过程,并解析每一行代码背后的逻辑。

准备工作

我们假设你已经训练好一个图像去噪模型并保存为 .pth 文件,模型结构自编码器实现如下(略):

python 复制代码
class SimpleDenoiser(nn.Module):
    def __init__(self):
        super(SimpleDenoiser, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

导出代码分解

我们现在来看导出脚本的核心逻辑,并分块解释它的每一部分。

1. 导入模块 & 设置路径

python 复制代码
//torch:核心框架

//train.SimpleDenoiser:从训练脚本复用模型结构

//os:用于创建输出目录

import torch
from train import SimpleDenoiser  # 模型结构
import os

2. 导出函数定义

python 复制代码
//这个函数接收三个参数:

//pth_path: 训练得到的模型参数文件路径

//onnx_path: 导出的 ONNX 文件保存路径

//input_size: 模拟推理输入的尺寸(默认 1×3×256×256)
def export_model_to_onnx(pth_path, onnx_path, input_size=(1, 3, 256, 256)):

3. 加载模型和权重

python 复制代码
//自动检测 CUDA 可用性,加载模型到对应设备;

//使用 load_state_dict() 加载训练好的参数;

//model.eval() 让模型切换到推理模式(关闭 Dropout/BatchNorm 更新);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleDenoiser().to(device)
model.load_state_dict(torch.load(pth_path, map_location=device))
model.eval()

4. 构造假输入(Dummy Input)

python 复制代码
//ONNX 导出需要一个具体的输入样本,我们这里用 torch.randn 生成一个形状为 (1, 3, 256, 256) 的随机图//像;

//输入必须放在同一个设备上(GPU 或 CPU);
dummy_input = torch.randn(*input_size).to(device)

5. 导出为 ONNX

python 复制代码
torch.onnx.export(
    model,  //要导出的模型
    dummy_input,  //示例输入张量
    onnx_path, //	导出路径
    export_params=True,  //是否导出权重
    opset_version=11,  //ONNX 的算子集版本,通常推荐 11 或 13
    do_constant_folding=True,  //优化常量表达式,减小模型体积
    input_names=['input'],  //自定义输入输出张量的名称
    output_names=['output'],  //声明哪些维度可以变动,比如 batch size、图像大小等(部署时更灵活)
    dynamic_axes={
        'input': {0: 'batch_size', 2: 'height', 3: 'width'},
        'output': {0: 'batch_size', 2: 'height', 3: 'width'}
    }
)

6. 创建目录并调用函数

java 复制代码
//确保输出文件夹存在,并调用导出函数生成最终模型。
if __name__ == "__main__":
    os.makedirs("onnx", exist_ok=True)
    export_model_to_onnx("weights/denoiser.pth", "onnx/denoiser.onnx")

导出后如何验证?

bash 复制代码
pip install onnxruntime
python 复制代码
import onnxruntime
import numpy as np

sess = onnxruntime.InferenceSession("onnx/denoiser.onnx")
input = np.random.randn(1, 3, 256, 256).astype(np.float32)
output = sess.run(None, {"input": input})
print("输出 shape:", output[0].shape)

模型预览:

总结

导出 ONNX 模型的流程主要包括:

  1. 加载模型结构 + 权重

  2. 准备 dummy 输入张量

  3. 调用 torch.onnx.export() 进行导出

  4. 设置 dynamic_axes 可变尺寸以增强部署适配性

这套流程适用于大部分视觉模型(分类、去噪、分割等),也是后续进行 TensorRT 推理或移动端部署的基础。

相关推荐
ziwu8 分钟前
【民族服饰识别系统】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积网络+resnet50算法
人工智能·后端·图像识别
ziwu22 分钟前
【卫星图像识别系统】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积网络+resnet50算法
人工智能·tensorflow·图像识别
ISACA中国28 分钟前
ISACA与中国内审协会共同推动的人工智能审计专家认证(AAIA)核心内容介绍
人工智能·审计·aaia·人工智能专家认证·人工智能审计专家认证·中国内审协会
ISACA中国43 分钟前
《第四届数字信任大会》精彩观点:针对AI的攻击技术(MITRE ATLAS)与我国对AI的政策导向解读
人工智能·ai·政策解读·国家ai·风险评估工具·ai攻击·人工智能管理
Coding茶水间44 分钟前
基于深度学习的PCB缺陷检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
绫语宁1 小时前
以防你不知道LLM小技巧!为什么 LLM 不适合多任务推理?
人工智能·后端
霍格沃兹测试开发学社-小明1 小时前
AI来袭:自动化测试在智能实战中的华丽转身
运维·人工智能·python·测试工具·开源
大千AI助手1 小时前
Softmax函数:深度学习中的多类分类基石与进化之路
人工智能·深度学习·机器学习·分类·softmax·激活函数·大千ai助手
韩曙亮1 小时前
【人工智能】AI 人工智能 技术 学习路径分析 ② ( 深度学习 -> 机器视觉 )
人工智能·深度学习·学习·ai·机器视觉
九千七5261 小时前
sklearn学习(3)数据降维
人工智能·python·学习·机器学习·sklearn