Pytorch基础:Tensor的detach方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


在Pytorch中,detach是Tensor的一个重要方法,用于返回一个脱离了计算图的张量,它的语法如下所示。

python 复制代码
Tensor.detach() → Tensor

在理解这个方法前,首先得知道几个概念,它们是Pytorch的基础。

required_grad属性

所有张量都拥有这个属性,它可以是True或False中的一个,代表这个张量需要计算梯度。如果设置为False,在反向传播的链式求导时该张量会被当作常数处理,而不是一个自变量。

该属性可以在创建张量时,通过required_grad参数来指定如例1所示。

python 复制代码
# 例1
x = torch.tensor([2.0], requires_grad=True)

也可以使用requires_grad_()方法或直接修改属性的方式进行动态修改,如例2所示。但要注意的是,只能改变叶张量的required_grad属性。

python 复制代码
# 例2
x.requires_grad_(False)
x.requires_grad = False

对于计算过程中创建的张量,其required_grad属性取决于其来源,如果新张量计算过程用到的原张量的required_grad属性都是False,则新张量的required_grad属性为False;否则,只要有某个原张量的required_grad属性为True,新张量的required_grad属性为True,如例3所示。

python 复制代码
# 例3
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出False

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

grad_fn属性

对于计算过程中创建的张量(非叶张量),如果其required_grad属性为True,则其grad_fn属性会记录生成该张量的操作,用于反向传播,如例4所示。

python 复制代码
# 例4
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.grad_fn) # 输出:None

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x / y
print(z.grad_fn) # 输出:<DivBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x + y
print(z.grad_fn) # 输出:<AddBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x - y
print(z.grad_fn) # 输出:<SubBackward0 object at 0x7fa9cfc75fd0>

如果你想临时禁用梯度计算,可以使用torch.no_grad()下文管理器来包裹不需要梯度计算的代码块,这样新张量的required_grad属性一定为False(自然grad_fn属性为None),与原张量的required_grad属性无关,如例5所示。

python 复制代码
# 例5
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
print(z.requires_grad) # 输出:False
print(z.grad_fn)       # 输出:None

如果反向传播到达一个张量时,其required_grad属性为False(或者说grad_fn属性为None),则会报错,如例6所示。

python 复制代码
# 例6
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
z.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

is_leaf属性

所有张量都拥有is_leaf属性,它可以是True或False中的一个,表示该张量是否为叶张量(leaf tensor),叶张量指的是那些并非计算而得的张量,比如权重张量、偏置张量、输入张量,如例7所示。

python 复制代码
# 例7
import torch
x = torch.tensor([2.0], requires_grad=True)
linear = torch.nn.Linear(1, 1)
y = x * 2
print(x.is_leaf)              # 输出:True
print(linear.weight.is_leaf)  # 输出:True
print(linear.bias.is_leaf)    # 输出:True
print(y.is_leaf)              # 输出:False

retain_grad属性

所有张量都拥有retain_grad属性,它可以是True或False中的一个,用于指定是否在反向传播后保留该张量的梯度(默认情况下,为了节约内存,非叶张量的梯度在用于反向传播后会被删除)。使用retain_grad()方法可以设置一个张量的retain_grad属性为True,从而保留非叶张量的梯度(注意:如果required_grad属性为False,设置retain_grad属性为True是无意义的),如例8所示。

python 复制代码
# 例8
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x**2
z = y**2
t = z**2
z.retain_grad()
t.backward()
print(x.grad) # 输出:tensor([1024.])
print(y.grad) # UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.) return self._grad None
print(z.grad) # 输出:tensor([32.])

detach()方法

下面进入正题,当一个张量调用detach()方法时,会返回一个新的张量,该张量与调用detach()方法的张量共享底层存储(除grad外),但其required_grad属性为False,如例9所示。

python 复制代码
# 例9
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x.detach()
print(y.requires_grad)   # 输出:False
z_1 = y**2
z_2 = x**2
print(z_1.requires_grad) # 输出:False
print(z_1.grad_fn)       # 输出:None
#z_1.backward()          RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

print(z_2.requires_grad) # 输出:True
print(z_2.grad_fn)       # 输出:<PowBackward0 object at 0x7f44042ccb50>
z_2.backward()           

print(id(x))             # 输出:139785294725264
print(id(y))             # 输出:139785274024560
print(x.storage().data_ptr()) # 输出:76321152
print(y.storage().data_ptr()) # 输出:76321152       

其中张量z_1是由经过张量x的detach()方法返回的张量y计算而来,而张量y的required_grad属性为False,因此张量z_1的required_grad属性也为False,其grad_fn属性为None,因此不能反向传播。

张量x和张量y的id号不同,证明它们是不同的张量,但storage().data_ptr()方法返回的指针表示,它们共享底层存储。

相关推荐
梦魇梦狸º32 分钟前
mac 配置 python 环境变量
chrome·python·macos
查理零世1 小时前
算法竞赛之差分进阶——等差数列差分 python
python·算法·差分
查士丁尼·绵3 小时前
面试-字符串1
python
好评笔记3 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
算家云3 小时前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
小兜全糖(xdqt)4 小时前
python中单例模式
开发语言·python·单例模式
Python数据分析与机器学习4 小时前
python高级加密算法AES对信息进行加密和解密
开发语言·python
AI街潜水的八角4 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
noravinsc4 小时前
python md5加密
前端·javascript·python
唯余木叶下弦声4 小时前
PySpark之金融数据分析(Spark RDD、SQL练习题)
大数据·python·sql·数据分析·spark·pyspark