说下register_buffer和Parameter的异同
相同点
| 方面 |
描述 |
| 追踪 |
都会被加入 state_dict(模型保存时会保存下来)。 |
与 Module 的绑定 |
都会随着模型移动到 cuda / cpu / float() 等而自动迁移。 |
都是 nn.Module 的一部分 |
都可以通过模块属性访问,如 self.x。 |
不同点
| 方面 |
torch.nn.Parameter |
register_buffer |
| 是否是可训练参数 |
✅ 是,会被视为模型需要优化的参数(model.parameters() 中包含) |
❌ 否,不会被优化器更新 |
| 梯度计算 |
默认 requires_grad=True,参与反向传播 |
默认 requires_grad=False,不参与反向传播 |
| 用途场景 |
模型的权重、偏置等需要学习的参数 |
均值、方差、mask、位置编码等常量或状态,如 BatchNorm 中的 running mean/var |
| 注册方式 |
self.w = nn.Parameter(tensor) 或 self.register_parameter("w", nn.Parameter(...)) |
self.register_buffer("buf", tensor) |
是否显示在 parameters() 中 |
✅ 会显示 |
❌ 不会显示 |
| 是否能直接赋值注册 |
✅ 可以直接赋值 |
❌ 必须通过 register_buffer() 注册,否则不会记录到 state_dict |
使用建议
| 情境 |
推荐使用 |
| 需要优化 |
nn.Parameter |
| 只做记录或参与计算但不优化 |
register_buffer |
| 实现自定义模块(如 BatchNorm)时的状态 |
register_buffer |
| 使用位置编码、attention mask |
register_buffer |
| 模型保存中需要但不训练 |
register_buffer |
这里我自己写了一个测试代码,分别运行ToyModel1 2 3 保存并读取,相信会对这两个函数有很深刻的认识。
import torch
import torch.nn as nn
import torch.nn.functional as F
class ToyModel(nn.Module):
def __init__(self, inChannels, outChannels):
super().__init__()
self.a1 = 1 # 实例成员,不会保存在ckpt中
self.a2 = 2
self.linear = nn.Linear(inChannels, outChannels)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
out = self.linear(x)
return out
class ToyModel2(nn.Module):
def __init__(self, inChannels, outChannels):
super().__init__()
self.a1 = 1 # 实例成员,不会保存在ckpt中
self.a2 = 2
self.linear = nn.Linear(inChannels, outChannels)
self.init_weights()
self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型参数,requires_grad=True, 保存进ckpt
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
out = self.linear(x)
out += self.b1
return out
class ToyModel3(nn.Module):
def __init__(self, inChannels, outChannels):
super().__init__()
self.a1 = 1 # 实例成员,不会保存在ckpt中
self.a2 = 2
self.linear = nn.Linear(inChannels, outChannels)
self.init_weights()
self.b1 = nn.Parameter(torch.randn(outChannels),)
self.register_buffer("c1", torch.ones_like(self.b1), persistent=True) # 类成员,requires_grad=False, 保存进ckpt,用于保存需要直接计算的常量,可以用self.c1访问
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
out = self.linear(x)
out += self.b1
out += self.c1
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from pathlib import Path
from models import ToyModel2, ToyModel, ToyModel3
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s')
if __name__ == "__main__":
savePath = Path("toymodel3.pth")
logger = logging.getLogger(__name__)
inp = torch.randn(3, 5)
model = ToyModel3(inp.size(1), inp.size(1) * 2)
pred = model(inp)
logger.info(f"{pred.size()=}")
for m in model.modules():
logger.info(m)
for name, param in model.named_parameters():
logger.info(f"{name = }, {param.size() = }, {param.requires_grad=}")
for name, buffer in model.named_buffers():
logger.info(f"{name = }, {buffer.size() = }")
torch.save(model.state_dict(), savePath)
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from models import ToyModel, ToyModel2, ToyModel3
if __name__ == "__main__":
savePath = Path("toymodel3.pth")
inp = torch.randn(3, 5)
model = ToyModel3(inp.size(1), inp.size(1) * 2)
ckpt = torch.load(savePath, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt)
pred = model(inp)
print(f"{pred.size()=}")
for m in model.modules():
print(m)
for name, param in model.named_parameters():
print(f"{name = }, {param.size() = }, {param.requires_grad=}")
for name, buffer in model.named_buffers():
print(f"{name = }, {buffer.size() = }")