pytorch register_buffer介绍

在 PyTorch 中,register_buffernn.Module 类的一个方法,用于注册一个 buffer,即模型中需要持久保存但不参与梯度更新的张量。这些 buffer 常用于存储模型中的常数或其他固定值(如位置编码、均值、方差等),这些值在前向传播中会被用到但不会在训练中被优化更新。

register_buffer 的作用

  1. 保存和加载模型状态 :通过 register_buffer 注册的张量会被包含在模型的 state_dict 中,这样它们会在模型保存时一起存储,在加载时恢复,保持模型完整性。

  2. 设备迁移register_buffer 注册的张量会自动随模型一起移动到指定设备。例如,使用 model.to(device) 时,buffer 张量会被移动到 device,无需手动将它们转移到 CPU 或 GPU。

  3. 不参与反向传播和梯度更新 :buffer 并不是 nn.Parameter,因此它不会参与反向传播,也不会被优化器更新。这对于存储常量值尤其适用。

使用方法

register_buffer 的语法如下:

复制代码
register_buffer(name, tensor)
  • name:字符串,表示 buffer 的名称。该名称会在模型 state_dict 中作为键。
  • tensor:一个 torch.Tensor,表示要注册为 buffer 的张量。通常这个张量的 requires_grad 属性为 False

示例

例如,在实现位置编码时,我们可以将其注册为一个 buffer:

复制代码
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # 初始化位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)

        # 将位置编码注册为 buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

在这个例子中:

  • self.register_buffer('pe', pe)pe 注册为 buffer。这样,pe 在模型保存和加载时会自动包含在内。
  • pe 不会被优化器更新,不会参与反向传播,因此适合存储这种常量张量。
  • 使用 model.to(device) 时,pe 会自动迁移到正确设备。

使用 register_buffer 的场景

register_buffer 常用于以下场景:

  • 存储固定的模型参数:例如 BatchNorm 层的均值和方差。
  • 存储计算所需的固定值:如位置编码、固定掩码或固定的权重。
  • 用于设备无关性:在定义网络结构时,可以使用 buffer 来确保模型在 GPU 和 CPU 之间自由切换,不会遗漏关键的张量。

注意事项

  • 不要将 buffer 误用为训练参数 。如果某个张量需要被训练或优化,那么它应该是 nn.Parameter,而非 buffer。
  • 命名冲突:buffer 的名字不能和模型已有的属性或方法重名,否则会导致错误。

使用 register_buffer 可以使模型结构更清晰、更易于维护,同时减少手动迁移张量的工作量。

相关推荐
火山引擎开发者社区5 小时前
技术速递|使用 GitHub Copilot CLI 构建 Emoji 列表生成器
人工智能
weelinking5 小时前
【产品】12_接入数据库——让数据永久保存
jvm·数据库·python·react.js·数据挖掘·前端框架·产品经理
codefan※6 小时前
干掉“幻觉“实战:如何构建企业级知识图谱增强 RAG
人工智能·知识图谱
wukangjupingbb6 小时前
传统基于药物 SMILES 序列和蛋白质氨基酸序列的 DTI(Drug-Target Interaction)预测方法的缺陷
人工智能
沪漂阿龙6 小时前
Codex 额度重置周期变化:AI 编程免费试玩时代正在结束
人工智能
程序大视界6 小时前
【Python系列课程】Python正则表达式(下):环视、命名分组与日志实战
开发语言·python·正则表达式
TickDB6 小时前
美股行情 API 接入避坑:REST 快照、WebSocket 推送、盘前盘后数据的边界
人工智能·python·websocket·行情数据 api
装不满的克莱因瓶6 小时前
深入理解卷积神经网络(CNN)——从原理到代码实践
人工智能·神经网络·cnn
完成大叔6 小时前
模块二,Agent知识图谱的工具链思考
人工智能
lauo6 小时前
ibbot手机发布:搭载poplang技术 + token节点经济,革新AI手机体验
人工智能·智能手机