【pytorch】register_buffer的使用

这篇文章讲解很清晰,以下内容仅做补充,探讨哪些对象需要手动注册,哪些会自动注册

在 PyTorch 中,哪些对象会自动注册为模型的一部分取决于它们的类型以及你如何定义它们。下面列出不需要手动注册、会自动注册的几种情况:

1. nn.Parameter

  • 自动注册 :任何你在 nn.Module 中定义为 nn.Parameter 的张量都会自动注册为模型的参数。它们会被视为模型的可训练参数 ,并且会被包含在模型的 state_dict() 中,也会参与反向传播和优化。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册为模型参数
            self.weight = nn.Parameter(torch.randn(5, 5))
    
        def forward(self, x):
            return x * self.weight
  • 特点

    • 被自动注册为模型的参数,参与梯度计算和更新。
    • 会在 model.parameters() 中找到,并且会在模型保存、加载以及转移设备时自动管理。

2. nn.Module 子类(如 nn.Conv2d, nn.Linear 等)

  • 自动注册 :所有继承自 nn.Module 的子类(如 nn.Conv2d, nn.Linear, nn.ReLU 等)会自动注册为模型的一部分。它们包含的权重和偏置会自动注册为模型参数,并参与训练和保存。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册为模型的一部分
            self.conv = nn.Conv2d(1, 1, 3, 1, 1)
    
        def forward(self, x):
            return self.conv(x)
  • 特点

    • 子模块中的所有 nn.Parameter 自动成为模型的一部分。
    • 例如 nn.Conv2d 的权重和偏置会在 state_dict() 中找到。
    • 可以通过 model.parameters() 获取所有可训练参数。

3. nn.Module 内的属性如果是其他 nn.Module 的实例

  • 当你将另一个 nn.Module 对象作为属性放在自定义模型中时,这个子模块的参数会自动注册为主模块的一部分。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 自动注册子模块
            self.linear1 = nn.Linear(10, 5)
            self.linear2 = nn.Linear(5, 1)
    
        def forward(self, x):
            x = self.linear1(x)
            return self.linear2(x)
  • 特点

    • 子模块会自动注册,子模块的参数也会作为主模块的一部分。
    • 子模块也可以递归地包含其他模块,这些都会自动注册。

4. buffers 使用 register_buffer 显式注册

  • 不会自动注册 :对于不需要训练的张量或常量(例如 BN 层中的均值、方差、位置编码等),需要使用 register_buffer 手动注册。这些张量不会参与梯度更新,但会随着模型保存、加载以及转移到设备。

  • 如何使用

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 手动注册 buffer
            self.register_buffer('my_buffer', torch.randn(5, 5))
    
        def forward(self, x):
            return x + self.my_buffer
  • 特点

    • 不参与梯度计算,但在 state_dict 中可见。
    • 可以通过 .to(device) 自动转移到指定设备。

5. 不会自动注册的对象

  • 普通的 Python 对象、torch.Tensor 或者 list/dict 类型不会自动注册为模型的一部分。你需要手动使用 register_buffer 或者 nn.Parameter 来使其成为模型的成员,否则这些对象不会在模型的 state_dict() 中出现,也不会随着模型迁移到 GPU/CPU。

  • 例子

    python 复制代码
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            # 不会自动注册
            self.tensor = torch.randn(5, 5)  # 这个张量不会自动成为模型的部分
    
        def forward(self, x):
            return x + self.tensor
  • 解决方法 :如果你希望 tensor 也能成为模型的一部分,使用 register_buffer

    python 复制代码
    self.register_buffer('my_buffer', self.tensor)

总结:

  • 自动注册

    • nn.Parameter : 任何 nn.Parameter 类型的属性会自动成为模型的参数。
    • nn.Module 子类 : 任何包含在 nn.Module 中的子模块(如 nn.Conv2d)会自动注册为模型的一部分。
  • 需要手动注册

    • 非可训练的常量(buffers :需要使用 register_buffer 来显式注册,它们不参与梯度计算,但会保存、加载以及转移到设备。
相关推荐
说私域3 分钟前
基于开源 AI 智能名片、S2B2C 商城小程序的用户获取成本优化分析
人工智能·小程序
东胜物联23 分钟前
探寻5G工业网关市场,5G工业网关品牌解析
人工智能·嵌入式硬件·5g
南宫理的日知录32 分钟前
99、Python并发编程:多线程的问题、临界资源以及同步机制
开发语言·python·学习·编程学习
皓74133 分钟前
服饰电商行业知识管理的创新实践与知识中台的重要性
大数据·人工智能·科技·数据分析·零售
coberup41 分钟前
django Forbidden (403)错误解决方法
python·django·403错误
985小水博一枚呀1 小时前
【深度学习滑坡制图|论文解读3】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer
龙哥说跨境1 小时前
如何利用指纹浏览器爬虫绕过Cloudflare的防护?
服务器·网络·python·网络爬虫
AltmanChan1 小时前
大语言模型安全威胁
人工智能·安全·语言模型
985小水博一枚呀1 小时前
【深度学习滑坡制图|论文解读2】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer·迁移学习
数据与后端架构提升之路1 小时前
从神经元到神经网络:深度学习的进化之旅
人工智能·神经网络·学习