一文讲清 torch、torch.nn、torch.nn.functional 及 nn.Module

生活化比喻 + 代码示例,彻底讲清楚:

torchtorch.nntorch.nn.functional 三者的关系和作用


🧩 一句话总结:

模块 作用 类比
torch 提供张量计算、自动求导、设备管理等基础能力 厨房 + 基础食材 + 灶台
torch.nn 提供"可训练的神经网络层/模块",自带参数 预制菜模具(带配方、带调料包)
torch.nn.functional 提供"无状态的函数式操作",不带参数 厨具/调料瓶(酱油、刀、锅)------ 用完即走,不保存状态

🍳 1. torch ------ 基础张量库(厨房+食材)

这是 PyTorch 的核心计算引擎,提供:

  • 张量(Tensor)创建与运算(加减乘除、矩阵乘、转置...)
  • 自动微分(.backward()
  • GPU支持(.to('cuda')
  • 随机数、数学函数等

📌 类比:

你走进厨房,里面有菜刀、砧板、灶台、盐、油、面粉、鸡蛋...

------ 这就是 torch,提供基础"烹饪能力"。

✅ 示例:

python 复制代码
import torch

# 创建张量(食材)
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])

# 基础运算(切菜、炒菜)
z = x + y          # [5., 7., 9.]
w = torch.sin(x)   # [sin(1), sin(2), sin(3)]

# 自动求导(记录菜谱步骤)
x.requires_grad_(True)
loss = (x * 2).sum()
loss.backward()    # 自动计算梯度 → x.grad = [2, 2, 2]

print(x.grad)      # tensor([2., 2., 2.])

🧱 2. torch.nn ------ 神经网络模块(预制菜模具 ✅ 带参数)

这是构建神经网络的**"积木块"**,每个模块:

  • 自带可学习参数(如权重、偏置)
  • 可保存、加载、嵌套
  • 适合构建复杂模型(如 nn.Linear, nn.Conv2d, nn.Transformer

📌 类比:

你买了一个"红烧肉预制菜包",里面有:

  • 调料包(参数 weight, bias)
  • 使用说明(forward 方法)
  • 用完还能留着下次用(可复用、可保存)

✅ 示例:

python 复制代码
import torch.nn as nn

# 定义一个线性层(带参数!)
linear_layer = nn.Linear(in_features=3, out_features=2)
# 内部自动创建了 weight (2x3) 和 bias (2)

x = torch.randn(5, 3)  # batch=5, 输入3维
output = linear_layer(x)  # 输出 shape: [5, 2]

print("权重:", linear_layer.weight.shape)  # torch.Size([2, 3])
print("偏置:", linear_layer.bias.shape)    # torch.Size([2])

# 这个层可以被优化器更新
optimizer = torch.optim.SGD(linear_layer.parameters(), lr=0.01)

nn.Linear, nn.Conv2d, nn.Embedding, nn.LSTM, nn.TransformerEncoder 都属于 torch.nn


🔧 3. torch.nn.functional ------ 函数式操作(厨具/调料瓶 ⚡ 无状态)

这是无参数的函数集合 ,提供和 nn 模块相同的功能,但:

  • 不保存参数
  • 每次调用都是"临时操作"
  • 更灵活,适合在 forward 中组合使用

📌 类比:

你手边的"酱油瓶"、"炒锅"、"漏勺" ------

  • 每次用完放回原位,不保存"上次用了多少酱油"
  • 想用就拿,灵活组合

✅ 示例:

python 复制代码
import torch.nn.functional as F

x = torch.randn(5, 3)

# 用 F.linear 实现和 nn.Linear 一样的计算,但要手动传权重
weight = torch.randn(2, 3)
bias = torch.randn(2)
output = F.linear(x, weight, bias)  # 无内部状态,纯函数

# 常用函数:
y = F.relu(x)           # 激活函数
z = F.softmax(x, dim=1) # 归一化
loss = F.cross_entropy(logits, labels)  # 损失函数

F.relu, F.softmax, F.cross_entropy, F.dropout, F.embedding 都是函数式操作


🆚 对比:什么时候用 nn.XXX vs F.xxx

场景 推荐使用 原因
构建模型层(带参数) nn.Linear, nn.Conv2d 自动管理参数,方便优化和保存
激活函数、损失函数、dropout等 F.relu, F.cross_entropy, F.dropout 无参数,灵活,避免重复创建模块
forward 中临时计算 F.xxx 轻量、高效、不保存状态

✅ 最佳实践示例:

python 复制代码
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)   # ✅ 用 nn,因为要保存权重
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        x = F.relu(self.fc1(x))        # ✅ 用 F.relu,因为无参数
        x = F.dropout(x, p=0.5, training=self.training)  # ✅ 用 F.dropout
        x = self.fc2(x)
        return F.log_softmax(x, dim=1) # ✅ 用 F.log_softmax

