【pytorch】register_buffer的使用

这篇文章讲解很清晰,以下内容仅做补充,探讨哪些对象需要手动注册,哪些会自动注册

在 PyTorch 中,哪些对象会自动注册为模型的一部分取决于它们的类型以及你如何定义它们。下面列出不需要手动注册、会自动注册的几种情况:

1. nn.Parameter

  • 自动注册 :任何你在 nn.Module 中定义为 nn.Parameter 的张量都会自动注册为模型的参数。它们会被视为模型的可训练参数 ,并且会被包含在模型的 state_dict() 中,也会参与反向传播和优化。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册为模型参数
            self.weight = nn.Parameter(torch.randn(5, 5))
    
        def forward(self, x):
            return x * self.weight
  • 特点

    • 被自动注册为模型的参数,参与梯度计算和更新。
    • 会在 model.parameters() 中找到,并且会在模型保存、加载以及转移设备时自动管理。

2. nn.Module 子类(如 nn.Conv2d, nn.Linear 等)

  • 自动注册 :所有继承自 nn.Module 的子类(如 nn.Conv2d, nn.Linear, nn.ReLU 等)会自动注册为模型的一部分。它们包含的权重和偏置会自动注册为模型参数,并参与训练和保存。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册为模型的一部分
            self.conv = nn.Conv2d(1, 1, 3, 1, 1)
    
        def forward(self, x):
            return self.conv(x)
  • 特点

    • 子模块中的所有 nn.Parameter 自动成为模型的一部分。
    • 例如 nn.Conv2d 的权重和偏置会在 state_dict() 中找到。
    • 可以通过 model.parameters() 获取所有可训练参数。

3. nn.Module 内的属性如果是其他 nn.Module 的实例

  • 当你将另一个 nn.Module 对象作为属性放在自定义模型中时,这个子模块的参数会自动注册为主模块的一部分。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册子模块
            self.linear1 = nn.Linear(10, 5)
            self.linear2 = nn.Linear(5, 1)
    
        def forward(self, x):
            x = self.linear1(x)
            return self.linear2(x)
  • 特点

    • 子模块会自动注册,子模块的参数也会作为主模块的一部分。
    • 子模块也可以递归地包含其他模块,这些都会自动注册。

4. buffers 使用 register_buffer 显式注册

  • 不会自动注册 :对于不需要训练的张量或常量(例如 BN 层中的均值、方差、位置编码等),需要使用 register_buffer 手动注册。这些张量不会参与梯度更新,但会随着模型保存、加载以及转移到设备。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 手动注册 buffer
            self.register_buffer('my_buffer', torch.randn(5, 5))
    
        def forward(self, x):
            return x + self.my_buffer
  • 特点

    • 不参与梯度计算,但在 state_dict 中可见。
    • 可以通过 .to(device) 自动转移到指定设备。

5. 不会自动注册的对象

  • 普通的 Python 对象、torch.Tensor 或者 list/dict 类型不会自动注册为模型的一部分。你需要手动使用 register_buffer 或者 nn.Parameter 来使其成为模型的成员,否则这些对象不会在模型的 state_dict() 中出现,也不会随着模型迁移到 GPU/CPU。

  • 例子

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 不会自动注册
            self.tensor = torch.randn(5, 5)  # 这个张量不会自动成为模型的部分
    
        def forward(self, x):
            return x + self.tensor
  • 解决方法 :如果你希望 tensor 也能成为模型的一部分,使用 register_buffer

    python 复制代码
    self.register_buffer('my_buffer', self.tensor)

总结:

  • 自动注册

    • nn.Parameter : 任何 nn.Parameter 类型的属性会自动成为模型的参数。
    • nn.Module 子类 : 任何包含在 nn.Module 中的子模块(如 nn.Conv2d)会自动注册为模型的一部分。
  • 需要手动注册

    • 非可训练的常量(buffers :需要使用 register_buffer 来显式注册,它们不参与梯度计算,但会保存、加载以及转移到设备。
相关推荐
z are6 分钟前
包含 Python 与 Jupyter的Anaconda的下载安装
开发语言·python·jupyter
aWty_9 分钟前
机器学习--神经网络
人工智能·神经网络·机器学习
傻啦嘿哟21 分钟前
在Excel中通过Python运行公式和函数实现数据计算
开发语言·python·excel
使者大牙21 分钟前
【PyTorch单点知识】深入了解 nn.ModuleList和 nn.ParameterList模块:灵活构建动态网络结构
人工智能·pytorch·python·深度学习
吃什么芹菜卷27 分钟前
机器学习:逻辑回归--过采样
人工智能·笔记·机器学习·逻辑回归
qh0526wy36 分钟前
量化交易的个人见解
python
wydxry42 分钟前
目标检测-小目标检测方法
人工智能·目标检测·计算机视觉
qq_426938481 小时前
git-fork操作指南
git·python
matrixlzp1 小时前
Python 解析 JSON 数据
python
小李很执着1 小时前
【深度智能】:迈向高级时代的人工智能全景指南
人工智能·python·深度学习·算法·机器学习·语言模型