基于PyTorch的深度学习3——非标量反向传播

以下内容由大语言模型生成:

张量,边代表从输入到输出的操作。当我们对一个标量调用backward()函数时,PyTorch能够直接计算该标量关于参与运算的所有张量的梯度。这是因为导数的概念最开始是针对标量值定义的。

为什么需要传入gradient参数?

当你有一个非标量(即形状不为1的张量)并希望对其求导时,实际上你是在尝试计算这个张量关于某些变量的雅可比矩阵(Jacobian Matrix)。然而,PyTorch并不直接支持这种操作,因为它复杂化了自动微分的流程。为了简化这个问题,PyTorch要求在这种情况下提供一个额外的gradient参数,这个参数同样是一个张量,并且其形状必须与调用backward()的那个张量相同。这个gradient实际上扮演了一个"权重"的角色,用于将多维导数的问题转化为标量导数的问题。

如何理解这个过程?

考虑你的例子:假设你有一个损失向量loss=(y_1, y_2, ..., y_m),你想要根据它来更新一些模型参数。由于loss不是一个标量,你不能直接对它调用backward()。这时,你可以引入一个向量v=(v_1, v_2, ..., v_m),然后将lossv进行点乘得到一个新的标量loss*v^T(这里v^T表示v的转置,虽然在实际代码中我们不会这样写,这只是为了表达数学概念)。这个新生成的标量可以被用来调用backward()方法,从而触发梯度的计算。

具体来说,这样做实际上是计算了loss的雅可比矩阵与v的乘积。换句话说,原本你需要计算的是雅可比矩阵,但现在通过点乘转换后,你只需计算一个标量关于所需变量的梯度。这使得PyTorch的自动微分机制能够处理这种情况,而不需要直接支持张量对张量的求导。

1)定义叶子节点及计算节点。

python 复制代码
import torch

#定义叶节点张量x,形状为1x2
x=torch.tensor([[2.3]],dtype=torch.float,requires_grad=True)

#初始化Jacobian矩阵
J=torch.zeros(2,2)

#初始化目标张量,形状为1x2
y=torch.zeros(1,2)

#定义y与x之间的映射关系:
#y1=x1**2+3*x2,y2=x2**2+2*x1
y[0,0]=x[0,0]**2+3*x[0,1]
y[0,1]=x[0,1]**2+2*x[0,0]

2)手工计算y对x的梯度

y对x的梯度是一个雅可比矩阵,可以通过手动计算值

python 复制代码
#生成y1对x的梯度
y.backward(torch.Tensor([[1, 0]]),retain_graph=True)
##gradient的作用:传入的gradient张量扮演了一个权重的角色,它决定了每个元素在最终梯度计算中的重要
##本质上,这是将雅可比矩阵乘以这个gradient向量,从而将多维导数的问题简化为一维标量导数的问题。

J[0]=x.grad

#梯度是累加的,故需要对x的梯度清零
x.grad = torch.zeros_like(x.grad)

#生成y2对x的梯度
y.backward(torch.Tensor([[0, 1]]))

J[1]=x.grad
#显示jacobian矩阵的值
print(J)
相关推荐
王国强20094 分钟前
现代循环神经网络3-深度循环神经网络
深度学习
轻松Ai享生活6 分钟前
如何使用 Python 调用 DeepSeek-R1 API?
人工智能
ππ记录15 分钟前
Manus全球首个通用Agent,Manus AI:Agent应用的ChatGPT时刻
人工智能·chatgpt·manus详细介绍·manus介绍·manus详细应用·manus教程·manus详情介绍
Wis4e33 分钟前
基于PyTorch的深度学习5——神经网络工具箱
pytorch·深度学习·神经网络
Timbo2号34 分钟前
OpenAI API调用实战:跨境支付验证的技术解决方案
人工智能·github
WPG大大通35 分钟前
驱动 AI 边缘计算新时代!高性能 i.MX 95 应用平台引领未来
人工智能·ai·汽车·边缘计算·方案·工业·大大通
Json_36 分钟前
Cursor 编辑器使用文档,一款AI 代码编辑器, 人人都可以写代码!!!Tab,Tab,再来一次 Tab
人工智能·openai·cursor
闲人编程39 分钟前
调试与性能优化技巧
人工智能·pytorch·深度学习
凉拌三丝1 小时前
Llama Index案例实战(三)状态的设置与读取
人工智能·ai 编程
YoseZang1 小时前
【机器学习和深度学习】分类问题通用评价指标:精确率、召回率、准确率和混淆矩阵
深度学习·机器学习·分类算法