with torch.no_grad:是截断梯度记录的,新生成的数据的都不记录梯度,但是今天产生了一点小疑惑,如果存在多层函数嵌入,是不是函数内所有的数据都不记录梯度,验证了一下,确实是的。
python
import torch
x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
def add(x,y,z):
w = x + y + z
print(w.requires_grad)
print(w.grad_fn)
def add2(x,y,z):
add(x,y,z)
with torch.no_grad():
add2(x,y,z)
add2(x,y,z)
"""
输出:
False
None
True
<AddBackward0 object at 0x00000250371BED68>
"""