# PyTorch 中 `nn.ModuleList` 详解

一、基本概念与核心功能

nn.ModuleList 是 PyTorch 中一个重要的容器类 ,用于存储和管理多个 nn.Module 对象。与普通的 Python 列表不同,nn.ModuleList 专门为神经网络模块设计,提供了关键的功能。

1.1 核心特性

特性 描述 重要性
自动参数注册 自动将列表中的模块参数注册到父模块 关键特性,确保参数可训练
模块化组织 便于管理多个相关子模块 提高代码可读性和维护性
PyTorch集成 与PyTorch生态系统完全兼容 支持序列化、设备移动等操作
动态构建 支持动态添加/删除模块 灵活构建复杂网络

1.2 基本使用示例

python 复制代码
import torch.nn as nn

class SimpleNetwork(nn.Module):
    def __init__(self, num_layers=3, hidden_dim=128):
        super().__init__()
        
        # 使用nn.ModuleList存储多个线性层
        self.layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) 
            for _ in range(num_layers)
        ])
        
        # 激活函数
        self.activation = nn.ReLU()
    
    def forward(self, x):
        # 顺序处理每个层
        for layer in self.layers:
            x = self.activation(layer(x))
        return x

# 创建网络实例
model = SimpleNetwork(num_layers=5, hidden_dim=128)
print(f"网络参数总数: {sum(p.numel() for p in model.parameters())}")
python 复制代码
网络参数总数: 82560

二、与普通Python列表的关键区别

2.1 参数注册机制

这是 nn.ModuleList 与普通列表最核心的区别:

python 复制代码
import torch.nn as nn

class CompareContainers(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 方法1: 使用普通Python列表
        self.normal_list = [
            nn.Linear(10, 20),
            nn.Linear(20, 30),
            nn.Linear(30, 40)
        ]
        
        # 方法2: 使用nn.ModuleList
        self.module_list = nn.ModuleList([
            nn.Linear(10, 20),
            nn.Linear(20, 30), 
            nn.Linear(30, 40)
        ])
    
    def analyze_parameters(self):
        """分析两种容器的参数差异"""
        print("=== 参数分析 ===")
        
        # 获取总参数
        all_params = list(self.parameters())
        print(f"模型总参数量: {len(all_params)}")
        
        # 检查普通列表中的参数是否被注册
        normal_list_params = []
        for module in self.normal_list:
            normal_list_params.extend(list(module.parameters()))
        
        print(f"普通列表参数量: {len(normal_list_params)}")
        print(f"ModuleList参数量: {len(list(self.module_list.parameters()))}")
        
        # 关键:普通列表的参数不会被自动优化
        return len(all_params) == len(list(self.module_list.parameters()))

# 测试
model = CompareContainers()
result = model.analyze_parameters()
print(f"普通列表参数是否被注册: {not result}")
python 复制代码
=== 参数分析 ===
模型总参数量: 6
普通列表参数量: 6
ModuleList参数量: 6
普通列表参数是否被注册: False

关键问题 :普通列表中的模块参数不会被优化器识别 ,因为PyTorch只在nn.Module的直接属性中查找参数。

2.2 序列化与设备移动

python 复制代码
import torch
import torch.nn as nn

def test_serialization():
    """测试序列化和设备移动能力"""
    
    class TestNetwork(nn.Module):
        def __init__(self):
            super().__init__()
            # 只有ModuleList支持完整的序列化
            self.layers = nn.ModuleList([nn.Linear(10, 20), nn.Linear(20, 10)])
            
        def forward(self, x):
            for layer in self.layers:
                x = torch.relu(layer(x))
            return x
    
    # 创建模型
    model = TestNetwork()
    
    # 1. 序列化测试
    torch.save(model.state_dict(), 'model_checkpoint.pth')
    
    # 2. 设备移动测试
    if torch.cuda.is_available():
        model = model.cuda()  # 自动移动所有参数到GPU
        print("模型成功移动到GPU")
    
    # 3. 加载检查点
    loaded_model = TestNetwork()
    loaded_model.load_state_dict(torch.load('model_checkpoint.pth'))
    print("模型成功从检查点加载")
    
    return model

test_serialization()

三、nn.ModuleList 的高级用法

3.1 动态构建网络架构

python 复制代码
import torch.nn as nn

class DynamicNetwork(nn.Module):
    """动态构建网络层"""
    
    def __init__(self, layer_configs):
        """
        参数:
            layer_configs: 层配置列表,如[(10, 20), (20, 30), (30, 5)]
        """
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        # 动态添加层
        for i, (in_dim, out_dim) in enumerate(layer_configs):
            self.layers.append(nn.Linear(in_dim, out_dim))
            
            # 添加批归一化(除了输出层)
            if i < len(layer_configs) - 1:
                self.layers.append(nn.BatchNorm1d(out_dim))
        
        # 激活函数
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            
            # 特定处理
            if isinstance(layer, nn.Linear) and i < len(self.layers) - 2:
                x = self.activation(x)
                x = self.dropout(x)
        
        return x

# 使用示例
configs = [(784, 256), (256, 128), (128, 64), (64, 10)]
model = DynamicNetwork(configs)
print(f"网络结构: {model}")
python 复制代码
网络结构: DynamicNetwork(
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=64, out_features=10, bias=True)
  )
  (activation): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
)

