【内涵】深度学习中的三种变量及pytorch中对应的三种tensor

程序是对现实世界/需求的映射,pytorch也不例外。在深度学习领域中,一般所需要的三种变量及pytorch中对应的三种tensor总结如下:

  1. 不需要反向传播来更新,也不需要保存在模型的文件参数中:这种对应的就是普通的tensor, 例如模型的图片输入tensor和模型的标签tensor。

  2. 不需要反向传播来更新,但是需要保存在模型的文件参数中用于推理的时候加载:这种就对应nn.Module.buffer,最经典的如batch noarm层中的mean和std,这在训练的时候不是反向传播更新而是计算出来的,模型训练完成后,会保存在参数文件中,模型被加载推理的时候,可以被取到。另外,就是如果之前做过目标检测任务的话,有一篇经典的文章gfl, 它也用到了这种变量:

    复制代码
         self.register_buffer('project',
                              torch.linspace(0, self.reg_max, self.reg_max + 1))

这样的话,就会有一个project变量被保存在参数文件中。

  1. 需要反向传播,也需要保存在模型参数文件中用于推理的时候加载:这种就对应nn.Parameter。当然所有层的weights, bias都是这样的变量。另外一个例子,比如说vit论文中的可学习的一个cls token也是这种变量

    self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))

这种也不用记忆,尤其是大模型时代。但是一般想好自己的需要(是否反向传播,是否保存至模型参数文件中),这种如何prompt,如何选择还是要知道,因此总结一下这个小点,作为自己的笔记。

相关推荐
2301_777599371 小时前
mysql如何进行数据库容量规划_评估磁盘空间增长趋势
jvm·数据库·python
aq55356001 小时前
PHP vs Python:30秒看懂核心区别
开发语言·python·php
xwz小王子1 小时前
多视角视频扩散策略:一种三维时空-觉察视频动作模型
人工智能·音视频
我是无敌小恐龙1 小时前
Java SE 零基础入门Day01 超详细笔记(开发前言+环境搭建+基础语法)
java·开发语言·人工智能·opencv·spring·机器学习
Ww.xh2 小时前
规避GCJ02偏移的坐标统一方案
人工智能
深圳市九鼎创展科技2 小时前
MT8883 vs RK3588 开发板全面对比:选型与场景落地指南
大数据·linux·人工智能·嵌入式硬件·ubuntu
CareyWYR2 小时前
AI Coding 订阅的集体退潮:从狂欢到收紧,中间只隔了一个季度
人工智能
NineData2 小时前
NineData 亮相香港国际创科展 InnoEX 2026,以 AI 加速布局全球市场
运维·数据库·人工智能·ninedata·新闻资讯·玖章算术