相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482
在Pytorch中,detach是Tensor的一个重要方法,用于返回一个脱离了计算图的张量,它的语法如下所示。
python
Tensor.detach() → Tensor
在理解这个方法前,首先得知道几个概念,它们是Pytorch的基础。
required_grad属性
所有张量都拥有这个属性,它可以是True或False中的一个,代表这个张量需要计算梯度。如果设置为False,在反向传播的链式求导时该张量会被当作常数处理,而不是一个自变量。
该属性可以在创建张量时,通过required_grad参数来指定如例1所示。
python
# 例1
x = torch.tensor([2.0], requires_grad=True)
也可以使用requires_grad_()方法或直接修改属性的方式进行动态修改,如例2所示。但要注意的是,只能改变叶张量的required_grad属性。
python
# 例2
x.requires_grad_(False)
x.requires_grad = False
对于计算过程中创建的张量,其required_grad属性取决于其来源,如果新张量计算过程用到的原张量的required_grad属性都是False,则新张量的required_grad属性为False;否则,只要有某个原张量的required_grad属性为True,新张量的required_grad属性为True,如例3所示。
python
# 例3
import torch
x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出False
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.requires_grad) # 输出True
x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x * y
print(z.requires_grad) # 输出True
grad_fn属性
对于计算过程中创建的张量(非叶张量),如果其required_grad属性为True,则其grad_fn属性会记录生成该张量的操作,用于反向传播,如例4所示。
python
# 例4
import torch
x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=False)
z = x * y
print(z.grad_fn) # 输出:None
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=False)
z = x / y
print(z.grad_fn) # 输出:<DivBackward0 object at 0x7fa9cfc75fd0>
x = torch.tensor([2.0], requires_grad=False)
y = torch.tensor([2.0], requires_grad=True)
z = x + y
print(z.grad_fn) # 输出:<AddBackward0 object at 0x7fa9cfc75fd0>
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
z = x - y
print(z.grad_fn) # 输出:<SubBackward0 object at 0x7fa9cfc75fd0>
如果你想临时禁用梯度计算,可以使用torch.no_grad()下文管理器来包裹不需要梯度计算的代码块,这样新张量的required_grad属性一定为False(自然grad_fn属性为None),与原张量的required_grad属性无关,如例5所示。
python
# 例5
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
z = x + y
print(z.requires_grad) # 输出:False
print(z.grad_fn) # 输出:None
如果反向传播到达一个张量时,其required_grad属性为False(或者说grad_fn属性为None),则会报错,如例6所示。
python
# 例6
import torch
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
with torch.no_grad():
z = x + y
z.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
is_leaf属性
所有张量都拥有is_leaf属性,它可以是True或False中的一个,表示该张量是否为叶张量(leaf tensor),叶张量指的是那些并非计算而得的张量,比如权重张量、偏置张量、输入张量,如例7所示。
python
# 例7
import torch
x = torch.tensor([2.0], requires_grad=True)
linear = torch.nn.Linear(1, 1)
y = x * 2
print(x.is_leaf) # 输出:True
print(linear.weight.is_leaf) # 输出:True
print(linear.bias.is_leaf) # 输出:True
print(y.is_leaf) # 输出:False
retain_grad属性
所有张量都拥有retain_grad属性,它可以是True或False中的一个,用于指定是否在反向传播后保留该张量的梯度(默认情况下,为了节约内存,非叶张量的梯度在用于反向传播后会被删除)。使用retain_grad()方法可以设置一个张量的retain_grad属性为True,从而保留非叶张量的梯度(注意:如果required_grad属性为False,设置retain_grad属性为True是无意义的),如例8所示。
python
# 例8
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x**2
z = y**2
t = z**2
z.retain_grad()
t.backward()
print(x.grad) # 输出:tensor([1024.])
print(y.grad) # UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:475.) return self._grad None
print(z.grad) # 输出:tensor([32.])
detach()方法
下面进入正题,当一个张量调用detach()方法时,会返回一个新的张量,该张量与调用detach()方法的张量共享底层存储(除grad外),但其required_grad属性为False,如例9所示。
python
# 例9
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x.detach()
print(y.requires_grad) # 输出:False
z_1 = y**2
z_2 = x**2
print(z_1.requires_grad) # 输出:False
print(z_1.grad_fn) # 输出:None
#z_1.backward() RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
print(z_2.requires_grad) # 输出:True
print(z_2.grad_fn) # 输出:<PowBackward0 object at 0x7f44042ccb50>
z_2.backward()
print(id(x)) # 输出:139785294725264
print(id(y)) # 输出:139785274024560
print(x.storage().data_ptr()) # 输出:76321152
print(y.storage().data_ptr()) # 输出:76321152
其中张量z_1是由经过张量x的detach()方法返回的张量y计算而来,而张量y的required_grad属性为False,因此张量z_1的required_grad属性也为False,其grad_fn属性为None,因此不能反向传播。
张量x和张量y的id号不同,证明它们是不同的张量,但storage().data_ptr()方法返回的指针表示,它们共享底层存储。