with torch.no_grad()在Pytorch中的应用

with torch.no_grad()在Pytorch中的应用

参考:

https://blog.csdn.net/qq_24761287/article/details/129773333

https://blog.csdn.net/sazass/article/details/116668755

在学习Pytorch时,老遇到 with torch.no_grad(),搞不清其作用,现在详细了解一下。

1、with torch.no_grad()含义

torch.no_grad() 上下文管理器通常用于那些不需要计算梯度的操作,例如在模型评估或推断时。在这些情况下,关闭自动求导功能可以提高代码执行效率,因为不需要计算梯度的操作通常比需要计算梯度的操作更快。

bash 复制代码
with torch.no_grad():
    # some code that doesn't require gradients

2、with torch.no_grad()运用场景

简单来说,如果不需要在接下来步骤中用到所计算的式子的梯度,就可以使用with torch.no_grad()来提升运算速度。

2.1 只评估模型

在模型的评估模式下,对验证数据集进行前向传播并计算性能指标,而不计算或存储梯度信息。这有助于节省内存和提高代码执行效率。在此处能使用with torch.no_grad()的根本原因是我们不依赖于模型得到的结果去执行梯度下降操作,例如:

bash 复制代码
model.eval()
with torch.no_grad():
    for inputs, targets in validation_loader:
        outputs = model(inputs)
        # 计算指标,如准确率、损失等
2.2 此模型的计算结果不参与此模型的梯度下降

在SAC算法的更新过程中,需要用到策略policy网络的结果去更新Q网络的参数 ,在计算策略policy网络的结果时,该计算结果并不会用于更新policy网络 ,因此我们需要使用with torch.no_grad():对next_log_prob = self.policy_net.evaluate(next_state)进行修饰。

bash 复制代码
        predicted_q_value1 = self.soft_q_net1(state)
        predicted_q_value1 = predicted_q_value1.gather(1, action.unsqueeze(-1))
        predicted_q_value2 = self.soft_q_net2(state)
        predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1))
        log_prob = self.policy_net.evaluate(state)
        # with torch.no_grad()表示不带梯度,因为只是用policy_net得到next_log_prob,对更新Q网络不起作用
        with torch.no_grad():
            next_log_prob = self.policy_net.evaluate(next_state)
        # reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem

        # Training Q Function
        self.alpha = self.log_alpha.exp()
        target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state), self.target_soft_q_net2(
            next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1)
        target_q_value = reward + (1 - done) * gamma * target_q_min  # if done==1, only reward
        q_value_loss1 = self.soft_q_criterion1(predicted_q_value1,
                                               target_q_value.detach())  # detach: no gradients for the variable
        q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())

        self.soft_q_optimizer1.zero_grad()
        q_value_loss1.backward()
        self.soft_q_optimizer1.step()
        self.soft_q_optimizer2.zero_grad()
        q_value_loss2.backward()
        self.soft_q_optimizer2.step()
2.3 模型更新参数

当你在优化算法中更新模型参数时,不需要在参数更新步骤中计算梯度。在更新参数时使用 torch.no_grad() 可以防止出现错误,并确保计算过程正确。

bash 复制代码
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

3、with torch.no_grad()本质作用

在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。例子如下所示:

bash 复制代码
x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
    w = x + y + z
    print(w.requires_grad)
    print(w.grad_fn)
print(w.requires_grad)

输出:
False
None
False

4、为什么要使用with torch.no_grad()

如果在这些情况下没有使用torch.no_grad() 会导致哪些错误?

  1. 额外的内存消耗:计算和存储梯度需要额外的内存。在不需要梯度的情况下仍然计算梯度会导致不必要的内存消耗。在内存有限的设备上,如GPU,这可能导致内存不足而无法执行计算。
  2. 降低计算速度:计算梯度会增加计算负担。如果在不需要梯度的情况下仍然计算梯度,会降低计算速度,从而增加模型评估和推理的时间。
  3. 可能的计算错误:在某些情况下,如在优化算法中更新参数时,如果不使用torch.no_grad(),可能导致错误。例如,如果你在需要梯度的张量上执行原地操作,PyTorch会抛出RuntimeError,因为这样的操作会破坏计算图和梯度计算。

虽然在某些情况下忘记使用 torch.no_grad() 可能不会立即导致错误,但为了确保计算效率和正确性,建议在不需要梯度计算的情况下使用 torch.no_grad()。

下面给出使用with torch.no_grad()修饰不需要求导语句和不使用的对比,可以看到在同样的实际内,使用修饰会带来更好的速度。

效果也是使用了with torch.no_grad()更好,但是这些都是参考,毕竟每次训练的收敛速度都不太一致:

相关推荐
m0_748232925 分钟前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
袁袁袁袁满8 分钟前
100天精通Python(爬虫篇)——第113天:‌爬虫基础模块之urllib详细教程大全
开发语言·爬虫·python·网络爬虫·爬虫实战·urllib·urllib模块教程
szxinmai主板定制专家11 分钟前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室14 分钟前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习25 分钟前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
老大白菜31 分钟前
Python 爬虫技术指南
python
QQ同步助手40 分钟前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代43 分钟前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新1 小时前
【AI】人工智能、LLM学习资源汇总
人工智能·学习
古希腊掌管学习的神2 小时前
[搜广推]王树森推荐系统——矩阵补充&最近邻查找
python·算法·机器学习·矩阵