pytorch梯度更新方法

一、方法1 autograd

python 复制代码
import torch
from torch import autograd

'''
demo1
'''
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)

y = a ** 2 * x + b * x + c

print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after :', grads[0], grads[1], grads[2])

二、方法2 backward

python 复制代码
import torch
A = torch.tensor(2., requires_grad=True)
B = torch.tensor(.5, requires_grad=True)
E = torch.tensor(1., requires_grad=True)
C = A * B
D = C.exp()
F = D + E
print(F)        # tensor(3.7183, grad_fn=<AddBackward0>) 打印计算结果,可以看到F的grad_fn指向AddBackward,即产生F的运算
print([x.is_leaf for x in [A, B, C, D, E, F]])  # [True, True, False, False, True, False] 打印是否为叶节点,由用户创建,且requires_grad设为True的节点为叶节点
print([x.grad_fn for x in [F, D, C, A]])    # [<AddBackward0 object at 0x7f972de8c7b8>, <ExpBackward object at 0x7f972de8c278>, <MulBackward0 object at 0x7f972de8c2b0>, None]  每个变量的grad_fn指向产生其算子的backward function,叶节点的grad_fn为空
print(F.grad_fn.next_functions) # ((<ExpBackward object at 0x7f972de8c390>, 0), (<AccumulateGrad object at 0x7f972de8c5f8>, 0)) 由于F = D + E, 因此F.grad_fn.next_functions也存在两项,分别对应于D, E两个变量,每个元组中的第一项对应于相应变量的grad_fn,第二项指示相应变量是产生其op的第几个输出。E作为叶节点,其上没有grad_fn,但有梯度累积函数,即AccumulateGrad(由于反传时多出可能产生梯度,需要进行累加)
F.backward(retain_graph=True)   # 进行梯度反传
print(A.grad, B.grad, E.grad)   # tensor(1.3591) tensor(5.4366) tensor(1.) 算得每个变量梯度,与求导得到的相符
print(C.grad, D.grad)   # None None 为节约空间,梯度反传完成后,中间节点的梯度并不会保留

三、总结

深度学习的核心是梯度下降算法,即:

y = wx + b
w(i) = w(i-1) - lr
▽w

b(i) = b(i-1) - lr*▽b

相关推荐
沪漂阿龙13 小时前
OpenAI Agents SDK 深度解析(三):执行层——Agent 的“幕后指挥部”
人工智能·深度学习
还是奇怪13 小时前
AI 提示词工程入门:用好的语言与模型高效对话
大数据·人工智能·语言模型·自然语言处理·transformer
健忘的萝卜13 小时前
Clawdbot 爆红硅谷,也把 AI Agent 和 Mac mini 推上风口
人工智能·macos·agent·数字员工·clawbot
迁旭13 小时前
claude code 提示词
人工智能·语言模型·gpt-3·知识图谱
不知名的老吴13 小时前
深度探索:直接预测多个token可行吗?
人工智能·回归
数智工坊13 小时前
【SAM-DETR论文阅读】:基于语义对齐匹配的DETR极速收敛检测框架
网络·论文阅读·人工智能·深度学习·transformer
小康小小涵13 小时前
基于ESP32S3实现无人机RID模块底层源码编译
linux·开发语言·python
风落无尘13 小时前
LangChain 完全入门指南:从基础到实战(附面试题)
人工智能·langchain
IT_陈寒13 小时前
Vue的这个响应式陷阱,我debug了一整天才爬出来
前端·人工智能·后端
zz_lzh13 小时前
arm版AI牛马:armbian(rk3588)设备部署openclaw
arm开发·人工智能·arm