torch中关于张量是否是叶子结点,张量梯度是否会被计算,张量梯度是否会被保存的感悟

先上结论:

1、叶子结点定义:

(1)不依赖其它任何结点的张量

(2)依赖其它张量,但其依赖的所有张量的require_grad=False

判断方法:查看is_leaf属性

2、张量梯度是否会被计算:

require_grad=True,且依赖其的张量不全为require_grad=False,该张量梯度会被计算

判断方法:backward之后查看张量的.grad属性(中间变量满足上述要求的梯度肯定也会被计算,只是backward之后会被释放掉,无法查看。中间变量其下面的叶子结点梯度被计算,根据链式法则,侧面也可证明中间变量的梯度肯定被计算了,因此本文只采用叶子结点说明该规律)

3、张量梯度是否会被保存(前提: 张量梯度可以被计算):

(1)是叶子结点

(2)是非叶子结点,但retain_grad=True

判断方法: backward之后查看张量的.grad属性

然后看例子:

两个例子先证明第一条:叶子结点定义

python 复制代码
import torch

# 两个例子先证明第一条:叶子结点定义
a = torch.tensor(1.,requires_grad=True)
b = torch.tensor(1.,requires_grad=False)
d = a+b
d.backward()

print(a.is_leaf) # True  (不依赖其它任何张量的结点)
print(b.is_leaf) # True  (不依赖其它任何张量的结点)
print(d.is_leaf) # False (依赖其它张量a和b,但张量a的require_grad=True,所以d不是叶子结点)


c = torch.tensor(1., requires_grad=False)
e = b+c

print(b.is_leaf)    # True  (不依赖其它任何张量的结点)
print(c.is_leaf)    # True  (不依赖其它任何张量的结点)
print(e.is_leaf)    # True (依赖其它张量b和c,但张量a和c的require_grad=False,所以e仍然是叶子结点)
# e.backward()会报错,因为所有结点的require_grad=False,因此不需要求梯度

两个例子再证明第二条:张量梯度是否会被计算

python 复制代码
import torch

# 两个例子再证明第二条:张量梯度是否会被计算
# require_grad=True的张量,且依赖其的张量不全为require_grad=False梯度会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True,因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False,因此b的梯度不会被计算
d = a+b                                   # d require_grad=True,因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
e = a+c                                   # e require_grad=True,因此e的梯度会被计算
d = d.detach()
print(d.requires_grad)                    # d require_grad=False,因此d的梯度不会被计算

f = d+e 
f.backward()
print(a.grad)                             # a有梯度

# require_grad=True的张量,但依赖其的张量全为require_grad=False,梯度不会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True, 因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False, 因此b的梯度不会被计算
d = a+b                                   # d require_grad=True, 因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
d = d.detach()                            # d require_grad=False, 因此d的梯度不会被计算
f = d+c                                   # f require_grad=True,因此f的梯度会被计算
f.backward()
print(a.grad)                             # a没有梯度

再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)

python 复制代码
import torch

# 再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)
# (1)叶子结点的梯度会被保存
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
d.backward()
print(a.is_leaf)
print(a.grad)                               # a的梯度被保存


# 感悟: 神经网络各层里面的参数require_grad=True(属于Parameter类型,其初始化的时候默认require_grad=True),并且如果上层不会被断开(满足梯度可以被计算条件)。且神经网络里面各层的参数都是叶子结点(从计算图可以得知满足1里面第一条),因此满足梯度保存条件第一条。因此其梯度一定会被保存。满足了以上两条,因此backward的时候其梯度一定会被计算并且保存,从而step的时候才能用于梯度更新)

# (2)非叶子结点,但retain_grad=True的张量梯度也会被保存。
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
print(d.is_leaf)                            # False
d.retain_grad()                             # retain_grad=True
d.backward()
print(d.is_leaf)                            # False
print(d.grad)                               # d的梯度被保存
相关推荐
Debroon1 分钟前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
小码农<^_^>6 分钟前
优选算法精品课--滑动窗口算法(一)
算法
羊小猪~~8 分钟前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨9 分钟前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测
晨曦_子画13 分钟前
编程语言之战:AI 之后的 Kotlin 与 Java
android·java·开发语言·人工智能·kotlin
道可云15 分钟前
道可云人工智能&元宇宙每日资讯|2024国际虚拟现实创新大会将在青岛举办
大数据·人工智能·3d·机器人·ar·vr
人工智能培训咨询叶梓24 分钟前
探索开放资源上指令微调语言模型的现状
人工智能·语言模型·自然语言处理·性能优化·调优·大模型微调·指令微调
zzZ_CMing24 分钟前
大语言模型训练的全过程:预训练、微调、RLHF
人工智能·自然语言处理·aigc
newxtc26 分钟前
【旷视科技-注册/登录安全分析报告】
人工智能·科技·安全·ddddocr
成都古河云27 分钟前
智慧场馆:安全、节能与智能化管理的未来
大数据·运维·人工智能·安全·智慧城市