torch中关于张量是否是叶子结点,张量梯度是否会被计算,张量梯度是否会被保存的感悟

先上结论:

1、叶子结点定义:

(1)不依赖其它任何结点的张量

(2)依赖其它张量,但其依赖的所有张量的require_grad=False

判断方法:查看is_leaf属性

2、张量梯度是否会被计算:

require_grad=True,且依赖其的张量不全为require_grad=False,该张量梯度会被计算

判断方法:backward之后查看张量的.grad属性(中间变量满足上述要求的梯度肯定也会被计算,只是backward之后会被释放掉,无法查看。中间变量其下面的叶子结点梯度被计算,根据链式法则,侧面也可证明中间变量的梯度肯定被计算了,因此本文只采用叶子结点说明该规律)

3、张量梯度是否会被保存(前提: 张量梯度可以被计算):

(1)是叶子结点

(2)是非叶子结点,但retain_grad=True

判断方法: backward之后查看张量的.grad属性

然后看例子:

两个例子先证明第一条:叶子结点定义

python 复制代码
import torch

# 两个例子先证明第一条:叶子结点定义
a = torch.tensor(1.,requires_grad=True)
b = torch.tensor(1.,requires_grad=False)
d = a+b
d.backward()

print(a.is_leaf) # True  (不依赖其它任何张量的结点)
print(b.is_leaf) # True  (不依赖其它任何张量的结点)
print(d.is_leaf) # False (依赖其它张量a和b,但张量a的require_grad=True,所以d不是叶子结点)


c = torch.tensor(1., requires_grad=False)
e = b+c

print(b.is_leaf)    # True  (不依赖其它任何张量的结点)
print(c.is_leaf)    # True  (不依赖其它任何张量的结点)
print(e.is_leaf)    # True (依赖其它张量b和c,但张量a和c的require_grad=False,所以e仍然是叶子结点)
# e.backward()会报错,因为所有结点的require_grad=False,因此不需要求梯度

两个例子再证明第二条:张量梯度是否会被计算

python 复制代码
import torch

# 两个例子再证明第二条:张量梯度是否会被计算
# require_grad=True的张量,且依赖其的张量不全为require_grad=False梯度会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True,因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False,因此b的梯度不会被计算
d = a+b                                   # d require_grad=True,因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
e = a+c                                   # e require_grad=True,因此e的梯度会被计算
d = d.detach()
print(d.requires_grad)                    # d require_grad=False,因此d的梯度不会被计算

f = d+e 
f.backward()
print(a.grad)                             # a有梯度

# require_grad=True的张量,但依赖其的张量全为require_grad=False,梯度不会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True, 因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False, 因此b的梯度不会被计算
d = a+b                                   # d require_grad=True, 因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
d = d.detach()                            # d require_grad=False, 因此d的梯度不会被计算
f = d+c                                   # f require_grad=True,因此f的梯度会被计算
f.backward()
print(a.grad)                             # a没有梯度

再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)

python 复制代码
import torch

# 再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)
# (1)叶子结点的梯度会被保存
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
d.backward()
print(a.is_leaf)
print(a.grad)                               # a的梯度被保存


# 感悟: 神经网络各层里面的参数require_grad=True(属于Parameter类型,其初始化的时候默认require_grad=True),并且如果上层不会被断开(满足梯度可以被计算条件)。且神经网络里面各层的参数都是叶子结点(从计算图可以得知满足1里面第一条),因此满足梯度保存条件第一条。因此其梯度一定会被保存。满足了以上两条,因此backward的时候其梯度一定会被计算并且保存,从而step的时候才能用于梯度更新)

# (2)非叶子结点,但retain_grad=True的张量梯度也会被保存。
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
print(d.is_leaf)                            # False
d.retain_grad()                             # retain_grad=True
d.backward()
print(d.is_leaf)                            # False
print(d.grad)                               # d的梯度被保存
相关推荐
IT古董31 分钟前
第四章:大模型(LLM)】06.langchain原理-(3)LangChain Prompt 用法
java·人工智能·python
TGITCIC1 小时前
AI Search进化论:从RAG到DeepSearch的智能体演变全过程
人工智能·ai大模型·ai智能体·ai搜索·大模型ai·deepsearch·ai search
lucky_lyovo5 小时前
自然语言处理NLP---预训练模型与 BERT
人工智能·自然语言处理·bert
fantasy_arch5 小时前
pytorch例子计算两张图相似度
人工智能·pytorch·python
No0d1es6 小时前
电子学会青少年软件编程(C/C++)5级等级考试真题试卷(2024年6月)
c语言·c++·算法·青少年编程·电子学会·五级
AndrewHZ7 小时前
【3D重建技术】如何基于遥感图像和DEM等数据进行城市级高精度三维重建?
图像处理·人工智能·深度学习·3d·dem·遥感图像·3d重建
飞哥数智坊7 小时前
Coze实战第18讲:Coze+计划任务,我终于实现了企微资讯简报的定时推送
人工智能·coze·trae
Code_流苏7 小时前
AI热点周报(8.10~8.16):AI界“冰火两重天“,GPT-5陷入热议,DeepSeek R2模型训练受阻?
人工智能·gpt·gpt5·deepseek r2·ai热点·本周周报
赴3357 小时前
矿物分类案列 (一)六种方法对数据的填充
人工智能·python·机器学习·分类·数据挖掘·sklearn·矿物分类
大模型真好玩7 小时前
一文深度解析OpenAI近期发布系列大模型:意欲一统大模型江湖?
人工智能·python·mcp