3.2 条件网络构建

python 复制代码
import torch.nn as nn


class ConditionalNetwork(nn.Module):
    """根据条件动态构建网络"""

    def __init__(self, num_blocks, use_residual=True, use_attention=False):
        super().__init__()

        self.blocks = nn.ModuleList()
        self.use_residual = use_residual
        self.use_attention = use_attention

        # 构建残差块
        for i in range(num_blocks):
            block = self._build_block(hidden_dim=128, block_id=i)
            self.blocks.append(block)

        # 可选:添加注意力层
        if use_attention:
            self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=4)

    def _build_block(self, hidden_dim, block_id):
        """构建单个块"""
        layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        ])
        return layers

    def forward(self, x):
        residual = x

        for i, block in enumerate(self.blocks):
            # 处理块中的每个层
            for layer in block:
                x = layer(x)

            # 残差连接
            if self.use_residual and x.shape == residual.shape:
                x = x + residual

            residual = x  # 更新残差

        # 注意力机制
        if self.use_attention and hasattr(self, 'attention'):
            x = x.unsqueeze(1)  # 添加序列维度
            x, _ = self.attention(x, x, x)
            x = x.squeeze(1)

        return x

四、与 nn.Sequential 的对比

4.1 功能对比表

特性 nn.ModuleList nn.Sequential
参数注册 ✅ 自动注册 ✅ 自动注册
前向传播 ❌ 需要手动实现 ✅ 自动顺序执行
灵活性 ✅ 高(任意连接) ❌ 低(只能顺序)
条件控制 ✅ 支持条件分支 ❌ 不支持
跳过连接 ✅ 容易实现 ❌ 难以实现
并行处理 ✅ 支持 ❌ 不支持
循环处理 ✅ 支持 ❌ 不支持

4.2 实际应用案例

Transformer 中的多头注意力实现

python 复制代码
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    """使用ModuleList实现多头注意力"""

    def __init__(self, d_model=512, num_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 使用ModuleList存储多个注意力头的线性变换
        self.q_proj_layers = nn.ModuleList([
            nn.Linear(d_model, self.d_k) for _ in range(num_heads)
        ])
        self.k_proj_layers = nn.ModuleList([
            nn.Linear(d_model, self.d_k) for _ in range(num_heads)
        ])
        self.v_proj_layers = nn.ModuleList([
            nn.Linear(d_model, self.d_k) for _ in range(num_heads)
        ])

        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 存储每个头的输出
        head_outputs = []

        # 并行处理每个头
        for i in range(self.num_heads):
            # 线性投影
            q = self.q_proj_layers[i](query)
            k = self.k_proj_layers[i](key)
            v = self.v_proj_layers[i](value)

            # 注意力计算
            scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))

            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)

            attn_weights = torch.softmax(scores, dim=-1)
            attn_weights = self.dropout(attn_weights)

            # 应用注意力权重
            head_output = torch.matmul(attn_weights, v)
            head_outputs.append(head_output)

        # 拼接多头输出
        multi_head_output = torch.cat(head_outputs, dim=-1)

        # 输出投影
        output = self.out_proj(multi_head_output)

        return output


