pytorch梯度更新方法

一、方法1 autograd

python 复制代码
import torch
from torch import autograd

'''
demo1
'''
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)

y = a ** 2 * x + b * x + c

print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after :', grads[0], grads[1], grads[2])

二、方法2 backward

python 复制代码
import torch
A = torch.tensor(2., requires_grad=True)
B = torch.tensor(.5, requires_grad=True)
E = torch.tensor(1., requires_grad=True)
C = A * B
D = C.exp()
F = D + E
print(F)        # tensor(3.7183, grad_fn=<AddBackward0>) 打印计算结果,可以看到F的grad_fn指向AddBackward,即产生F的运算
print([x.is_leaf for x in [A, B, C, D, E, F]])  # [True, True, False, False, True, False] 打印是否为叶节点,由用户创建,且requires_grad设为True的节点为叶节点
print([x.grad_fn for x in [F, D, C, A]])    # [<AddBackward0 object at 0x7f972de8c7b8>, <ExpBackward object at 0x7f972de8c278>, <MulBackward0 object at 0x7f972de8c2b0>, None]  每个变量的grad_fn指向产生其算子的backward function,叶节点的grad_fn为空
print(F.grad_fn.next_functions) # ((<ExpBackward object at 0x7f972de8c390>, 0), (<AccumulateGrad object at 0x7f972de8c5f8>, 0)) 由于F = D + E, 因此F.grad_fn.next_functions也存在两项,分别对应于D, E两个变量,每个元组中的第一项对应于相应变量的grad_fn,第二项指示相应变量是产生其op的第几个输出。E作为叶节点,其上没有grad_fn,但有梯度累积函数,即AccumulateGrad(由于反传时多出可能产生梯度,需要进行累加)
F.backward(retain_graph=True)   # 进行梯度反传
print(A.grad, B.grad, E.grad)   # tensor(1.3591) tensor(5.4366) tensor(1.) 算得每个变量梯度,与求导得到的相符
print(C.grad, D.grad)   # None None 为节约空间,梯度反传完成后,中间节点的梯度并不会保留

三、总结

深度学习的核心是梯度下降算法,即:

y = wx + b
w(i) = w(i-1) - lr
▽w

b(i) = b(i-1) - lr*▽b

相关推荐
m***记2 分钟前
Python字符串操作:如何判断子串是否存在
linux·服务器·python
AI人工智能+17 分钟前
智能文本抽取:通过OCR、自然语言处理等多项技术,将非结构化文档转化为可读、可分析的数据资产
人工智能·nlp·ocr·文本抽取
这张生成的图像能检测吗17 分钟前
(论文速读)Anyattack: 面向视觉语言模型的大规模自监督对抗性攻击
人工智能·语言模型·clip·视觉语言模型·对抗攻击
gorgeous(๑>؂<๑)23 分钟前
【DeepSeek-OCR系列第一篇】Language Modelling with Pixels【ICLR23】
人工智能·语言模型·自然语言处理·ocr
开放知识图谱24 分钟前
论文浅尝 | LightPROF:一种轻量级推理框架,用于大型语言模型在知识图谱上的应用(AAAI2025)
人工智能·语言模型·自然语言处理·知识图谱
vlln27 分钟前
【论文速读】LLM+AL: 用符号逻辑校准语言模型的规划能力
人工智能·语言模型·自然语言处理
Antonio91540 分钟前
【图像处理】图像错切变换
图像处理·人工智能
小白银子43 分钟前
零基础从头教学Linux(Day 56)
linux·运维·python
文火冰糖的硅基工坊44 分钟前
[人工智能-大模型-85]:大模型应用层 - AI/AR眼镜:华为智能眼镜、苹果智能眼镜、Google Glass智能眼镜的软硬件技术架构
人工智能·华为·ar
B站计算机毕业设计之家1 小时前
计算机视觉:python手写数字识别系统 手写数字检测 CNN算法 卷积神经网络 OpenCV和Keras模型 大数据毕业设计(建议收藏)✅
python·神经网络·opencv·计算机视觉·cnn·手写数字·数字识别