Pytorch基础:Tensor的detach方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


在Pytorch中,detach是Tensor的一个重要方法,用于返回一个脱离了计算图的张量,它的语法如下所示。

python 复制代码
Tensor.detach() → Tensor

在理解这个方法前,首先得知道几个概念,它们是Pytorch的基础。

required_grad属性

所有张量都拥有这个属性,它可以是True或False中的一个,代表这个张量需要计算梯度。如果设置为False,在反向传播的链式求导时该张量会被当作常数处理,而不是一个自变量。

该属性可以在创建张量时,通过required_grad参数来指定如例1所示。

python 复制代码
# 例1
x = torch.tensor([2.0], requires_grad=True)

也可以使用requires_grad_()方法或直接修改属性的方式进行动态修改,如例2所示。但要注意的是,只能改变叶张量的required_grad属性。

python 复制代码
# 例2
x.requires_grad_(False)
x.requires_grad = False

对于计算过程中创建的张量,其required_grad属性取决于其来源,如果新张量计算过程用到的原张量的required_grad属性都是False,则新张量的required_grad属性为False;否则,只要有某个原张量的required_grad属性为True,新张量的required_grad属性为True,如例3所示。

python 复制代码
# 例3
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出False

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True

grad_fn属性

对于计算过程中创建的张量(非叶张量),如果其required_grad属性为True,则其grad_fn属性会记录生成该张量的操作,用于反向传播,如例4所示。

python 复制代码
# 例4
import torch

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.grad_fn) # 输出:None

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x / y
print(z.grad_fn) # 输出:<DivBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x + y
print(z.grad_fn) # 输出:<AddBackward0 object at 0x7fa9cfc75fd0>

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x - y
print(z.grad_fn) # 输出:<SubBackward0 object at 0x7fa9cfc75fd0>

如果你想临时禁用梯度计算,可以使用torch.no_grad()下文管理器来包裹不需要梯度计算的代码块,这样新张量的required_grad属性一定为False(自然grad_fn属性为None),与原张量的required_grad属性无关,如例5所示。

python 复制代码
# 例5
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
print(z.requires_grad) # 输出:False
print(z.grad_fn)       # 输出:None

如果反向传播到达一个张量时,其required_grad属性为False(或者说grad_fn属性为None),则会报错,如例6所示。

python 复制代码
# 例6
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
    z = x + y
z.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

is_leaf属性

所有张量都拥有is_leaf属性,它可以是True或False中的一个,表示该张量是否为叶张量(leaf tensor),叶张量指的是那些并非计算而得的张量,比如权重张量、偏置张量、输入张量,如例7所示。

python 复制代码
# 例7
import torch
x = torch.tensor([2.0], requires_grad=True)
linear = torch.nn.Linear(1, 1)
y = x * 2
print(x.is_leaf)              # 输出:True
print(linear.weight.is_leaf)  # 输出:True
print(linear.bias.is_leaf)    # 输出:True
print(y.is_leaf)              # 输出:False

retain_grad属性

所有张量都拥有retain_grad属性,它可以是True或False中的一个,用于指定是否在反向传播后保留该张量的梯度(默认情况下,为了节约内存,非叶张量的梯度在用于反向传播后会被删除)。使用retain_grad()方法可以设置一个张量的retain_grad属性为True,从而保留非叶张量的梯度(注意:如果required_grad属性为False,设置retain_grad属性为True是无意义的),如例8所示。

python 复制代码
# 例8
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x**2
z = y**2
t = z**2
z.retain_grad()
t.backward()
print(x.grad) # 输出:tensor([1024.])
print(y.grad) # UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:475.) return self._grad None
print(z.grad) # 输出:tensor([32.])

detach()方法

下面进入正题,当一个张量调用detach()方法时,会返回一个新的张量,该张量与调用detach()方法的张量共享底层存储(除grad外),但其required_grad属性为False,如例9所示。

python 复制代码
# 例9
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x.detach()
print(y.requires_grad)   # 输出:False
z_1 = y**2
z_2 = x**2
print(z_1.requires_grad) # 输出:False
print(z_1.grad_fn)       # 输出:None
#z_1.backward()          RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

print(z_2.requires_grad) # 输出:True
print(z_2.grad_fn)       # 输出:<PowBackward0 object at 0x7f44042ccb50>
z_2.backward()           

print(id(x))             # 输出:139785294725264
print(id(y))             # 输出:139785274024560
print(x.storage().data_ptr()) # 输出:76321152
print(y.storage().data_ptr()) # 输出:76321152       

其中张量z_1是由经过张量x的detach()方法返回的张量y计算而来,而张量y的required_grad属性为False,因此张量z_1的required_grad属性也为False,其grad_fn属性为None,因此不能反向传播。

张量x和张量y的id号不同,证明它们是不同的张量,但storage().data_ptr()方法返回的指针表示,它们共享底层存储。

相关推荐
charley.layabox4 小时前
8月1日ChinaJoy酒会 | 游戏出海高端私享局 | 平台 × 发行 × 投资 × 研发精英畅饮畅聊
人工智能·游戏
DFRobot智位机器人5 小时前
AIOT开发选型:行空板 K10 与 M10 适用场景与选型深度解析
人工智能
想成为风筝7 小时前
从零开始学习深度学习—水果分类之PyQt5App
人工智能·深度学习·计算机视觉·pyqt
F_D_Z7 小时前
MMaDA:多模态大型扩散语言模型
人工智能·语言模型·自然语言处理
江沉晚呤时7 小时前
在 C# 中调用 Python 脚本:实现跨语言功能集成
python·microsoft·c#·.net·.netcore·.net core
大知闲闲哟7 小时前
深度学习G2周:人脸图像生成(DCGAN)
人工智能·深度学习
飞哥数智坊8 小时前
Coze实战第15讲:钱都去哪儿了?Coze+飞书搭建自动记账系统
人工智能·coze
wenzhangli78 小时前
低代码引擎核心技术:OneCode常用动作事件速查手册及注解驱动开发详解
人工智能·低代码·云原生
电脑能手8 小时前
如何远程访问在WSL运行的Jupyter Notebook
ide·python·jupyter
Edward-tan8 小时前
CCPD 车牌数据集提取标注,并转为标准 YOLO 格式
python