torch中关于张量是否是叶子结点,张量梯度是否会被计算,张量梯度是否会被保存的感悟

先上结论:

1、叶子结点定义:

(1)不依赖其它任何结点的张量

(2)依赖其它张量,但其依赖的所有张量的require_grad=False

判断方法:查看is_leaf属性

2、张量梯度是否会被计算:

require_grad=True,且依赖其的张量不全为require_grad=False,该张量梯度会被计算

判断方法:backward之后查看张量的.grad属性(中间变量满足上述要求的梯度肯定也会被计算,只是backward之后会被释放掉,无法查看。中间变量其下面的叶子结点梯度被计算,根据链式法则,侧面也可证明中间变量的梯度肯定被计算了,因此本文只采用叶子结点说明该规律)

3、张量梯度是否会被保存(前提: 张量梯度可以被计算):

(1)是叶子结点

(2)是非叶子结点,但retain_grad=True

判断方法: backward之后查看张量的.grad属性

然后看例子:

两个例子先证明第一条:叶子结点定义

python 复制代码
import torch

# 两个例子先证明第一条:叶子结点定义
a = torch.tensor(1.,requires_grad=True)
b = torch.tensor(1.,requires_grad=False)
d = a+b
d.backward()

print(a.is_leaf) # True  (不依赖其它任何张量的结点)
print(b.is_leaf) # True  (不依赖其它任何张量的结点)
print(d.is_leaf) # False (依赖其它张量a和b,但张量a的require_grad=True,所以d不是叶子结点)


c = torch.tensor(1., requires_grad=False)
e = b+c

print(b.is_leaf)    # True  (不依赖其它任何张量的结点)
print(c.is_leaf)    # True  (不依赖其它任何张量的结点)
print(e.is_leaf)    # True (依赖其它张量b和c,但张量a和c的require_grad=False,所以e仍然是叶子结点)
# e.backward()会报错,因为所有结点的require_grad=False,因此不需要求梯度

两个例子再证明第二条:张量梯度是否会被计算

python 复制代码
import torch

# 两个例子再证明第二条:张量梯度是否会被计算
# require_grad=True的张量,且依赖其的张量不全为require_grad=False梯度会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True,因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False,因此b的梯度不会被计算
d = a+b                                   # d require_grad=True,因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
e = a+c                                   # e require_grad=True,因此e的梯度会被计算
d = d.detach()
print(d.requires_grad)                    # d require_grad=False,因此d的梯度不会被计算

f = d+e 
f.backward()
print(a.grad)                             # a有梯度

# require_grad=True的张量,但依赖其的张量全为require_grad=False,梯度不会被计算
a = torch.tensor(1.,requires_grad=True)   # a require_grad=True, 因此a的梯度会被计算
b = torch.tensor(1.,requires_grad=False)  # b require_grad=False, 因此b的梯度不会被计算
d = a+b                                   # d require_grad=True, 因此d的梯度会被计算
c = torch.tensor(1., requires_grad=True)  # c require_grad=True,因此c的梯度会被计算
d = d.detach()                            # d require_grad=False, 因此d的梯度不会被计算
f = d+c                                   # f require_grad=True,因此f的梯度会被计算
f.backward()
print(a.grad)                             # a没有梯度

再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)

python 复制代码
import torch

# 再证明第三条: 张量梯度是否会被保存,前提是张量的梯度能被计算(既满足第二条)
# (1)叶子结点的梯度会被保存
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
d.backward()
print(a.is_leaf)
print(a.grad)                               # a的梯度被保存


# 感悟: 神经网络各层里面的参数require_grad=True(属于Parameter类型,其初始化的时候默认require_grad=True),并且如果上层不会被断开(满足梯度可以被计算条件)。且神经网络里面各层的参数都是叶子结点(从计算图可以得知满足1里面第一条),因此满足梯度保存条件第一条。因此其梯度一定会被保存。满足了以上两条,因此backward的时候其梯度一定会被计算并且保存,从而step的时候才能用于梯度更新)

# (2)非叶子结点,但retain_grad=True的张量梯度也会被保存。
a = torch.tensor(1.,requires_grad=True)     # a的require_grad= True,且a是一个叶子结点(a.is_leaf=True),所以backward之后a的梯度会被保存。(满足条件3)
b = torch.tensor(1.,requires_grad=False)    # b是一个叶子结点(b.is_leaf=True),但b的require_grad= False,所以backward之后b的梯度不会被保存.(不满足条件3)
d = a+b                                     # d的require_grad= True, 但d不是一个叶子结点,所以backward之后b的梯度不会被保存。(不满足条件3)
print(d.is_leaf)                            # False
d.retain_grad()                             # retain_grad=True
d.backward()
print(d.is_leaf)                            # False
print(d.grad)                               # d的梯度被保存
相关推荐
科技小花4 小时前
全球化深水区,数据治理成为企业出海 “核心竞争力”
大数据·数据库·人工智能·数据治理·数据中台·全球化
zhuiyisuifeng5 小时前
2026前瞻:GPTimage2镜像官网或将颠覆视觉创作
人工智能·gpt
徐健峰5 小时前
GPT-image-2 热门玩法实战(一):AI 看手相 — 一张手掌照片生成专业手相分析图
人工智能·gpt
weixin_370976355 小时前
AI的终极赛跑:进入AGI,还是泡沫破灭?
大数据·人工智能·agi
Slow菜鸟5 小时前
AI学习篇(五) | awesome-design-md 使用说明
人工智能·学习
超级码力6665 小时前
【Latex文件架构】Latex文件架构模板
算法·数学建模·信息可视化
穿条秋裤到处跑6 小时前
每日一道leetcode(2026.04.29):二维网格图中探测环
算法·leetcode·职场和发展
冬奇Lab6 小时前
RAG 系列(五):Embedding 模型——语义理解的核心
人工智能·llm·aigc
深小乐6 小时前
AI 周刊【2026.04.27-05.03】:Anthropic 9000亿美元估值、英伟达死磕智能体、中央重磅定调AI
人工智能
码点滴6 小时前
什么时候用 DeepSeek V4,而不是 GPT-5/Claude/Gemini?
人工智能·gpt·架构·大模型·deepseek