PyTorch 系列 之 nn.Module:所有模型的骨架

开篇:Module 不是"普通父类"

写 PyTorch 模型,最终都会回到 nn.Module。

它不是一个空壳父类。它是模型的总管:管参数、管子模块、管状态、管调用、管保存加载。

你写的 forward 只是计算逻辑。真正让模型"像模型一样工作"的,是 nn.Module 背后的注册和调用机制。

把 Module 放到 PyTorch 体系里看

1. 官方定位:Module 是神经网络模块的基类

PyTorch 官方文档说得很直接:torch.nn.Module 是所有神经网络模块的基类,自己的模型也应该继承它。

更关键的是:Module 可以包含其他 Module。也就是说,模型天然是一棵树。

这棵树的价值很大。只要树建好了,PyTorch 就能递归找参数、迁移设备、切换模式、保存状态。

模型是 Module 组件树

2. 写 Module,只记住两个函数

init:放结构,放状态,放需要被 PyTorch 管理的对象。

forward:放计算,描述输入如何一步步变成输出。

简单说:init 管"有什么",forward 管"怎么走"。

init 与 forward 的分工

复制代码
import torch
import torch.nn as nn
class TinyNet(nn.Module):
 def __init__(self):
 super().__init__()
 self.fc1 = nn.Linear(10, 32)
 self.act = nn.ReLU()
 self.fc2 = nn.Linear(32, 2)
 def forward(self, x):
 x = self.fc1(x)
 x = self.act(x)
 return self.fc2(x)
model = TinyNet()
y = model(torch.randn(4, 10))

这段代码很短,但已经触发了 Module 的核心能力:fc1、act、fc2 都被注册成子模块;Linear 里的 weight 和 bias 会被递归找到。

3. 自动注册:不是魔法,是 setattr

为什么 self.fc = nn.Linear(...) 之后,model.parameters() 能找到里面的权重?

因为 nn.Module 重写了 setattr。你给属性赋值时,它会看 value 的类型。

如果是 nn.Parameter,放进 _parameters;如果是 nn.Module,放进 _modules;如果是 Buffer 或 register_buffer,放进 _buffers;其他对象才只是普通属性。

setattr 如何把对象分流到不同注册表

4. 四类成员:命运完全不同

很多 PyTorch Bug,本质是对象放错了地方。

想让优化器更新,必须是 Parameter。想随模型保存但不被优化,应该是 Buffer。想嵌套层,必须是 Module 或容器 Module。

Parameter、Buffer、Module、普通属性对比

5. 源码入口一:Module.init 先准备内部字典

源码里,Module.init 会先创建几个关键容器:_parameters、_buffers、_modules,以及 hooks 相关字典。

这就是为什么自定义模型第一行通常要写 super().init()。没有这一步,后续注册没有地方可放。

复制代码
# torch/nn/modules/module.py 的核心思路
super().__setattr__("training", True)
super().__setattr__("_parameters", {})
super().__setattr__("_buffers", {})
super().__setattr__("_modules", {})

注意:这里源码用的是 super().setattr,不是 self.xxx = xxx。这样可以绕开 Module.setattr 的注册逻辑,避免初始化时自己拦截自己。

6. 源码入口二:model(x) 会走 call

官方文档提醒:虽然计算配方写在 forward 里,但应该调用 Module 实例本身,而不是直接调用 forward。

原因很简单:model(x) 会进入 call 和 _call_impl,那里会处理 forward hooks、pre hooks、编译调用等逻辑。直接 model.forward(x) 会绕开这些能力。

model(x) 的真实调用路径

这条链路也解释了很多高级能力:特征提取可以靠 hook,调试可以靠 hook,torch.compile 也可以挂在 Module 调用入口上。

7. parameters():优化器为什么能拿到所有权重

parameters() 返回的是 Module 参数迭代器,通常直接交给优化器。

它默认 recurse=True,会递归进入所有子模块。所以你只需要写 optimizer = torch.optim.Adam(model.parameters()),不需要手动收集每一层权重。

复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for name, p in model.named_parameters():
 print(name, p.shape)

关键点:被注册的参数才会出现。普通 Python list 里的层不会自动注册。需要用 nn.ModuleList 或 nn.Sequential。

8. Buffer:不学习,但要跟着模型走

