PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环

如果你在 PyTorch 中只做「调包侠」,那么永远只是在外围打转;只有把「模型定义 → 修改 → 保存/加载」整条链路打通,才算真正拥有了炼丹炉的钥匙。

本文把官方教程 5.1--5.4 浓缩成一篇逻辑闭环的实战笔记,力求"看完即可落地"。


1. 为什么要有"模型工程化"思维?

阶段 痛点举例 本章解法
快速验证 一行行手写 100 层 CNN? Sequential / 模型块
需求变更 ResNet50 输出从 1000 → 10 类 局部层替换 / 外部输入输出
训练中断 断电后需从头再来 断点续训
部署迁移 8 卡训练 → 1 卡推理报错 统一权重前缀

2. 模型定义:三种姿势,按需选择

2.1 Sequential ------ 极简线性堆叠

python 复制代码
net = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10)
)

适用:快速 PoC、网络无分支。

2.2 ModuleList / ModuleDict ------ 乐高式复用

python 复制代码
class TinyResNet(nn.Module):
    def __init__(self, n_blocks=4):
        super().__init__()
        self.blocks = nn.ModuleList([
            Bottleneck(64, 64) for _ in range(n_blocks)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

适用:重复单元、需要动态深度。


3. 模型修改:三大高频需求一次讲透

torchvision.models.resnet50() 为例。

需求 关键 API / 技巧 代码片段
改输出类别 直接替换 fc net.fc = nn.Linear(2048, 10)
加额外输入 forward 里 torch.cat x = torch.cat([net(x), add_var.unsqueeze(1)], 1)
多输出/中间特征 修改 forward 的 return return out, feature

所有修改只需继承 nn.Module 并重写 __init__forward,无需动原始源码。


4. 模型保存与加载:单卡/多卡一次说清

4.1 存什么?

方式 命令 优缺点
仅权重 torch.save(model.state_dict(), path) 轻量、跨环境兼容
整个模型 torch.save(model, path) 含结构,但依赖原始类定义和 Python 版本

实战建议:99% 场景只存权重。

4.2 单卡 ↔ 多卡权重前缀问题

  • 多卡训练 会引入 "module." 前缀
  • 通用解法 :存权重时统一存 model.module.state_dict(),或加载时 strip 前缀:
python 复制代码
state = torch.load('multi_gpu.pth')
new_state = {k[7:]: v for k, v in state.items()}  # 去掉 'module.'
model.load_state_dict(new_state)

4.3 断点续训:把训练状态一起打包

python 复制代码
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'epoch': epoch,
    'best_acc': best_acc
}, 'checkpoint.pth')

# 恢复
ckpt = torch.load('checkpoint.pth')
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = ckpt['epoch'] + 1

5. 一条完整的开发流水线示例

Sequential / ModuleList 改层 加输入 加输出 单卡 多卡 定义网络 训练 需求变更 局部替换 fc 重写 forward + cat return 多个值 保存权重 state_dict 部署环境 直接 load DataParallel + strip 前缀 继续训练 / 推理


6. 小结 & 行动清单

任务场景 立即能做的最小行动
快速搭 baseline nn.Sequential 10 行内出模型
迁移学习 把 ResNet50 的 fc 替换成你的类别数
断电续训练 把 optimizer & epoch 一起写进 checkpoint
8 卡训练 → 单卡推理 保存 model.module.state_dict()

记住一句话:权重是模型的灵魂,结构是容器;容器可以重建,灵魂必须妥善保存。


参考资料

《深入浅出PyTorch》第5章 5.1--5.4(DatawhaleChina 团队)

官方文档:torch.save / torch.load / nn.DataParallel

相关推荐
梨落秋霜1 小时前
Python入门篇【文件处理】
android·java·python
kisshuan123961 小时前
【深度学习】使用RetinaNet+X101-32x4d_FPN_GHM模型实现茶芽检测与识别_1
人工智能·深度学习
Java 码农1 小时前
RabbitMQ集群部署方案及配置指南03
java·python·rabbitmq
Learn Beyond Limits1 小时前
解构语义:从词向量到神经分类|Decoding Semantics: Word Vectors and Neural Classification
人工智能·算法·机器学习·ai·分类·数据挖掘·nlp
崔庆才丨静觅2 小时前
0代码生成4K高清图!ACE Data Platform × SeeDream 专属方案:小白/商家闭眼冲
人工智能·api
qq_356448372 小时前
机器学习基本概念与梯度下降
人工智能
张登杰踩3 小时前
VIA标注格式转Labelme标注格式
python
水如烟3 小时前
孤能子视角:关系性学习,“喂饭“的小孩认知
人工智能
徐_长卿3 小时前
2025保姆级微信AI群聊机器人教程:教你如何本地打造私人和群聊机器人
人工智能·机器人
XyX——3 小时前
【福利教程】一键解锁 ChatGPT / Gemini / Spotify 教育权益!TG 机器人全自动验证攻略
人工智能·chatgpt·机器人