pytorch梯度更新方法

一、方法1 autograd

python 复制代码
import torch
from torch import autograd

'''
demo1
'''
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = torch.tensor(3., requires_grad=True)

y = a ** 2 * x + b * x + c

print('before:', a.grad, b.grad, c.grad)
grads = autograd.grad(y, [a, b, c])
print('after :', grads[0], grads[1], grads[2])

二、方法2 backward

python 复制代码
import torch
A = torch.tensor(2., requires_grad=True)
B = torch.tensor(.5, requires_grad=True)
E = torch.tensor(1., requires_grad=True)
C = A * B
D = C.exp()
F = D + E
print(F)        # tensor(3.7183, grad_fn=<AddBackward0>) 打印计算结果,可以看到F的grad_fn指向AddBackward,即产生F的运算
print([x.is_leaf for x in [A, B, C, D, E, F]])  # [True, True, False, False, True, False] 打印是否为叶节点,由用户创建,且requires_grad设为True的节点为叶节点
print([x.grad_fn for x in [F, D, C, A]])    # [<AddBackward0 object at 0x7f972de8c7b8>, <ExpBackward object at 0x7f972de8c278>, <MulBackward0 object at 0x7f972de8c2b0>, None]  每个变量的grad_fn指向产生其算子的backward function,叶节点的grad_fn为空
print(F.grad_fn.next_functions) # ((<ExpBackward object at 0x7f972de8c390>, 0), (<AccumulateGrad object at 0x7f972de8c5f8>, 0)) 由于F = D + E, 因此F.grad_fn.next_functions也存在两项,分别对应于D, E两个变量,每个元组中的第一项对应于相应变量的grad_fn,第二项指示相应变量是产生其op的第几个输出。E作为叶节点,其上没有grad_fn,但有梯度累积函数,即AccumulateGrad(由于反传时多出可能产生梯度,需要进行累加)
F.backward(retain_graph=True)   # 进行梯度反传
print(A.grad, B.grad, E.grad)   # tensor(1.3591) tensor(5.4366) tensor(1.) 算得每个变量梯度,与求导得到的相符
print(C.grad, D.grad)   # None None 为节约空间,梯度反传完成后,中间节点的梯度并不会保留

三、总结

深度学习的核心是梯度下降算法,即:

y = wx + b
w(i) = w(i-1) - lr
▽w

b(i) = b(i-1) - lr*▽b

相关推荐
聚客AI1 天前
🙈AI Agent的未来:工具调用将如何重塑智能应用?
人工智能·agent·mcp
悟能不能悟1 天前
if __name__=‘__main__‘的用处
python
Source.Liu1 天前
【Python基础】 15 Rust 与 Python 基本类型对比笔记
笔记·python·rust
前端世界1 天前
Python 正则表达式实战:用 Match 对象轻松解析拼接数据流
python·正则表达式·php
幂简集成1 天前
通义灵码 AI 程序员低代码 API 课程实战教程
android·人工智能·深度学习·神经网络·低代码·rxjava
Tadas-Gao1 天前
阿里云通义MoE全局均衡技术:突破专家负载失衡的革新之道
人工智能·架构·大模型·llm·云计算
xiaozhazha_1 天前
快鹭云业财一体化系统技术解析:低代码+AI如何破解数据孤岛难题
人工智能·低代码
DreamNotOver1 天前
基于Scikit-learn集成学习模型的情感分析研究与实现
python·scikit-learn·集成学习
pan0c231 天前
集成学习(随机森林算法、Adaboost算法)
人工智能·机器学习·集成学习
pan0c231 天前
集成学习 —— 梯度提升树GBDT、XGBoost
人工智能·机器学习·集成学习