一文讲清 nn.Sequential 等容器类

我们用生活化比喻 + 简单代码 + 场景对比 ,向初学者彻底讲清楚 PyTorch 中 torch.nn 的这几个"容器类":

Sequential

ModuleList

ModuleDict

ParameterList

ParameterDict


🎯 一句话总览:

这些都是 PyTorch 提供的"收纳盒",帮你组织神经网络中的层或参数,让代码更整洁、更灵活、更容易训练!


一、基础概念回顾

在 PyTorch 中:

  • 所有网络层(如 Linear, Conv2d, ReLU)都是 nn.Module 的子类。
  • 所有可学习参数(如权重、偏置)都是 nn.Parameter(本质是带 requires_grad=True 的 Tensor)。
  • 如果你想把多个层或参数"打包"在一起,就要用这些"收纳盒"。

📦 1. nn.Sequential ------ "流水线式收纳盒"

🧩 比喻:

组装流水线:输入数据从第一个模块进去,依次经过每个模块,最后输出。顺序固定,不能跳过。

✅ 适用场景:

  • 简单前馈网络(如全连接、简单CNN)
  • 模块顺序执行,无需条件分支或跳转

🧑‍💻 使用方法:

python 复制代码
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
    nn.Softmax(dim=1)
)

x = torch.randn(1, 784)
output = model(x)  # 自动依次执行每个层
print(output.shape)  # torch.Size([1, 10])

⚠️ 注意:

  • 不能写条件判断、循环、跳转
  • 不能访问中间层输出(除非拆开写或用 hook)
  • 适合"直筒型"网络

📦 2. nn.ModuleList ------ "可编程的模块列表"

🧩 比喻:

工具箱里的扳手组:你可以按编号取出第几个扳手,也可以循环遍历、动态增减。

✅ 适用场景:

  • 需要 for 循环遍历层(如多层 LSTM、ResNet 的多个残差块)
  • 层数动态决定(比如根据配置文件创建 N 个层)
  • 需要按索引访问特定层

🧑‍💻 使用方法:

python 复制代码
class MyModel(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(100, 100) for _ in range(num_layers)
        ])
        self.relu = nn.ReLU()

    def forward(self, x):
        for layer in self.layers:  # ✅ 可以用 for 循环!
            x = self.relu(layer(x))
        return x

model = MyModel(num_layers=5)
x = torch.randn(2, 100)
output = model(x)
print(output.shape)  # torch.Size([2, 100])

⚠️ 重要:

  • ❌ 不能直接调用 ModuleList(input) ------ 它不是函数!
  • ✅ 必须在 forward 里手动遍历或索引调用
  • ✅ 支持 .append(), .extend(), len(), 索引访问等

📦 3. nn.ModuleDict ------ "带名字的模块字典"

🧩 比喻:

工具墙上的标签挂钩:每个工具(模块)都有名字,你可以按名字取用。

✅ 适用场景:

  • 多个不同功能的模块,需要用名字区分(如不同任务的 head)
  • 动态选择模块(如根据输入选择不同分支)
  • 配置化模型结构(如 config['encoder'] = 'resnet')

🧑‍💻 使用方法:

python 复制代码
class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoders = nn.ModuleDict({
            'image': nn.Conv2d(3, 16, 3),
            'text': nn.Linear(100, 50),
            'audio': nn.LSTM(80, 64, batch_first=True)
        })
        self.classifier = nn.Linear(50, 10)

    def forward(self, x, task_type):
        if task_type == 'image':
            x = self.encoders['image'](x).mean(dim=[2,3])
        elif task_type == 'text':
            x = self.encoders['text'](x)
        elif task_type == 'audio':
            x, _ = self.encoders['audio'](x)
            x = x.mean(dim=1)
        return self.classifier(x)

model = MultiTaskModel()
x_img = torch.randn(1, 3, 32, 32)
output = model(x_img, 'image')
print(output.shape)  # torch.Size([1, 10])

⚠️ 注意:

  • ✅ 支持字典操作:keys(), values(), items(), ['key'], .get()
  • ❌ 不能直接调用 ModuleDict(input)

📦 4. nn.ParameterList ------ "参数列表收纳盒"

🧩 比喻:

一排可调电阻:每个都是独立可学习参数,你可以编号访问、循环调整。

✅ 适用场景:

  • 一组自定义可学习参数,数量动态或较多
  • 不想写 self.param1, self.param2, ...
  • 需要注册到模型中,让 optimizer 能更新它们

🧑‍💻 使用方法:

python 复制代码
class LearnableBiases(nn.Module):
    def __init__(self, num_biases):
        super().__init__()
        # 创建多个可学习偏置参数
        self.biases = nn.ParameterList([
            nn.Parameter(torch.randn(1)) for _ in range(num_biases)
        ])

    def forward(self, x):
        for bias in self.biases:
            x = x + bias  # 每个样本加上所有偏置
        return x

