PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环

如果你在 PyTorch 中只做「调包侠」,那么永远只是在外围打转;只有把「模型定义 → 修改 → 保存/加载」整条链路打通,才算真正拥有了炼丹炉的钥匙。

本文把官方教程 5.1--5.4 浓缩成一篇逻辑闭环的实战笔记,力求"看完即可落地"。


1. 为什么要有"模型工程化"思维?

阶段 痛点举例 本章解法
快速验证 一行行手写 100 层 CNN? Sequential / 模型块
需求变更 ResNet50 输出从 1000 → 10 类 局部层替换 / 外部输入输出
训练中断 断电后需从头再来 断点续训
部署迁移 8 卡训练 → 1 卡推理报错 统一权重前缀

2. 模型定义:三种姿势,按需选择

2.1 Sequential ------ 极简线性堆叠

python 复制代码
net = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10)
)

适用:快速 PoC、网络无分支。

2.2 ModuleList / ModuleDict ------ 乐高式复用

python 复制代码
class TinyResNet(nn.Module):
    def __init__(self, n_blocks=4):
        super().__init__()
        self.blocks = nn.ModuleList([
            Bottleneck(64, 64) for _ in range(n_blocks)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

适用:重复单元、需要动态深度。


3. 模型修改:三大高频需求一次讲透

torchvision.models.resnet50() 为例。

需求 关键 API / 技巧 代码片段
改输出类别 直接替换 fc net.fc = nn.Linear(2048, 10)
加额外输入 forward 里 torch.cat x = torch.cat([net(x), add_var.unsqueeze(1)], 1)
多输出/中间特征 修改 forward 的 return return out, feature

所有修改只需继承 nn.Module 并重写 __init__forward,无需动原始源码。


4. 模型保存与加载:单卡/多卡一次说清

4.1 存什么?

方式 命令 优缺点
仅权重 torch.save(model.state_dict(), path) 轻量、跨环境兼容
整个模型 torch.save(model, path) 含结构,但依赖原始类定义和 Python 版本

实战建议:99% 场景只存权重。

4.2 单卡 ↔ 多卡权重前缀问题

  • 多卡训练 会引入 "module." 前缀
  • 通用解法 :存权重时统一存 model.module.state_dict(),或加载时 strip 前缀:
python 复制代码
state = torch.load('multi_gpu.pth')
new_state = {k[7:]: v for k, v in state.items()}  # 去掉 'module.'
model.load_state_dict(new_state)

4.3 断点续训:把训练状态一起打包

python 复制代码
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'epoch': epoch,
    'best_acc': best_acc
}, 'checkpoint.pth')

# 恢复
ckpt = torch.load('checkpoint.pth')
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = ckpt['epoch'] + 1

5. 一条完整的开发流水线示例

Sequential / ModuleList 改层 加输入 加输出 单卡 多卡 定义网络 训练 需求变更 局部替换 fc 重写 forward + cat return 多个值 保存权重 state_dict 部署环境 直接 load DataParallel + strip 前缀 继续训练 / 推理


6. 小结 & 行动清单

任务场景 立即能做的最小行动
快速搭 baseline nn.Sequential 10 行内出模型
迁移学习 把 ResNet50 的 fc 替换成你的类别数
断电续训练 把 optimizer & epoch 一起写进 checkpoint
8 卡训练 → 单卡推理 保存 model.module.state_dict()

记住一句话:权重是模型的灵魂,结构是容器;容器可以重建,灵魂必须妥善保存。


参考资料

《深入浅出PyTorch》第5章 5.1--5.4(DatawhaleChina 团队)

官方文档:torch.save / torch.load / nn.DataParallel

相关推荐
天上的光18 分钟前
17.迁移学习
人工智能·机器学习·迁移学习
合作小小程序员小小店21 分钟前
SDN安全开发环境中常见的框架,工具,第三方库,mininet常见指令介绍
python·安全·生成对抗网络·网络安全·网络攻击模型
后台开发者Ethan24 分钟前
Python需要了解的一些知识
开发语言·人工智能·python
北京_宏哥34 分钟前
Python零基础从入门到精通详细教程11 - python数据类型之数字(Number)-浮点型(float)详解
前端·python·面试
猫头虎44 分钟前
猫头虎AI分享|一款Coze、Dify类开源AI应用超级智能体快速构建工具:FastbuildAI
人工智能·开源·prompt·github·aigc·ai编程·ai-native
重启的码农1 小时前
ggml 介绍 (6) 后端 (ggml_backend)
c++·人工智能·神经网络
重启的码农1 小时前
ggml介绍 (7)后端缓冲区 (ggml_backend_buffer)
c++·人工智能·神经网络
数据智能老司机1 小时前
面向企业的图学习扩展——图简介
人工智能·机器学习·ai编程
盼小辉丶1 小时前
PyTorch生成式人工智能——使用MusicGen生成音乐
pytorch·python·深度学习·生成模型
mit6.8242 小时前
[AI React Web] 包与依赖管理 | `axios`库 | `framer-motion`库
前端·人工智能·react.js