Pytorch基础:Tensor的detach方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


在Pytorch中,detach是Tensor的一个重要方法,用于返回一个脱离了计算图的张量,它的语法如下所示。

python 复制代码
Tensor.detach() → Tensor

在理解这个方法前,首先得知道几个概念,它们是Pytorch的基础。

required_grad属性

所有张量都拥有这个属性,它可以是True或False中的一个,代表这个张量需要计算梯度。如果设置为False,在反向传播的链式求导时该张量会被当作常数处理,而不是一个自变量。

该属性可以在创建张量时,通过required_grad参数来指定如例1所示。

python 复制代码
# 例1
x = torch.tensor([2.0], requires_grad=True)

也可以使用requires_grad_()方法或直接修改属性的方式进行动态修改,如例2所示。但要注意的是,只能改变叶张量的required_grad属性。

python 复制代码
# 例2
x.requires_grad_(False)
x.requires_grad = False

对于计算过程中创建的张量,其required_grad属性取决于其来源,如果新张量计算过程用到的原张量的required_grad属性都是False,则新张量的required_grad属性为False;否则,只要有某个原张量的required_grad属性为True,新张量的required_grad属性为True,如例3所示。

python 复制代码
# 例3
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出False

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

grad_fn属性

对于计算过程中创建的张量(非叶张量),如果其required_grad属性为True,则其grad_fn属性会记录生成该张量的操作,用于反向传播,如例4所示。

python 复制代码
# 例4
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.grad_fn) # 输出:None

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x / y
print(z.grad_fn) # 输出:<DivBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x + y
print(z.grad_fn) # 输出:<AddBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x - y
print(z.grad_fn) # 输出:<SubBackward0 object at 0x7fa9cfc75fd0>

如果你想临时禁用梯度计算,可以使用torch.no_grad()下文管理器来包裹不需要梯度计算的代码块,这样新张量的required_grad属性一定为False(自然grad_fn属性为None),与原张量的required_grad属性无关,如例5所示。

python 复制代码
# 例5
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
print(z.requires_grad) # 输出:False
print(z.grad_fn)       # 输出:None

如果反向传播到达一个张量时,其required_grad属性为False(或者说grad_fn属性为None),则会报错,如例6所示。

python 复制代码
# 例6
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
z.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

is_leaf属性

所有张量都拥有is_leaf属性,它可以是True或False中的一个,表示该张量是否为叶张量(leaf tensor),叶张量指的是那些并非计算而得的张量,比如权重张量、偏置张量、输入张量,如例7所示。

python 复制代码
# 例7
import torch
x = torch.tensor([2.0], requires_grad=True)
linear = torch.nn.Linear(1, 1)
y = x * 2
print(x.is_leaf)              # 输出:True
print(linear.weight.is_leaf)  # 输出:True
print(linear.bias.is_leaf)    # 输出:True
print(y.is_leaf)              # 输出:False

retain_grad属性

所有张量都拥有retain_grad属性,它可以是True或False中的一个,用于指定是否在反向传播后保留该张量的梯度(默认情况下,为了节约内存,非叶张量的梯度在用于反向传播后会被删除)。使用retain_grad()方法可以设置一个张量的retain_grad属性为True,从而保留非叶张量的梯度(注意:如果required_grad属性为False,设置retain_grad属性为True是无意义的),如例8所示。

python 复制代码
# 例8
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x**2
z = y**2
t = z**2
z.retain_grad()
t.backward()
print(x.grad) # 输出:tensor([1024.])
print(y.grad) # UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.) return self._grad None
print(z.grad) # 输出:tensor([32.])

detach()方法

下面进入正题,当一个张量调用detach()方法时,会返回一个新的张量,该张量与调用detach()方法的张量共享底层存储(除grad外),但其required_grad属性为False,如例9所示。

python 复制代码
# 例9
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x.detach()
print(y.requires_grad)   # 输出:False
z_1 = y**2
z_2 = x**2
print(z_1.requires_grad) # 输出:False
print(z_1.grad_fn)       # 输出:None
#z_1.backward()          RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

print(z_2.requires_grad) # 输出:True
print(z_2.grad_fn)       # 输出:<PowBackward0 object at 0x7f44042ccb50>
z_2.backward()           

print(id(x))             # 输出:139785294725264
print(id(y))             # 输出:139785274024560
print(x.storage().data_ptr()) # 输出:76321152
print(y.storage().data_ptr()) # 输出:76321152       

其中张量z_1是由经过张量x的detach()方法返回的张量y计算而来,而张量y的required_grad属性为False,因此张量z_1的required_grad属性也为False,其grad_fn属性为None,因此不能反向传播。

张量x和张量y的id号不同,证明它们是不同的张量,但storage().data_ptr()方法返回的指针表示,它们共享底层存储。

相关推荐
清风ꦿ几秒前
neo4j 图表数据导入到 TuGraph
python·neo4j·knowledge graph
深度学习lover2 小时前
[项目代码] YOLOv8 遥感航拍飞机和船舶识别 [目标检测]
python·yolo·目标检测·计算机视觉·遥感航拍飞机和船舶识别
云起无垠3 小时前
【论文速读】| FirmRCA:面向 ARM 嵌入式固件的后模糊测试分析,并实现高效的基于事件的故障定位
人工智能·自动化
水木流年追梦3 小时前
【python因果库实战10】为何需要因果分析
开发语言·python
m0_675988234 小时前
Leetcode2545:根据第 K 场考试的分数排序
python·算法·leetcode
Leweslyh5 小时前
物理信息神经网络(PINN)八课时教案
人工智能·深度学习·神经网络·物理信息神经网络
love you joyfully5 小时前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
该醒醒了~5 小时前
PaddlePaddle推理模型利用Paddle2ONNX转换成onnx模型
人工智能·paddlepaddle
小树苗1936 小时前
DePIN潜力项目Spheron解读:激活闲置硬件,赋能Web3与AI
人工智能·web3
凡人的AI工具箱6 小时前
每天40分玩转Django:Django测试
数据库·人工智能·后端·python·django·sqlite