pytorch梯度更新方法

一、方法1 autograd

python 复制代码
import torch
from torch import autograd

'''
demo1
'''
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)

y = a ** 2 * x + b * x + c

print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after :', grads[0], grads[1], grads[2])

二、方法2 backward

python 复制代码
import torch
A = torch.tensor(2., requires_grad=True)
B = torch.tensor(.5, requires_grad=True)
E = torch.tensor(1., requires_grad=True)
C = A * B
D = C.exp()
F = D + E
print(F)        # tensor(3.7183, grad_fn=<AddBackward0>) 打印计算结果,可以看到F的grad_fn指向AddBackward,即产生F的运算
print([x.is_leaf for x in [A, B, C, D, E, F]])  # [True, True, False, False, True, False] 打印是否为叶节点,由用户创建,且requires_grad设为True的节点为叶节点
print([x.grad_fn for x in [F, D, C, A]])    # [<AddBackward0 object at 0x7f972de8c7b8>, <ExpBackward object at 0x7f972de8c278>, <MulBackward0 object at 0x7f972de8c2b0>, None]  每个变量的grad_fn指向产生其算子的backward function,叶节点的grad_fn为空
print(F.grad_fn.next_functions) # ((<ExpBackward object at 0x7f972de8c390>, 0), (<AccumulateGrad object at 0x7f972de8c5f8>, 0)) 由于F = D + E, 因此F.grad_fn.next_functions也存在两项,分别对应于D, E两个变量,每个元组中的第一项对应于相应变量的grad_fn,第二项指示相应变量是产生其op的第几个输出。E作为叶节点,其上没有grad_fn,但有梯度累积函数,即AccumulateGrad(由于反传时多出可能产生梯度,需要进行累加)
F.backward(retain_graph=True)   # 进行梯度反传
print(A.grad, B.grad, E.grad)   # tensor(1.3591) tensor(5.4366) tensor(1.) 算得每个变量梯度,与求导得到的相符
print(C.grad, D.grad)   # None None 为节约空间,梯度反传完成后,中间节点的梯度并不会保留

三、总结

深度学习的核心是梯度下降算法,即:

y = wx + b
w(i) = w(i-1) - lr
▽w

b(i) = b(i-1) - lr*▽b

相关推荐
TGITCIC2 小时前
能源AI天团:多智能体如何破解行业复杂任务
人工智能·能源·新能源·ai agent·大模型ai·ai能源·能源大模型
我爱计算机视觉3 小时前
ICCV 2025 | VideoOrion: 将视频中的物体动态编码进大语言模型,理解视频涨点10%以上!
人工智能·语言模型·自然语言处理
Lululaurel3 小时前
深度模型瘦身术:从100MB到5MB的工业级压缩实战
pytorch·python·机器学习·模型压缩·模型优化·边缘部署
WWZZ20254 小时前
ORB_SLAM2原理及代码解析:Tracking::CreateInitialMapMonocular() 函数
人工智能·opencv·算法·计算机视觉·机器人·slam·感知
WWZZ20254 小时前
ORB_SLAM2原理及代码解析:Tracking::MonocularInitialization() 函数
人工智能·opencv·算法·计算机视觉·机器人·感知·单目相机
那雨倾城4 小时前
PiscCode:基于OpenCV的前景物体检测
图像处理·python·opencv·计算机视觉
eve杭5 小时前
解锁数据主权与极致性能:AI本地部署的全面指南
大数据·人工智能·5g·ai
数字时代全景窗5 小时前
商业航天与数字经济(一):从4G、5G得与失,看6G时代商业航天如何成为新经济引擎?
大数据·人工智能·5g
一粒马豆6 小时前
flask_socketio+pyautogui实现的具有加密传输功能的极简远程桌面
python·flask·pyautogui·远程桌面·flask_socketio
F_D_Z6 小时前
【一文理解】下采样与上采样区别
人工智能·深度学习·计算机视觉