Pytorch基础:Tensor的detach方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


在Pytorch中,detach是Tensor的一个重要方法,用于返回一个脱离了计算图的张量,它的语法如下所示。

python 复制代码
Tensor.detach() → Tensor

在理解这个方法前,首先得知道几个概念,它们是Pytorch的基础。

required_grad属性

所有张量都拥有这个属性,它可以是True或False中的一个,代表这个张量需要计算梯度。如果设置为False,在反向传播的链式求导时该张量会被当作常数处理,而不是一个自变量。

该属性可以在创建张量时,通过required_grad参数来指定如例1所示。

python 复制代码
# 例1
x = torch.tensor([2.0], requires_grad=True)

也可以使用requires_grad_()方法或直接修改属性的方式进行动态修改,如例2所示。但要注意的是,只能改变叶张量的required_grad属性。

python 复制代码
# 例2
x.requires_grad_(False)
x.requires_grad = False

对于计算过程中创建的张量,其required_grad属性取决于其来源,如果新张量计算过程用到的原张量的required_grad属性都是False,则新张量的required_grad属性为False;否则,只要有某个原张量的required_grad属性为True,新张量的required_grad属性为True,如例3所示。

python 复制代码
# 例3
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出False

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

grad_fn属性

对于计算过程中创建的张量(非叶张量),如果其required_grad属性为True,则其grad_fn属性会记录生成该张量的操作,用于反向传播,如例4所示。

python 复制代码
# 例4
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.grad_fn) # 输出:None

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x / y
print(z.grad_fn) # 输出:<DivBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x + y
print(z.grad_fn) # 输出:<AddBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x - y
print(z.grad_fn) # 输出:<SubBackward0 object at 0x7fa9cfc75fd0>

如果你想临时禁用梯度计算,可以使用torch.no_grad()下文管理器来包裹不需要梯度计算的代码块,这样新张量的required_grad属性一定为False(自然grad_fn属性为None),与原张量的required_grad属性无关,如例5所示。

python 复制代码
# 例5
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
print(z.requires_grad) # 输出:False
print(z.grad_fn)       # 输出:None

如果反向传播到达一个张量时,其required_grad属性为False(或者说grad_fn属性为None),则会报错,如例6所示。

python 复制代码
# 例6
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
z.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

is_leaf属性

所有张量都拥有is_leaf属性,它可以是True或False中的一个,表示该张量是否为叶张量(leaf tensor),叶张量指的是那些并非计算而得的张量,比如权重张量、偏置张量、输入张量,如例7所示。

python 复制代码
# 例7
import torch
x = torch.tensor([2.0], requires_grad=True)
linear = torch.nn.Linear(1, 1)
y = x * 2
print(x.is_leaf)              # 输出:True
print(linear.weight.is_leaf)  # 输出:True
print(linear.bias.is_leaf)    # 输出:True
print(y.is_leaf)              # 输出:False

retain_grad属性

所有张量都拥有retain_grad属性,它可以是True或False中的一个,用于指定是否在反向传播后保留该张量的梯度(默认情况下,为了节约内存,非叶张量的梯度在用于反向传播后会被删除)。使用retain_grad()方法可以设置一个张量的retain_grad属性为True,从而保留非叶张量的梯度(注意:如果required_grad属性为False,设置retain_grad属性为True是无意义的),如例8所示。

python 复制代码
# 例8
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x**2
z = y**2
t = z**2
z.retain_grad()
t.backward()
print(x.grad) # 输出:tensor([1024.])
print(y.grad) # UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.) return self._grad None
print(z.grad) # 输出:tensor([32.])

detach()方法

下面进入正题,当一个张量调用detach()方法时,会返回一个新的张量,该张量与调用detach()方法的张量共享底层存储(除grad外),但其required_grad属性为False,如例9所示。

python 复制代码
# 例9
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x.detach()
print(y.requires_grad)   # 输出:False
z_1 = y**2
z_2 = x**2
print(z_1.requires_grad) # 输出:False
print(z_1.grad_fn)       # 输出:None
#z_1.backward()          RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

print(z_2.requires_grad) # 输出:True
print(z_2.grad_fn)       # 输出:<PowBackward0 object at 0x7f44042ccb50>
z_2.backward()           

print(id(x))             # 输出:139785294725264
print(id(y))             # 输出:139785274024560
print(x.storage().data_ptr()) # 输出:76321152
print(y.storage().data_ptr()) # 输出:76321152       

其中张量z_1是由经过张量x的detach()方法返回的张量y计算而来,而张量y的required_grad属性为False,因此张量z_1的required_grad属性也为False,其grad_fn属性为None,因此不能反向传播。

张量x和张量y的id号不同,证明它们是不同的张量,但storage().data_ptr()方法返回的指针表示,它们共享底层存储。

相关推荐
爱打代码的小林几秒前
基于 OpenCV 与 Dlib 的人脸替换
人工智能·opencv·计算机视觉
无忧智库1 分钟前
某市“十五五“知识产权大数据监管平台与全链条保护系统建设方案深度解读(WORD)
大数据·人工智能
顾北121 分钟前
AI对话应用接口开发全解析:同步接口+SSE流式+智能体+前端对接
前端·人工智能
综合热讯5 分钟前
股票融资融券交易时间限制一览与制度说明
大数据·人工智能·区块链
AEIC学术交流中心5 分钟前
【快速EI检索 | ICPS出版】2026年计算机技术与可持续发展国际学术会议(CTSD 2026)
人工智能·计算机网络
玄同7658 分钟前
Python Random 模块深度解析:从基础 API 到 AI / 大模型工程化实践
人工智能·笔记·python·学习·算法·语言模型·llm
风指引着方向9 分钟前
昇腾 AI 开发生产力工具:CANN CLI 的高级使用与自动化脚本编写
运维·人工智能·自动化
算法狗210 分钟前
大模型面试题:1B的模型和1T的数据大概要训练多久
人工智能·深度学习·机器学习·语言模型
AIFarmer12 分钟前
在EV3上运行Python语言——环境设置
python·ev3