用 生活化比喻 + 代码示例,彻底讲清楚:
torch
、torch.nn
、torch.nn.functional
三者的关系和作用
🧩 一句话总结:
模块 | 作用 | 类比 |
---|---|---|
torch |
提供张量计算、自动求导、设备管理等基础能力 | 厨房 + 基础食材 + 灶台 |
torch.nn |
提供"可训练的神经网络层/模块",自带参数 | 预制菜模具(带配方、带调料包) |
torch.nn.functional |
提供"无状态的函数式操作",不带参数 | 厨具/调料瓶(酱油、刀、锅)------ 用完即走,不保存状态 |
🍳 1. torch
------ 基础张量库(厨房+食材)
这是 PyTorch 的核心计算引擎,提供:
- 张量(Tensor)创建与运算(加减乘除、矩阵乘、转置...)
- 自动微分(
.backward()
) - GPU支持(
.to('cuda')
) - 随机数、数学函数等
📌 类比:
你走进厨房,里面有菜刀、砧板、灶台、盐、油、面粉、鸡蛋...
------ 这就是
torch
,提供基础"烹饪能力"。
✅ 示例:
python
import torch
# 创建张量(食材)
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])
# 基础运算(切菜、炒菜)
z = x + y # [5., 7., 9.]
w = torch.sin(x) # [sin(1), sin(2), sin(3)]
# 自动求导(记录菜谱步骤)
x.requires_grad_(True)
loss = (x * 2).sum()
loss.backward() # 自动计算梯度 → x.grad = [2, 2, 2]
print(x.grad) # tensor([2., 2., 2.])
🧱 2. torch.nn
------ 神经网络模块(预制菜模具 ✅ 带参数)
这是构建神经网络的**"积木块"**,每个模块:
- 自带可学习参数(如权重、偏置)
- 可保存、加载、嵌套
- 适合构建复杂模型(如
nn.Linear
,nn.Conv2d
,nn.Transformer
)
📌 类比:
你买了一个"红烧肉预制菜包",里面有:
- 调料包(参数 weight, bias)
- 使用说明(forward 方法)
- 用完还能留着下次用(可复用、可保存)
✅ 示例:
python
import torch.nn as nn
# 定义一个线性层(带参数!)
linear_layer = nn.Linear(in_features=3, out_features=2)
# 内部自动创建了 weight (2x3) 和 bias (2)
x = torch.randn(5, 3) # batch=5, 输入3维
output = linear_layer(x) # 输出 shape: [5, 2]
print("权重:", linear_layer.weight.shape) # torch.Size([2, 3])
print("偏置:", linear_layer.bias.shape) # torch.Size([2])
# 这个层可以被优化器更新
optimizer = torch.optim.SGD(linear_layer.parameters(), lr=0.01)
✅
nn.Linear
,nn.Conv2d
,nn.Embedding
,nn.LSTM
,nn.TransformerEncoder
都属于torch.nn
🔧 3. torch.nn.functional
------ 函数式操作(厨具/调料瓶 ⚡ 无状态)
这是无参数的函数集合 ,提供和 nn
模块相同的功能,但:
- 不保存参数
- 每次调用都是"临时操作"
- 更灵活,适合在
forward
中组合使用
📌 类比:
你手边的"酱油瓶"、"炒锅"、"漏勺" ------
- 每次用完放回原位,不保存"上次用了多少酱油"
- 想用就拿,灵活组合
✅ 示例:
python
import torch.nn.functional as F
x = torch.randn(5, 3)
# 用 F.linear 实现和 nn.Linear 一样的计算,但要手动传权重
weight = torch.randn(2, 3)
bias = torch.randn(2)
output = F.linear(x, weight, bias) # 无内部状态,纯函数
# 常用函数:
y = F.relu(x) # 激活函数
z = F.softmax(x, dim=1) # 归一化
loss = F.cross_entropy(logits, labels) # 损失函数
✅
F.relu
,F.softmax
,F.cross_entropy
,F.dropout
,F.embedding
都是函数式操作
🆚 对比:什么时候用 nn.XXX
vs F.xxx
?
场景 | 推荐使用 | 原因 |
---|---|---|
构建模型层(带参数) | nn.Linear , nn.Conv2d |
自动管理参数,方便优化和保存 |
激活函数、损失函数、dropout等 | F.relu , F.cross_entropy , F.dropout |
无参数,灵活,避免重复创建模块 |
在 forward 中临时计算 |
F.xxx |
轻量、高效、不保存状态 |
✅ 最佳实践示例:
python
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20) # ✅ 用 nn,因为要保存权重
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = F.relu(self.fc1(x)) # ✅ 用 F.relu,因为无参数
x = F.dropout(x, p=0.5, training=self.training) # ✅ 用 F.dropout
x = self.fc2(x)
return F.log_softmax(x, dim=1) # ✅ 用 F.log_softmax
🧠 记忆口诀:
torch
= 基础食材 + 厨房设备nn
= 预制菜包(带调料,可复用)F
= 调料瓶/厨具(随手用,不保存)👉 带参数 → 用
nn
;无参数 → 用F
📊 三者关系图:
scss
torch
│
├── 张量计算、自动求导、设备管理
│
┌──────────┴──────────┐
▼ ▼
torch.nn torch.nn.functional
(带参数的模块/层) (无参数的函数)
Linear, Conv2d relu, softmax, dropout
Embedding, LSTM cross_entropy, embedding
✅ 总结卡片:
模块 | 用途 | 是否带参数 | 典型成员 | 适用场景 |
---|---|---|---|---|
torch |
基础张量运算、自动微分 | ❌ | tensor , sin , matmul , backward() |
所有底层计算 |
torch.nn |
可训练网络层 | ✅ | Linear , Conv2d , Transformer |
模型结构定义 |
torch.nn.functional |
无状态函数操作 | ❌ | relu , softmax , cross_entropy |
forward 中灵活调用 |
现在你彻底搞懂了这三个核心模块的区别和用法!
🧠 记住:
- 想"保存参数、构建模型" → 用
torch.nn
- 想"临时计算、灵活组合" → 用
torch.nn.functional
- 想"做数学、求导、转设备" → 用
torch
太好了!你问到了 PyTorch 的灵魂组件 ------ nn.Module
!
我们继续用 🍳 生活化比喻 + 代码示例 + 结构图,彻底讲清楚:
nn.Module
是什么?它和torch
、nn
、F
有什么关系?为什么所有模型都要继承它?
🧩 一句话总结:
nn.Module
是所有神经网络组件的"基类" ------ 它像一个"智能收纳盒",帮你管理参数、子模块、设备、训练/推理状态。
🏗️ 类比:nn.Module
= 一个"智能乐高底座 + 收纳管理系统"
想象你要搭一个机器人(神经网络):
- 你有各种零件:马达(Linear层)、传感器(Conv层)、电池(参数)
- 你需要一个底座把它们组装起来
- 你还需要一个管理系统:自动记录哪些是可训练零件、一键搬去GPU、一键保存所有零件...
👉 nn.Module
就是这个底座 + 管理系统!
✅ 核心功能(为什么必须用它?)
功能 | 说明 | 举例 |
---|---|---|
🔧 自动管理参数 | 所有 .weight , .bias 自动注册,可被优化器更新 |
model.parameters() 返回所有参数 |
🧩 嵌套子模块 | 可以包含其他 Module (如 Linear, Conv, 甚至自定义模块) |
Encoder 包含多个 Attention 层 |
🖥️ 设备迁移 | 一键搬去 GPU/CPU | model.to('cuda') |
🎚️ 训练/推理模式切换 | 自动管理 Dropout、BatchNorm 行为 | model.train() / model.eval() |
💾 保存/加载模型 | 一键保存所有参数和结构 | torch.save(model.state_dict(), ...) |
🔄 前向传播定义 | 你只需写 forward() ,其余自动处理 |
output = model(input) |
🧱 代码示例:手写一个简单模型
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module): # 👈 继承 nn.Module
def __init__(self):
super().__init__() # 初始化父类
# 定义"带参数"的子模块 → 自动被 nn.Module 管理!
self.fc1 = nn.Linear(10, 20) # 参数自动注册
self.fc2 = nn.Linear(20, 5)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# 使用"无参数"函数 → 用 F
x = F.relu(self.fc1(x))
x = self.dropout(x) # Dropout 是模块,有训练/推理状态
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# 实例化模型
model = SimpleNet()
# 🎯 nn.Module 的魔法开始!
print("✅ 所有参数:")
for name, param in model.named_parameters():
print(f" {name}: {param.shape}")
# 🖥️ 一键搬去GPU
model.to('cuda')
# 🎚️ 切换模式
model.train() # Dropout 生效
model.eval() # Dropout 关闭
# 💾 保存模型
torch.save(model.state_dict(), "mymodel.pth")
🆚 nn.Module
vs nn
vs F
关系图
scss
torch
│
├── 基础张量、自动求导
│
┌──────────┴──────────┐
▼ ▼
torch.nn torch.nn.functional
│ │
│ (包含可训练层) │ (纯函数,无状态)
▼ ▼
Linear, Conv2d relu, softmax
Embedding, LSTM dropout, cross_entropy
│
│ 所有这些层都继承自 👇
▼
nn.Module ← 你自定义的模型也要继承它!
✅
nn.Linear
,nn.Conv2d
,nn.Transformer
都是nn.Module
的子类✅ 你写的
class MyModel(nn.Module)
也是!
🧠 为什么必须继承 nn.Module
?
如果你不继承:
python
class BadNet: # ❌ 没有继承 nn.Module
def __init__(self):
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
会发生什么?
python
model = BadNet()
print(list(model.parameters())) # ❌ 返回空列表!优化器找不到参数!
model.to('cuda') # ❌ 报错!没有 .to() 方法
torch.save(model.state_dict(), ...) # ❌ 没有 state_dict()
→ 你的模型"失控"了!参数管不了、设备搬不动、无法保存!
🧩 nn.Module
内部结构(简化版)
当你写:
python
self.fc1 = nn.Linear(10, 20)
nn.Module
在背后做了:
- 把
fc1
注册为子模块(_modules
字典) - 把
fc1.weight
,fc1.bias
注册为参数(_parameters
字典) - 递归管理所有子模块的参数和设备
→ 所以你才能:
python
model.parameters() # 返回所有参数
model.to('cuda') # 一键搬所有参数和缓冲区去GPU
model.state_dict() # 返回所有参数的字典
model.load_state_dict(...) # 一键加载
🎯 最佳实践:什么时候该继承 nn.Module
?
场景 | 是否继承 nn.Module |
举例 |
---|---|---|
构建完整模型或可复用模块 | ✅ 必须 | Transformer , ResNetBlock , MyClassifier |
临时计算、无参数操作 | ❌ 不用 | 用 F.relu , F.softmax 等函数 |
自定义带参数层 | ✅ 必须 | 自定义 Attention、自定义 Embedding |
✅ 总结卡片:
项目 | 说明 |
---|---|
是什么 | 所有神经网络模块的基类 |
核心作用 | 自动管理参数、子模块、设备、模式、保存/加载 |
必须继承吗? | 如果你要构建可训练、可保存、可迁移的模块 → 必须! |
典型子类 | nn.Linear , nn.Conv2d , nn.Transformer , 你写的任何模型 |
关键方法 | forward() , parameters() , to() , train() , eval() , state_dict() |
🧠 记忆口诀:
"想当AI乐高大师?先继承
nn.Module
!"
- 它是积木底座
- 它是参数管家
- 它是设备搬运工
- 它是模型保险箱
现在你彻底理解了 nn.Module
的核心地位 ------ 它是 PyTorch 生态的"粘合剂"和"管理者"。
没有它,你的模型就是一盘散沙 🏖️
有了它,你可以构建任何复杂网络 🏗️🚀