梯度反向传播过程是如何处理repeat函数的

举个例子

python 复制代码
import torch

# 假设有一个简单的计算图
x = torch.tensor(2.0, requires_grad=True)
c=x.repeat(2,1)
y = c ** 2


print(x,"\n",c,"\n",y)
torch.sum(y)

z = torch.sum(y)

# 计算梯度
grads = torch.autograd.grad(z, x)
# 打印梯度
print("Gradient of z w.r.t. x:", grads)

输出

python 复制代码
tensor(2., requires_grad=True) 
 tensor([[2.],
        [2.]], grad_fn=<RepeatBackward0>) 
 tensor([[4.],
        [4.]], grad_fn=<PowBackward0>)
Gradient of z w.r.t. x: (tensor(8.),)

在你的代码中,首先创建了一个张量 x,然后使用 repeat 函数将其在第0维(行)上重复2次,形成一个形状为 (2, 1) 的张量 c。然后,计算了 c 的平方,得到张量 y。最后,对 y 求和,得到标量张量 z。

因为 z 是一个标量,所以可以对它对 x 求梯度。现在让我们来分析一下为什么计算出的梯度是8:

首先,我们有 y = c ** 2,其中 c 是重复了两次的 x。所以 y 的值为 [4.0, 4.0]。然后,我们对 y 求和,得到 z = 8.0。

接下来,我们要计算 z 对 x 的梯度。由于 z 是标量,所以 torch.autograd.grad 函数的返回值是一个包含一个元素的元组,即 (grad_x,)。因此,grads 的值是一个包含一个张量的元组。在这种情况下,梯度的计算是通过链式法则完成的,即 dz/dx = dz/dy * dy/dc * dc/dx。在这里,dz/dy 是1,因为 z 是 y 的总和,dy/dc 是2,因为 y 中每个元素对 c 的导数都是2,dc/dx 是2,因为 c 是通过将 x 重复两次得到的,所以 x 对 c 的导数是2。因此, dz/dx = 1 * 2 * 2 = 4 * 2 = 8。

因此,计算出的梯度是8。

太神奇了,dc/dx的结果就是重复的次数!,那其实repeat函数的效果相对于放大了repeat对象的学习率,放大倍数就是repeat的次数,所以慎用repeat呀!

相关推荐
星期天要睡觉13 分钟前
机器学习深度学习 所需数据的清洗实战案例 (结构清晰、万字解析、完整代码)包括机器学习方法预测缺失值的实践
人工智能·深度学习·机器学习·数据挖掘
让心淡泊14423 分钟前
DAY 50 预训练模型+CBAM模块
python
renhongxia134 分钟前
大模型微调RAG、LORA、强化学习
人工智能·深度学习·算法·语言模型
dundunmm1 小时前
【论文阅读】SIMBA: single-cell embedding along with features(1)
论文阅读·深度学习·神经网络·embedding·生物信息·单细胞·多组学
BYSJMG1 小时前
计算机大数据毕业设计推荐:基于Spark的气候疾病传播可视化分析系统【Hadoop、python、spark】
大数据·hadoop·python·信息可视化·spark·django·课程设计
抠头专注python环境配置2 小时前
OCR库pytesseract安装保姆级教程
python·ocr·conda
山烛2 小时前
矿物分类系统开发笔记(二):模型训练[删除空缺行]
人工智能·笔记·python·机器学习·分类·数据挖掘
大得3693 小时前
django生成迁移文件,执行生成到数据库
后端·python·django
大志说编程3 小时前
LangChain框架入门17: 手把手教你创建LLM工具
python·langchain·ai编程
R-G-B3 小时前
【P38 6】OpenCV Python——图片的运算(算术运算、逻辑运算)加法add、subtract减法、乘法multiply、除法divide
人工智能·python·opencv·图片的运算·图片加法add·图片subtract减法·图片乘法multiply