PyTorch 模型开发全栈指南:从定义、修改到保存的完整闭环

如果你在 PyTorch 中只做「调包侠」,那么永远只是在外围打转;只有把「模型定义 → 修改 → 保存/加载」整条链路打通,才算真正拥有了炼丹炉的钥匙。

本文把官方教程 5.1--5.4 浓缩成一篇逻辑闭环的实战笔记,力求"看完即可落地"。


1. 为什么要有"模型工程化"思维?

阶段 痛点举例 本章解法
快速验证 一行行手写 100 层 CNN? Sequential / 模型块
需求变更 ResNet50 输出从 1000 → 10 类 局部层替换 / 外部输入输出
训练中断 断电后需从头再来 断点续训
部署迁移 8 卡训练 → 1 卡推理报错 统一权重前缀

2. 模型定义:三种姿势,按需选择

2.1 Sequential ------ 极简线性堆叠

python 复制代码
net = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten(),
    nn.Linear(64, 10)
)

适用:快速 PoC、网络无分支。

2.2 ModuleList / ModuleDict ------ 乐高式复用

python 复制代码
class TinyResNet(nn.Module):
    def __init__(self, n_blocks=4):
        super().__init__()
        self.blocks = nn.ModuleList([
            Bottleneck(64, 64) for _ in range(n_blocks)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

适用:重复单元、需要动态深度。


3. 模型修改:三大高频需求一次讲透

torchvision.models.resnet50() 为例。

需求 关键 API / 技巧 代码片段
改输出类别 直接替换 fc net.fc = nn.Linear(2048, 10)
加额外输入 forward 里 torch.cat x = torch.cat([net(x), add_var.unsqueeze(1)], 1)
多输出/中间特征 修改 forward 的 return return out, feature

所有修改只需继承 nn.Module 并重写 __init__forward,无需动原始源码。


4. 模型保存与加载:单卡/多卡一次说清

4.1 存什么?

方式 命令 优缺点
仅权重 torch.save(model.state_dict(), path) 轻量、跨环境兼容
整个模型 torch.save(model, path) 含结构,但依赖原始类定义和 Python 版本

实战建议:99% 场景只存权重。

4.2 单卡 ↔ 多卡权重前缀问题

  • 多卡训练 会引入 "module." 前缀
  • 通用解法 :存权重时统一存 model.module.state_dict(),或加载时 strip 前缀:
python 复制代码
state = torch.load('multi_gpu.pth')
new_state = {k[7:]: v for k, v in state.items()}  # 去掉 'module.'
model.load_state_dict(new_state)

4.3 断点续训:把训练状态一起打包

python 复制代码
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'epoch': epoch,
    'best_acc': best_acc
}, 'checkpoint.pth')

# 恢复
ckpt = torch.load('checkpoint.pth')
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
start_epoch = ckpt['epoch'] + 1

5. 一条完整的开发流水线示例

Sequential / ModuleList 改层 加输入 加输出 单卡 多卡 定义网络 训练 需求变更 局部替换 fc 重写 forward + cat return 多个值 保存权重 state_dict 部署环境 直接 load DataParallel + strip 前缀 继续训练 / 推理


6. 小结 & 行动清单

任务场景 立即能做的最小行动
快速搭 baseline nn.Sequential 10 行内出模型
迁移学习 把 ResNet50 的 fc 替换成你的类别数
断电续训练 把 optimizer & epoch 一起写进 checkpoint
8 卡训练 → 单卡推理 保存 model.module.state_dict()

记住一句话:权重是模型的灵魂,结构是容器;容器可以重建,灵魂必须妥善保存。


参考资料

《深入浅出PyTorch》第5章 5.1--5.4(DatawhaleChina 团队)

官方文档:torch.save / torch.load / nn.DataParallel

相关推荐
小oo呆19 小时前
【自然语言处理与大模型】主题建模 Topic Modeling
人工智能·自然语言处理
ycydynq19 小时前
python html 解析的一些写法
linux·python·html
KKKlucifer19 小时前
从被动合规到主动免疫:AI 破解数据智能安全的四大核心场景
人工智能·安全
权泽谦19 小时前
脑肿瘤分割与分类的人工智能研究报告
人工智能·分类·数据挖掘
余俊晖19 小时前
文档图像旋转对VLM OCR的影响及基于Phi-3.5-Vision+分类头的文档方向分类器、及数据构建思路
人工智能·分类·ocr
Cleaner19 小时前
我是如何高效学习大模型的
人工智能·程序员·llm
西猫雷婶19 小时前
CNN的四维Pytorch张量格式
人工智能·pytorch·python·深度学习·神经网络·机器学习·cnn
未来之窗软件服务19 小时前
幽冥大陆(二十三)python语言智慧农业电子秤读取——东方仙盟炼气期
开发语言·python·仙盟创梦ide·东方仙盟·东方仙盟sdk·东方仙盟浏览器
程序员三藏19 小时前
Web自动化测试详细流程和步骤
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
化作星辰19 小时前
解决 OpenCV imread 在 Windows 中读取包含中文路径图片失败的问题
人工智能·opencv·计算机视觉