PyTorch中nn.Module详解

直接print(dir(nn.Module)),得到如下内容:

一、模型结构与参数

  1. parameters()

    • 用途:返回模块的所有可训练参数(如权重、偏置)。

    • 示例

      python 复制代码
      for param in model.parameters():
          print(param.shape)
  2. named_parameters()

    • 用途:返回带名称的参数迭代器,便于调试和访问特定参数。

    • 示例

      python 复制代码
      for name, param in model.named_parameters():
          if 'weight' in name:
              print(name, param.shape)
  3. children()

    • 用途:返回直接子模块的迭代器。

    • 示例

      python 复制代码
      for child in model.children():
          print(type(child))
  4. modules()

    • 用途:递归返回所有子模块(包括自身)。

    • 示例

      python 复制代码
      for module in model.modules():
          if isinstance(module, nn.Conv2d):
              print(module.kernel_size)

二、模型状态与模式

  1. train()eval()

    • 用途:切换训练/推理模式(影响Dropout、BatchNorm等层)。

    • 示例

      python 复制代码
      model.train()  # 训练模式
      model.eval()   # 推理模式
  2. training

    • 用途 :布尔属性,指示当前模式(True 为训练,False 为推理)。

    • 示例

      python 复制代码
      print(model.training)  # 输出:True/False

三、模型保存与加载

  1. state_dict()

    • 用途 :返回包含模型所有参数的字典(OrderedDict)。

    • 示例

      python 复制代码
      torch.save(model.state_dict(), 'model.pth')
  2. load_state_dict()

    • 用途:从字典加载模型参数。

    • 示例

      python 复制代码
      model.load_state_dict(torch.load('model.pth'))

四、设备与数据类型

  1. to()

    • 用途:将模型移动到指定设备(如GPU)或转换数据类型。

    • 示例

      python 复制代码
      model.to('cuda')          # 移动到GPU
      model.to(torch.float16)   # 转换为半精度
  2. cpu()cuda()

    • 用途:快捷方法,分别将模型移动到CPU或GPU。

    • 示例

      python 复制代码
      model.cuda()  # 等价于 model.to('cuda')

五、前向传播与计算

  1. forward()

    • 用途:定义模型的前向传播逻辑(需在自定义模块中重写)。

    • 示例

      python 复制代码
      class MyModel(nn.Module):
          def forward(self, x):
              return self.layer(x)
  2. __call__()

    • 用途 :调用模型实例时触发(内部调用 forward(),支持钩子函数)。

    • 示例

      python 复制代码
      output = model(x)  # 等价于 output = model.forward(x)

六、参数初始化与优化

  1. zero_grad()

    • 用途:清空所有参数的梯度(通常在每个训练步骤前调用)。

    • 示例

      python 复制代码
      optimizer.zero_grad()  # 等价于 model.zero_grad()
  2. requires_grad_()

    • 用途:设置参数是否需要梯度(用于冻结部分模型)。

    • 示例

      python 复制代码
      for param in model.parameters():
          param.requires_grad = False  # 冻结所有参数

七、调试与信息

  1. extra_repr()

    • 用途:自定义模块打印信息(需在子类中重写)。

    • 示例

      python 复制代码
      class MyModel(nn.Module):
          def extra_repr(self):
              return f"hidden_size={self.hidden_size}"
  2. dump_patches()

    • 用途:打印模型的补丁信息(用于调试版本差异)。

八、其他实用方法

  1. apply()

    • 用途:递归应用函数到所有子模块(如初始化权重)。

    • 示例

      python 复制代码
      def init_weights(m):
          if isinstance(m, nn.Conv2d):
              nn.init.kaiming_normal_(m.weight)
      model.apply(init_weights)
  2. register_forward_hook()

    • 用途:注册前向传播钩子(用于捕获中间输出,调试或特征提取)。

总结

日常使用中,最频繁的方法包括:

  • 模型构建parameters(), children(), modules()
  • 训练与推理train(), eval(), zero_grad(), forward()
  • 保存与加载state_dict(), load_state_dict()
  • 设备管理to(), cuda(), cpu()

其他方法根据具体需求选择使用,例如钩子函数用于高级调试,apply() 用于统一初始化。