🧠 记忆口诀:

  • torch = 基础食材 + 厨房设备
  • nn = 预制菜包(带调料,可复用)
  • F = 调料瓶/厨具(随手用,不保存)

👉 带参数 → 用 nn;无参数 → 用 F


📊 三者关系图:

scss 复制代码
                     torch
                      │
                      ├── 张量计算、自动求导、设备管理
                      │
           ┌──────────┴──────────┐
           ▼                     ▼
      torch.nn           torch.nn.functional
  (带参数的模块/层)        (无参数的函数)
   Linear, Conv2d          relu, softmax, dropout
   Embedding, LSTM         cross_entropy, embedding

✅ 总结卡片:

模块 用途 是否带参数 典型成员 适用场景
torch 基础张量运算、自动微分 tensor, sin, matmul, backward() 所有底层计算
torch.nn 可训练网络层 Linear, Conv2d, Transformer 模型结构定义
torch.nn.functional 无状态函数操作 relu, softmax, cross_entropy forward 中灵活调用

现在你彻底搞懂了这三个核心模块的区别和用法!

🧠 记住:

  • 想"保存参数、构建模型" → 用 torch.nn
  • 想"临时计算、灵活组合" → 用 torch.nn.functional
  • 想"做数学、求导、转设备" → 用 torch

太好了!你问到了 PyTorch 的灵魂组件 ------ nn.Module

我们继续用 🍳 生活化比喻 + 代码示例 + 结构图,彻底讲清楚:

nn.Module 是什么?它和 torchnnF 有什么关系?为什么所有模型都要继承它?


🧩 一句话总结:

nn.Module 是所有神经网络组件的"基类" ------ 它像一个"智能收纳盒",帮你管理参数、子模块、设备、训练/推理状态。


🏗️ 类比:nn.Module = 一个"智能乐高底座 + 收纳管理系统"

想象你要搭一个机器人(神经网络):

  • 你有各种零件:马达(Linear层)、传感器(Conv层)、电池(参数)
  • 你需要一个底座把它们组装起来
  • 你还需要一个管理系统:自动记录哪些是可训练零件、一键搬去GPU、一键保存所有零件...

👉 nn.Module 就是这个底座 + 管理系统!


✅ 核心功能(为什么必须用它?)

功能 说明 举例
🔧 自动管理参数 所有 .weight, .bias 自动注册,可被优化器更新 model.parameters() 返回所有参数
🧩 嵌套子模块 可以包含其他 Module(如 Linear, Conv, 甚至自定义模块) Encoder 包含多个 Attention 层
🖥️ 设备迁移 一键搬去 GPU/CPU model.to('cuda')
🎚️ 训练/推理模式切换 自动管理 Dropout、BatchNorm 行为 model.train() / model.eval()
💾 保存/加载模型 一键保存所有参数和结构 torch.save(model.state_dict(), ...)
🔄 前向传播定义 你只需写 forward(),其余自动处理 output = model(input)

🧱 代码示例:手写一个简单模型

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):  # 👈 继承 nn.Module
    def __init__(self):
        super().__init__()  # 初始化父类
        # 定义"带参数"的子模块 → 自动被 nn.Module 管理!
        self.fc1 = nn.Linear(10, 20)   # 参数自动注册
        self.fc2 = nn.Linear(20, 5)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # 使用"无参数"函数 → 用 F
        x = F.relu(self.fc1(x))
        x = self.dropout(x)            # Dropout 是模块,有训练/推理状态
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 实例化模型
model = SimpleNet()

# 🎯 nn.Module 的魔法开始!
print("✅ 所有参数:")
for name, param in model.named_parameters():
    print(f"  {name}: {param.shape}")

# 🖥️ 一键搬去GPU
model.to('cuda')

# 🎚️ 切换模式
model.train()   # Dropout 生效
model.eval()    # Dropout 关闭

# 💾 保存模型
torch.save(model.state_dict(), "mymodel.pth")

