这篇文章讲解很清晰,以下内容仅做补充,探讨哪些对象需要手动注册,哪些会自动注册。
在 PyTorch 中,哪些对象会自动注册为模型的一部分取决于它们的类型以及你如何定义它们。下面列出不需要手动注册、会自动注册的几种情况:
1. nn.Parameter
-
自动注册 :任何你在
nn.Module
中定义为nn.Parameter
的张量都会自动注册为模型的参数。它们会被视为模型的可训练参数 ,并且会被包含在模型的state_dict()
中,也会参与反向传播和优化。 -
如何使用:
pythonclass MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # 自动注册为模型参数 self.weight = nn.Parameter(torch.randn(5, 5)) def forward(self, x): return x * self.weight
-
特点:
- 被自动注册为模型的参数,参与梯度计算和更新。
- 会在
model.parameters()
中找到,并且会在模型保存、加载以及转移设备时自动管理。
2. nn.Module
子类(如 nn.Conv2d
, nn.Linear
等)
-
自动注册 :所有继承自
nn.Module
的子类(如nn.Conv2d
,nn.Linear
,nn.ReLU
等)会自动注册为模型的一部分。它们包含的权重和偏置会自动注册为模型参数,并参与训练和保存。 -
如何使用:
pythonclass MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # 自动注册为模型的一部分 self.conv = nn.Conv2d(1, 1, 3, 1, 1) def forward(self, x): return self.conv(x)
-
特点:
- 子模块中的所有
nn.Parameter
自动成为模型的一部分。 - 例如
nn.Conv2d
的权重和偏置会在state_dict()
中找到。 - 可以通过
model.parameters()
获取所有可训练参数。
- 子模块中的所有
3. 在 nn.Module
内的属性如果是其他 nn.Module
的实例
-
当你将另一个
nn.Module
对象作为属性放在自定义模型中时,这个子模块的参数会自动注册为主模块的一部分。 -
如何使用:
pythonclass MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # 自动注册子模块 self.linear1 = nn.Linear(10, 5) self.linear2 = nn.Linear(5, 1) def forward(self, x): x = self.linear1(x) return self.linear2(x)
-
特点:
- 子模块会自动注册,子模块的参数也会作为主模块的一部分。
- 子模块也可以递归地包含其他模块,这些都会自动注册。
4. buffers
使用 register_buffer
显式注册
-
不会自动注册 :对于不需要训练的张量或常量(例如 BN 层中的均值、方差、位置编码等),需要使用
register_buffer
手动注册。这些张量不会参与梯度更新,但会随着模型保存、加载以及转移到设备。 -
如何使用:
pythonclass MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # 手动注册 buffer self.register_buffer('my_buffer', torch.randn(5, 5)) def forward(self, x): return x + self.my_buffer
-
特点:
- 不参与梯度计算,但在
state_dict
中可见。 - 可以通过
.to(device)
自动转移到指定设备。
- 不参与梯度计算,但在
5. 不会自动注册的对象
-
普通的 Python 对象、
torch.Tensor
或者list/dict
类型不会自动注册为模型的一部分。你需要手动使用register_buffer
或者nn.Parameter
来使其成为模型的成员,否则这些对象不会在模型的state_dict()
中出现,也不会随着模型迁移到 GPU/CPU。 -
例子:
pythonclass MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # 不会自动注册 self.tensor = torch.randn(5, 5) # 这个张量不会自动成为模型的部分 def forward(self, x): return x + self.tensor
-
解决方法 :如果你希望
tensor
也能成为模型的一部分,使用register_buffer
:pythonself.register_buffer('my_buffer', self.tensor)
总结:
-
自动注册:
nn.Parameter
: 任何nn.Parameter
类型的属性会自动成为模型的参数。nn.Module
子类 : 任何包含在nn.Module
中的子模块(如nn.Conv2d
)会自动注册为模型的一部分。
-
需要手动注册:
- 非可训练的常量(
buffers
) :需要使用register_buffer
来显式注册,它们不参与梯度计算,但会保存、加载以及转移到设备。
- 非可训练的常量(