pytorch梯度更新方法

一、方法1 autograd

python 复制代码
import torch
from torch import autograd

'''
demo1
'''
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)

y = a ** 2 * x + b * x + c

print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after :', grads[0], grads[1], grads[2])

二、方法2 backward

python 复制代码
import torch
A = torch.tensor(2., requires_grad=True)
B = torch.tensor(.5, requires_grad=True)
E = torch.tensor(1., requires_grad=True)
C = A * B
D = C.exp()
F = D + E
print(F)        # tensor(3.7183, grad_fn=<AddBackward0>) 打印计算结果,可以看到F的grad_fn指向AddBackward,即产生F的运算
print([x.is_leaf for x in [A, B, C, D, E, F]])  # [True, True, False, False, True, False] 打印是否为叶节点,由用户创建,且requires_grad设为True的节点为叶节点
print([x.grad_fn for x in [F, D, C, A]])    # [<AddBackward0 object at 0x7f972de8c7b8>, <ExpBackward object at 0x7f972de8c278>, <MulBackward0 object at 0x7f972de8c2b0>, None]  每个变量的grad_fn指向产生其算子的backward function,叶节点的grad_fn为空
print(F.grad_fn.next_functions) # ((<ExpBackward object at 0x7f972de8c390>, 0), (<AccumulateGrad object at 0x7f972de8c5f8>, 0)) 由于F = D + E, 因此F.grad_fn.next_functions也存在两项,分别对应于D, E两个变量,每个元组中的第一项对应于相应变量的grad_fn,第二项指示相应变量是产生其op的第几个输出。E作为叶节点,其上没有grad_fn,但有梯度累积函数,即AccumulateGrad(由于反传时多出可能产生梯度,需要进行累加)
F.backward(retain_graph=True)   # 进行梯度反传
print(A.grad, B.grad, E.grad)   # tensor(1.3591) tensor(5.4366) tensor(1.) 算得每个变量梯度,与求导得到的相符
print(C.grad, D.grad)   # None None 为节约空间,梯度反传完成后,中间节点的梯度并不会保留

三、总结

深度学习的核心是梯度下降算法,即:

y = wx + b
w(i) = w(i-1) - lr
▽w

b(i) = b(i-1) - lr*▽b

相关推荐
Moshow郑锴1 小时前
人工智能中的(特征选择)数据过滤方法和包裹方法
人工智能
TY-20252 小时前
【CV 目标检测】Fast RCNN模型①——与R-CNN区别
人工智能·目标检测·目标跟踪·cnn
CareyWYR2 小时前
苹果芯片Mac使用Docker部署MinerU api服务
人工智能
失散133 小时前
自然语言处理——02 文本预处理(下)
人工智能·自然语言处理
wyiyiyi3 小时前
【Web后端】Django、flask及其场景——以构建系统原型为例
前端·数据库·后端·python·django·flask
mit6.8243 小时前
[1Prompt1Story] 滑动窗口机制 | 图像生成管线 | VAE变分自编码器 | UNet去噪神经网络
人工智能·python
sinat_286945193 小时前
AI应用安全 - Prompt注入攻击
人工智能·安全·prompt
没有bug.的程序员3 小时前
JVM 总览与运行原理:深入Java虚拟机的核心引擎
java·jvm·python·虚拟机
甄超锋4 小时前
Java ArrayList的介绍及用法
java·windows·spring boot·python·spring·spring cloud·tomcat
迈火4 小时前
ComfyUI-3D-Pack:3D创作的AI神器
人工智能·gpt·3d·ai·stable diffusion·aigc·midjourney