self.register_buffer方法使用解析(pytorch)

self.register_buffer就是pytorch框架用来保存不更新参数的方法。

列子如下:

c 复制代码
self.register_buffer("position_emb", torch.randn((5, 3)))

第一个参数position_emb传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数torch.randn((5, 3),并一次初始化后保存于模型,不会有梯度传播给它,能被模型的model.state_dict()记录下来,可以理解为模型的常数。当然,你想保留固定值,使用如下代码:

c 复制代码
self.register_buffer("position_emb", torch.tensorrt([[2,5],[8,9]]))

进一步探讨训练对该参数是否有影响,答案是:没影响。具体可看下面实现的列子代码:

c 复制代码
import torch
from torch.nn import Embedding

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = Embedding(5, 3)
        self.register_buffer("position_emb", torch.randn((5, 3)))
    def forward(self,vec):
        input = torch.tensor([0, 1, 2, 3, 4])
        emb_vec1 = self.emb(input)
        emb_vec1=emb_vec1+self.position_emb
        output = torch.einsum('ik, kj -> ij', emb_vec1, vec)
        return output
def simple_train():
    model = Model()
    vec = torch.randn((3, 1))
    label = torch.Tensor(5, 1).fill_(3)
    loss_fun = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.015)
    print('初始化后position_emb参数:\n',model.position_emb)
    for iter_num in range(100):
        output = model(vec)
        loss = loss_fun(output, label)
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
    print('训练后position_emb参数:\n', model.position_emb)

if __name__ == '__main__':
   simple_train()  # 训练与保存权重

实现结果如下:

相关推荐
埃菲尔铁塔_CV算法25 分钟前
深度学习神经网络创新点方向
人工智能·深度学习·神经网络
艾思科蓝-何老师【H8053】43 分钟前
【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)
人工智能·信号处理·论文发表·香港中文大学
秀儿还能再秀1 小时前
机器学习——简单线性回归、逻辑回归
笔记·python·学习·机器学习
weixin_452600691 小时前
《青牛科技 GC6125:驱动芯片中的璀璨之星,点亮 IPcamera 和云台控制(替代 BU24025/ROHM)》
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·智能充电枪
学术搬运工1 小时前
【珠海科技学院主办,暨南大学协办 | IEEE出版 | EI检索稳定 】2024年健康大数据与智能医疗国际会议(ICHIH 2024)
大数据·图像处理·人工智能·科技·机器学习·自然语言处理
右恩1 小时前
AI大模型重塑软件开发:流程革新与未来展望
人工智能
图片转成excel表格2 小时前
WPS Office Excel 转 PDF 后图片丢失的解决方法
人工智能·科技·深度学习
阿_旭2 小时前
如何使用OpenCV和Python进行相机校准
python·opencv·相机校准·畸变校准
幸运的星竹2 小时前
使用pytest+openpyxl做接口自动化遇到的问题
python·自动化·pytest
ApiHug2 小时前
ApiSmart x Qwen2.5-Coder 开源旗舰编程模型媲美 GPT-4o, ApiSmart 实测!
人工智能·spring boot·spring·ai编程·apihug