pytorch梯度更新方法

一、方法1 autograd

python 复制代码
import torch
from torch import autograd

'''
demo1
'''
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)

y = a ** 2 * x + b * x + c

print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after :', grads[0], grads[1], grads[2])

二、方法2 backward

python 复制代码
import torch
A = torch.tensor(2., requires_grad=True)
B = torch.tensor(.5, requires_grad=True)
E = torch.tensor(1., requires_grad=True)
C = A * B
D = C.exp()
F = D + E
print(F)        # tensor(3.7183, grad_fn=<AddBackward0>) 打印计算结果,可以看到F的grad_fn指向AddBackward,即产生F的运算
print([x.is_leaf for x in [A, B, C, D, E, F]])  # [True, True, False, False, True, False] 打印是否为叶节点,由用户创建,且requires_grad设为True的节点为叶节点
print([x.grad_fn for x in [F, D, C, A]])    # [<AddBackward0 object at 0x7f972de8c7b8>, <ExpBackward object at 0x7f972de8c278>, <MulBackward0 object at 0x7f972de8c2b0>, None]  每个变量的grad_fn指向产生其算子的backward function,叶节点的grad_fn为空
print(F.grad_fn.next_functions) # ((<ExpBackward object at 0x7f972de8c390>, 0), (<AccumulateGrad object at 0x7f972de8c5f8>, 0)) 由于F = D + E, 因此F.grad_fn.next_functions也存在两项,分别对应于D, E两个变量,每个元组中的第一项对应于相应变量的grad_fn,第二项指示相应变量是产生其op的第几个输出。E作为叶节点,其上没有grad_fn,但有梯度累积函数,即AccumulateGrad(由于反传时多出可能产生梯度,需要进行累加)
F.backward(retain_graph=True)   # 进行梯度反传
print(A.grad, B.grad, E.grad)   # tensor(1.3591) tensor(5.4366) tensor(1.) 算得每个变量梯度,与求导得到的相符
print(C.grad, D.grad)   # None None 为节约空间,梯度反传完成后,中间节点的梯度并不会保留

三、总结

深度学习的核心是梯度下降算法,即:

y = wx + b
w(i) = w(i-1) - lr
▽w

b(i) = b(i-1) - lr*▽b

相关推荐
啊森要自信2 分钟前
【GUI自动化测试】控件、鼠标键盘操作与多场景自动化
c语言·开发语言·python·adb·ipython
YJlio3 分钟前
《Sysinternals实战指南》16.5 Ctrl2Cap 工具详解:把 Caps Lock 变成 Ctrl 的键盘改造与回退方法
linux·运维·服务器·网络·python·学习·计算机外设
某林2123 分钟前
从底层硬件死锁到 QoS 通信底层的全链路复盘
python·ros2·qos
Jutick4 分钟前
WebSocket 连接没断,行情却停了:如何给实时数据流加双层 watchdog?
python
石头城的小石头5 分钟前
【从0到1的鼠标位置显示记录器,基于python环境pycharm下编译实施,最终打包为exe,欢迎交流】
python·目标跟踪·pycharm·计算机外设·鼠标
用户8356290780519 分钟前
Python 操作 Word 修订跟踪(Track Changes)
后端·python
意图共鸣14 分钟前
意图共鸣科技《AI记忆链商业化白皮书3.0》假设场景解析:从母亲到消防员,专属AI如何重塑记忆与传承
人工智能·科技·架构
电商API_1800790524715 分钟前
Python 实现闲鱼商品列表批量采集,接口异常重试机制搭建
大数据·开发语言·数据库·爬虫·python
放下华子我只抽RuiKe523 分钟前
FastAPI 全栈后端(四):认证与授权
开发语言·前端·javascript·python·深度学习·react.js·fastapi
ai产品老杨23 分钟前
解耦安防碎片化:基于 Docker 与边缘计算的 AI 视频管理平台架构演进(附 GB28181/RTSP 统一接入与源码交付实践)
人工智能·docker·边缘计算