Pytorch中register_buffer和torch.nn.Parameter的异同

说下register_buffer和Parameter的异同

相同点

方面 描述
追踪 都会被加入 state_dict(模型保存时会保存下来)。
Module 的绑定 都会随着模型移动到 cuda / cpu / float() 等而自动迁移。
都是 nn.Module 的一部分 都可以通过模块属性访问,如 self.x

不同点

方面 torch.nn.Parameter register_buffer
是否是可训练参数 ✅ 是,会被视为模型需要优化的参数(model.parameters() 中包含) ❌ 否,不会被优化器更新
梯度计算 默认 requires_grad=True,参与反向传播 默认 requires_grad=False,不参与反向传播
用途场景 模型的权重、偏置等需要学习的参数 均值、方差、mask、位置编码等常量或状态,如 BatchNorm 中的 running mean/var
注册方式 self.w = nn.Parameter(tensor)self.register_parameter("w", nn.Parameter(...)) self.register_buffer("buf", tensor)
是否显示在 parameters() ✅ 会显示 ❌ 不会显示
是否能直接赋值注册 ✅ 可以直接赋值 ❌ 必须通过 register_buffer() 注册,否则不会记录到 state_dict

使用建议

情境 推荐使用
需要优化 nn.Parameter
只做记录或参与计算但不优化 register_buffer
实现自定义模块(如 BatchNorm)时的状态 register_buffer
使用位置编码、attention mask register_buffer
模型保存中需要但不训练 register_buffer

这里我自己写了一个测试代码,分别运行ToyModel1 2 3 保存并读取,相信会对这两个函数有很深刻的认识。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class ToyModel(nn.Module):
    def __init__(self, inChannels, outChannels):
        super().__init__()

        self.a1 = 1 # 实例成员,不会保存在ckpt中
        self.a2 = 2
        self.linear = nn.Linear(inChannels, outChannels)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        out = self.linear(x)
        return out


class ToyModel2(nn.Module):
    def __init__(self, inChannels, outChannels):
        super().__init__()

        self.a1 = 1 # 实例成员,不会保存在ckpt中
        self.a2 = 2
        self.linear = nn.Linear(inChannels, outChannels)

        self.init_weights()

        self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型参数,requires_grad=True, 保存进ckpt

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        out = self.linear(x)
        out += self.b1
        return out


class ToyModel3(nn.Module):
    def __init__(self, inChannels, outChannels):
        super().__init__()

        self.a1 = 1 # 实例成员,不会保存在ckpt中
        self.a2 = 2
        self.linear = nn.Linear(inChannels, outChannels)

        self.init_weights()

        self.b1 = nn.Parameter(torch.randn(outChannels),)
        self.register_buffer("c1", torch.ones_like(self.b1), persistent=True) # 类成员,requires_grad=False, 保存进ckpt,用于保存需要直接计算的常量,可以用self.c1访问

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        out = self.linear(x)
        out += self.b1
        out += self.c1
        return out
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from pathlib import Path

from models import ToyModel2, ToyModel, ToyModel3

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s')


if __name__ == "__main__":
    savePath = Path("toymodel3.pth")
    logger = logging.getLogger(__name__)

    inp = torch.randn(3, 5)
    model = ToyModel3(inp.size(1), inp.size(1) * 2)
    pred = model(inp)

    logger.info(f"{pred.size()=}")

    for m in model.modules():
        logger.info(m)

    for name, param in model.named_parameters():
        logger.info(f"{name = }, {param.size() = }, {param.requires_grad=}")

    for name, buffer in model.named_buffers():
        logger.info(f"{name = }, {buffer.size() = }")

    torch.save(model.state_dict(), savePath)
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path

from models import ToyModel, ToyModel2, ToyModel3


if __name__ == "__main__":
    savePath = Path("toymodel3.pth")

    inp = torch.randn(3, 5)
    model = ToyModel3(inp.size(1), inp.size(1) * 2)
    ckpt = torch.load(savePath, map_location="cpu", weights_only=True)
    model.load_state_dict(ckpt)
    pred = model(inp)

    print(f"{pred.size()=}")

    for m in model.modules():
        print(m)

    for name, param in model.named_parameters():
        print(f"{name = }, {param.size() = }, {param.requires_grad=}")

    for name, buffer in model.named_buffers():
        print(f"{name = }, {buffer.size() = }")