与nn.Sequential对比:

1. 继承关系与基础属性

  • nn.Module

    • 是所有神经网络模块的基类,提供最基础的功能(如参数管理、钩子机制)。
    • 包含核心属性:_parameters, _modules, _buffers 等。
  • nn.Sequential

    • nn.Module 的子类,继承了所有基础功能。
    • 额外添加了与顺序执行相关的属性(如 __getitem__append)。

2. 核心差异对比

功能类别 nn.Module nn.Sequential
模块构建 需要手动实现 forward 方法 自动按顺序执行子模块,无需定义 forward
子模块访问 通过属性名(如 self.conv1 通过索引或命名(如 model[0]
动态修改 需手动管理子模块 支持 appendextendinsert 等操作
适用场景 复杂网络结构(如ResNet、U-Net) 简单顺序结构(如LeNet卷积部分)

3. 具体方法对比

3.1 公共方法(两者都有)
python 复制代码
# 模型参数与结构
['parameters', 'named_parameters', 'children', 'modules', 'named_children', 'named_modules']

# 模型状态
['train', 'eval', 'training', 'zero_grad', 'requires_grad_']

# 设备与数据类型
['to', 'cpu', 'cuda', 'float', 'double', 'half', 'bfloat16']

# 保存与加载
['state_dict', 'load_state_dict']

# 钩子机制
['register_forward_hook', 'register_backward_hook']
3.2 nn.Sequential 特有的方法
python 复制代码
# 列表操作(动态修改模块顺序)
['__getitem__', '__setitem__', '__delitem__', '__len__', 'append', 'extend', 'insert', 'pop']

# 索引相关
['_get_item_by_idx']
3.3 nn.Module 特有的方法
python 复制代码
# 自定义实现
['forward', 'extra_repr']

# 高级管理
['add_module', 'register_module', 'register_parameter', 'register_buffer']

4. 示例对比

4.1 创建模型
python 复制代码
# nn.Module(需自定义 forward)
class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.conv(x))

# nn.Sequential(自动按顺序执行)
seq_model = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.ReLU()
)
4.2 访问子模块
python 复制代码
# nn.Module
custom_model.conv  # 通过属性名访问

# nn.Sequential
seq_model[0]       # 通过索引访问
seq_model.append(nn.MaxPool2d(2))  # 动态添加模块

5. 总结

特性 nn.Module nn.Sequential
灵活性 高(自定义任意逻辑) 低(仅支持顺序执行)
代码复杂度 较高(需手动实现 forward 低(自动处理前向传播)
动态修改 不支持直接操作(需手动管理) 支持 appendinsert 等操作
适用场景 复杂网络、分支结构、自定义操作 简单堆叠模块(如CNN的卷积部分)

建议:

  • 对于简单的顺序网络,优先使用 nn.Sequential 以减少代码量。
  • 对于包含复杂逻辑(如残差连接、多输入输出)的网络,使用 nn.Module 自定义实现。
相关推荐
gddkxc18 分钟前
AI CRM中的数据分析:悟空AI CRM如何帮助企业优化运营
人工智能·信息可视化·数据分析
我是李武涯20 分钟前
PyTorch Dataloader工作原理 之 default collate_fn操作
pytorch·python·深度学习
AI视觉网奇34 分钟前
Python 检测运动模糊 源代码
人工智能·opencv·计算机视觉
东隆科技35 分钟前
PRIMES推出SFM 2D全扫描场分析仪革新航空航天LPBF激光增材制造
人工智能·制造
无风听海43 分钟前
神经网络之计算图repeat节点
人工智能·深度学习·神经网络
刘晓倩1 小时前
在PyCharm中创建项目并练习
人工智能
Kratzdisteln1 小时前
【Python】绘制椭圆眼睛跟随鼠标交互算法配图详解
python·数学·numpy·pillow·matplotlib·仿射变换
Dev7z1 小时前
阿尔茨海默病早期症状影像分类数据集
人工智能·分类·数据挖掘
神码小Z1 小时前
DeepSeek再开源3B-MoE-OCR模型,视觉压缩高达20倍,支持复杂图表解析等多模态能力!
人工智能
maxruan1 小时前
PyTorch学习
人工智能·pytorch·python·学习