model = LearnableBiases(3)
x = torch.tensor([1.0, 2.0, 3.0])
output = model(x)
print(output)        # tensor([4.xxx, 5.xxx, 6.xxx])
print(model.biases)  # 包含3个 Parameter

⚠️ 注意:

  • 里面的元素必须是 nn.Parameter
  • 会自动被 model.parameters() 收录 → 优化器能更新
  • 支持 .append(), 索引等

📦 5. nn.ParameterDict ------ "带名字的参数字典"

🧩 比喻:

控制面板上的旋钮组:每个旋钮都有名字(如"亮度"、"音量"),按名字调节。

✅ 适用场景:

  • 多个有语义名称的可学习参数
  • 根据配置动态创建参数
  • 需要按名字访问/更新参数

🧑‍💻 使用方法:

python 复制代码
class TaskSpecificParams(nn.Module):
    def __init__(self):
        super().__init__()
        self.task_params = nn.ParameterDict({
            'task_a_weight': nn.Parameter(torch.ones(10)),
            'task_b_bias': nn.Parameter(torch.zeros(1)),
            'global_scale': nn.Parameter(torch.tensor(1.0))
        })

    def forward(self, x, task_name):
        if task_name == 'task_a':
            return x * self.task_params['task_a_weight']
        elif task_name == 'task_b':
            return x + self.task_params['task_b_bias']
        else:
            return x * self.task_params['global_scale']

model = TaskSpecificParams()
x = torch.randn(5, 10)
output = model(x, 'task_a')
print(output.shape)  # torch.Size([5, 10])

⚠️ 注意:

  • 会自动注册参数 → model.parameters() 包含它们
  • 支持字典操作

🆚 对比总结表:

容器 存什么 是否可调用 是否支持索引/循环 是否支持命名访问 典型场景
Sequential Module ✅ 是 简单直筒网络
ModuleList Module ❌ 否 循环层、动态层数
ModuleDict Module ❌ 否 多分支、按名选模块
ParameterList Parameter ❌ 否 一组无名参数
ParameterDict Parameter ❌ 否 一组有名参数(如任务特定参数)

🧠 给初学者的黄金建议:

  1. 90% 情况用 Sequential 或普通 Module 属性就够了
  2. ✅ 当你需要 for 循环层 → 用 ModuleList
  3. ✅ 当你需要 按名字选模块 → 用 ModuleDict
  4. ✅ 自定义参数时,优先考虑是否能封装成 nn.Linear 等标准层
  5. ✅ 如果必须用自定义参数 → 用 ParameterList/Dict,别忘了它们是 nn.Parameter

🎁 小测验(巩固理解):

下面哪个写法是错误的?

python 复制代码
# A
model = nn.Sequential(nn.Linear(10,5), nn.ReLU())
out = model(x)

# B
layers = [nn.Linear(10,5), nn.ReLU()]  # ❌ 普通 list,不会被注册!
out = x
for layer in layers:
    out = layer(out)

# C
layers = nn.ModuleList([nn.Linear(10,5), nn.ReLU()])
out = x
for layer in layers:  # ✅ 正确!
    out = layer(out)

👉 答案:B 是错的!

因为普通 Python list 里的 Module 不会被 PyTorch 识别为子模块model.parameters() 找不到参数,优化器不会更新它们!

相关推荐
DisonTangor2 小时前
MiniMax 开源一个为极致编码与智能体工作流打造的迷你模型——MiniMax-M2
人工智能·语言模型·开源·aigc
Giser探索家4 小时前
无人机桥梁巡检:以“空天地”智慧之力守护交通生命线
大数据·人工智能·算法·安全·架构·无人机
不会学习的小白O^O4 小时前
双通道深度学习框架可实现从无人机激光雷达点云中提取橡胶树冠
人工智能·深度学习·无人机
恒点虚拟仿真4 小时前
虚拟仿真实训破局革新:打造无人机飞行专业实践教学新范式
人工智能·无人机·ai教学·虚拟仿真实训·无人机飞行·无人机专业虚拟仿真·无人机飞行虚拟仿真
鲜枣课堂5 小时前
华为最新光通信架构AI-OTN,如何应对AI浪潮?
人工智能·华为·架构
格林威5 小时前
AOI在新能源电池制造领域的应用
人工智能·数码相机·计算机视觉·视觉检测·制造·工业相机
dxnb225 小时前
Datawhale25年10月组队学习:math for AI+Task5解析几何
人工智能·学习
DooTask官方号5 小时前
DooTask 1.3.38 版本更新:MCP 服务器与 AI 工具深度融合,开启任务管理新体验
运维·服务器·人工智能·开源软件·dootask
Coovally AI模型快速验证7 小时前
OmniNWM:突破自动驾驶世界模型三大瓶颈,全景多模态仿真新标杆(附代码地址)
人工智能·深度学习·机器学习·计算机视觉·自动驾驶·transformer