Buffer 是模型状态,但不是可学习参数。

典型例子是 BatchNorm 的 running_mean。它会影响计算,也需要保存和加载,但不应该交给优化器更新。

复制代码
class RunningStat(nn.Module):
 def __init__(self):
 super().__init__()
 self.register_buffer("running_mean", torch.zeros(10))
 def forward(self, x):
 return x - self.running_mean

persistent=True 的 Buffer 会进 state_dict。persistent=False 的 Buffer 会跟着 to()/cuda() 迁移,但不会被保存。

9. state_dict:模型保存的核心

PyTorch 推荐保存 state_dict,而不是把整个模型对象直接打包。

state_dict 里包含参数和持久 Buffer。key 是名字,value 是 Tensor。名字来自 Module 树的路径。

state_dict 如何收集模型状态

复制代码
torch.save(model.state_dict(), "model.pt")
new_model = TinyNet()
new_model.load_state_dict(torch.load("model.pt"))

理解 state_dict,后面学断点续训、迁移学习、分布式保存、部署导出都会轻松很多。

10. train() 和 eval():只切模式,不跑训练

model.train() 不会自动开始训练。model.eval() 也不会自动开始推理。

它们只是递归设置每个 Module 的 training 标志。某些层会读这个标志,比如 Dropout 和 BatchNorm。

train/eval 的真实作用

一个常见误解:eval() 不等于 no_grad()。eval() 只影响层行为;no_grad() 才影响 Autograd 是否记录梯度。

11. to():为什么整棵模型能搬到 GPU

model.to(device) 能生效,是因为 Module 会递归处理自己的参数、Buffer 和子模块。

普通属性不会自动被迁移。普通 Tensor 如果没注册成 Buffer,也不会被 Module 管。

复制代码
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyNet().to(device)
x = torch.randn(4, 10, device=device)
y = model(x)

记住:需要跟着模型走的 Tensor,要么是 Parameter,要么注册成 Buffer。

12. 源码级主线:六个动作串起来

读 Module 源码,不要从头到尾硬啃。按动作读:初始化、注册、调用、遍历、状态、模式/迁移。

最重要的三个问题:

对象赋值后,去了 _parameters、_buffers、_modules,还是普通 dict

模型调用时,是走 model(x),还是绕过了 call

保存加载时,这个对象会不会出现在 state_dict 里?

13. 常见坑:先看有没有被注册

当参数没有更新、模型保存后丢东西、to(cuda) 后设备不一致,先不要怀疑玄学。

先检查注册。

复制代码
print(model)
print(dict(model.named_parameters()).keys())
print(dict(model.named_buffers()).keys())
print(model.state_dict().keys())

总结

nn.Module 是 PyTorch 模型的骨架。

它把散落的层、参数、状态组织成一棵可训练、可迁移、可保存、可调试的树。

掌握 Module,后面讲 Linear、Conv、Transformer、训练循环、模型保存、分布式训练,都会有同一套底层语言。


内容来源:PyTorch 系列 之 nn.Module:所有模型的骨架:功能变化与行业影响解析_热闻岛

相关推荐
AI玫瑰助手1 小时前
Python模块:from...import...导入指定内容
开发语言·python·信息可视化
海天一色y1 小时前
深入理解 Function Calling、MCP 与 Skills:AI Agent 的三层能力架构
人工智能·mcp·skills
小星AI1 小时前
FastMCP 2.0 实战:10 分钟给 Claude Code 装上手
人工智能·agent
小森林之主1 小时前
Python re 模块速查:从实战对比中掌握正则表达式
python·正则表达式·性能测试·re模块·编程实战
昨日之日20061 小时前
Higgs Audio v3 - 超自然多语言情感TTS,一键克隆声音 一键整合包下载
人工智能·音视频
极客老王说Agent1 小时前
2026全业务链条断层破解:智能体如何重构端到端业务闭环
人工智能·ai·chatgpt·重构
云烟成雨TD1 小时前
Spring AI 1.x 系列【61】Spring AI 2.0 升级指南
java·人工智能·spring
Luhui Dev1 小时前
几何图,现在可以用 API 一句话生成
人工智能·数学·luhuidev
郭wes代码1 小时前
Win10 拒绝访问、长期关机自动维护与声音图标灰色故障解决记录
windows·python·开源