# 调用示例
if __name__ == "__main__":
    # 定义参数
    batch_size = 2
    seq_length = 10
    d_model = 512
    num_heads = 8

    # 创建多头注意力实例
    multihead_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

    # 创建随机输入
    query = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)
    key = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)
    value = torch.rand(batch_size, seq_length, d_model)  # (batch_size, seq_length, d_model)

    # 可选的mask
    mask = torch.ones(batch_size, seq_length, seq_length)  # (batch_size, seq_length, seq_length)

    # 计算注意力输出
    output = multihead_attn(query, key, value, mask)

    print("Attention Output Shape:", output.shape)  # 应该是 (batch_size, seq_length, d_model)
python 复制代码
Attention Output Shape: torch.Size([2, 10, 512])

残差网络(ResNet)块实现

python 复制代码
import torch
import torch.nn as nn


class ResNetBlock(nn.Module):
    """使用ModuleList构建残差块"""

    def __init__(self, in_channels, out_channels, num_layers=2, downsample=False):
        super().__init__()

        self.layers = nn.ModuleList()

        # 构建多个卷积层
        for i in range(num_layers):
            in_ch = in_channels if i == 0 else out_channels
            stride = 2 if (downsample and i == 0) else 1

            layer = nn.Sequential(
                nn.Conv2d(in_ch, out_channels, kernel_size=3,
                          stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
            self.layers.append(layer)

        # 下采样层
        if downsample or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                          stride=2 if downsample else 1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x

        # 顺序处理每个层
        for layer in self.layers:
            x = layer(x)

        # 残差连接
        if self.downsample is not None:
            identity = self.downsample(identity)

        x += identity
        x = torch.relu(x)

        return x


# 调用示例
if __name__ == "__main__":
    # 定义输入参数
    batch_size = 2
    in_channels = 64
    out_channels = 128
    height, width = 32, 32  # 输入图像的高度和宽度

    # 创建一个随机输入张量
    input_tensor = torch.rand(batch_size, in_channels, height, width)  # (N, C, H, W)

    # 创建残差块实例
    resnet_block = ResNetBlock(in_channels=in_channels, out_channels=out_channels, num_layers=2, downsample=True)

    # 计算输出
    output_tensor = resnet_block(input_tensor)

    print("Output Shape:", output_tensor.shape)  # 应该是 (batch_size, out_channels, height/2, width/2)
python 复制代码
Output Shape: torch.Size([2, 128, 16, 16])
相关推荐
2501_942818912 小时前
AI 多模态全栈项目实战:Vue3 + Node 打造 TTS+ASR 全家桶!
vue.js·人工智能·node.js
CICI131414132 小时前
藦卡机器人:让焊接更洁净、更精准、更智能
大数据·人工智能
嵌入式老牛2 小时前
面向能源领域的AI大模型工程化落地方法
人工智能·能源
BoBoZz192 小时前
ResetCameraOrientation 保存、修改和恢复摄像机的精确视角参数
python·vtk·图形渲染·图形处理
天竺鼠不该去劝架2 小时前
金融智能体三大核心场景:银行运营、证券研究、保险理赔效率提升路径
人工智能·科技·自动化
Small___ming2 小时前
【人工智能基础】深度学习归一化层完全指南:从入门到精通
人工智能·深度学习·归一化
aloha_7892 小时前
python基础面经八股
开发语言·python
短视频矩阵源码定制2 小时前
矩阵系统源头厂家
大数据·人工智能·矩阵