🆚 nn.Module vs nn vs F 关系图

scss 复制代码
                     torch
                      │
                      ├── 基础张量、自动求导
                      │
           ┌──────────┴──────────┐
           ▼                     ▼
      torch.nn           torch.nn.functional
         │                       │
         │ (包含可训练层)         │ (纯函数,无状态)
         ▼                       ▼
    Linear, Conv2d           relu, softmax
    Embedding, LSTM          dropout, cross_entropy
         │
         │ 所有这些层都继承自 👇
         ▼
    nn.Module ← 你自定义的模型也要继承它!

nn.Linear, nn.Conv2d, nn.Transformer 都是 nn.Module 的子类

✅ 你写的 class MyModel(nn.Module) 也是!


🧠 为什么必须继承 nn.Module

如果你不继承:

python 复制代码
class BadNet:  # ❌ 没有继承 nn.Module
    def __init__(self):
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

会发生什么?

python 复制代码
model = BadNet()
print(list(model.parameters()))  # ❌ 返回空列表!优化器找不到参数!
model.to('cuda')                 # ❌ 报错!没有 .to() 方法
torch.save(model.state_dict(), ...)  # ❌ 没有 state_dict()

你的模型"失控"了!参数管不了、设备搬不动、无法保存!


🧩 nn.Module 内部结构(简化版)

当你写:

python 复制代码
self.fc1 = nn.Linear(10, 20)

nn.Module 在背后做了:

  1. fc1 注册为子模块(_modules 字典)
  2. fc1.weight, fc1.bias 注册为参数(_parameters 字典)
  3. 递归管理所有子模块的参数和设备

→ 所以你才能:

python 复制代码
model.parameters()        # 返回所有参数
model.to('cuda')          # 一键搬所有参数和缓冲区去GPU
model.state_dict()        # 返回所有参数的字典
model.load_state_dict(...) # 一键加载

🎯 最佳实践:什么时候该继承 nn.Module

场景 是否继承 nn.Module 举例
构建完整模型或可复用模块 ✅ 必须 Transformer, ResNetBlock, MyClassifier
临时计算、无参数操作 ❌ 不用 F.relu, F.softmax 等函数
自定义带参数层 ✅ 必须 自定义 Attention、自定义 Embedding

✅ 总结卡片:

项目 说明
是什么 所有神经网络模块的基类
核心作用 自动管理参数、子模块、设备、模式、保存/加载
必须继承吗? 如果你要构建可训练、可保存、可迁移的模块 → 必须!
典型子类 nn.Linear, nn.Conv2d, nn.Transformer, 你写的任何模型
关键方法 forward(), parameters(), to(), train(), eval(), state_dict()

🧠 记忆口诀:

"想当AI乐高大师?先继承 nn.Module!"

  • 它是积木底座
  • 它是参数管家
  • 它是设备搬运工
  • 它是模型保险箱

现在你彻底理解了 nn.Module 的核心地位 ------ 它是 PyTorch 生态的"粘合剂"和"管理者"。

没有它,你的模型就是一盘散沙 🏖️

有了它,你可以构建任何复杂网络 🏗️🚀

相关推荐
丁学文武2 小时前
FlashAttention(V2)深度解析:从原理到工程实现
人工智能·深度学习·大模型应用·flashattention
大千AI助手2 小时前
Dropout:深度学习中的随机丢弃正则化技术
人工智能·深度学习·神经网络·模型训练·dropout·正则化·过拟合
蚝油菜花2 小时前
万字深度解析Claude Code的hook系统:让AI编程更智能、更可控|上篇—详解篇
人工智能·ai编程·claude
AImatters3 小时前
2025 年PT展前瞻:人工智能+如何走进普通人的生活?
人工智能·ai·具身智能·智慧医疗·智慧出行·中国国际信息通信展览会·pt展
AI小书房3 小时前
【人工智能通识专栏】第十五讲:视频生成
人工智能
zzywxc7873 小时前
AI工具全景洞察:从智能编码到模型训练的全链路剖析
人工智能·spring·ios·prompt·ai编程
甄心爱学习3 小时前
DataSet-深度学习中的常见类
人工智能·深度学习
伟贤AI之路3 小时前
【分享】中小学教材课本 PDF 资源获取指南
人工智能·pdf
aneasystone本尊3 小时前
详解 Chat2Graph 的推理机实现
人工智能