Pytorch的梯度控制

在之前的实验中遇到一些问题,因为之前计算资源有限,我就想着微调其中一部分参数做,于是我误打误撞使用了with torch.no_grad,可是发现梯度传递不了,于是写下此文来记录梯度控制的两个方法与区别。

在PyTorch中,控制梯度计算对于模型训练和微调至关重要。这里区分两个常用方法:

1. tensor.requires_grad = False

  • 目标: 单个张量(通常是模型参数 nn.Parameter)。
  • 行为:
    • "参数冻结" :这个张量本身不会计算梯度 (.gradNone)。
    • "参数不更新" :优化器不会更新这个张量。
    • "梯度可穿透" :如果它参与的运算的输入是 requires_grad=True 的,梯度仍然会通过这个运算传递给输入。它不阻碍梯度流向更早的可训练层。
  • 场景:
    • 微调:冻结预训练模型的某些层,只训练其他层。
    • 例子:pretrained_layer.weight.requires_grad = False

2. with torch.no_grad():

  • 目标: 一个代码块 (with 语句块内部)。

  • 行为:

    • "全局梯度关闭" (块内):块内所有新创建的张量默认 requires_grad=False
    • "不记录计算图" :块内的运算不被追踪,不构建反向传播所需的计算图。
    • "梯度截断" :梯度流到这个块的边界就会停止,无法通过块内的操作继续反向传播
  • 场景:

    • 模型评估/推理 (Inference/Evaluation):不需要梯度,节省内存和计算。
    • 执行不需要梯度的任何计算。
    • 例子:
    python 复制代码
     with torch.no_grad():
         outputs = model(inputs)
         # ...其他评估代码

核心区别速记:

特性 requires_grad=False with torch.no_grad():
谁不更新? 这个参数自己 (块内)没人更新
梯度能过吗? 能过! 不能过! (被截断)
影响范围? 单个张量 整个代码块

一句话总结:

  • 想让某个参数不更新但梯度能流过 ,用 requires_grad=False
  • 想让一段代码完全不计算梯度也不让梯度流过 ,用 with torch.no_grad()

搞清楚这两者的区别,能在PyTorch中更灵活地控制模型的训练过程!

相关推荐
Giorno3723 分钟前
用 LLM 做数据提取踩了 6 个坑,我加了 6 层防御——15000 张发票的实战总结
人工智能
沉浸式学习ing4 分钟前
播客和视频怎么变成知识库里的笔记?音视频转结构化笔记完整方案
人工智能·笔记·gpt·学习·ai·音视频·notion
kexnjdcncnxjs7 分钟前
SQL批量删除不同条件的记录_使用IN子句简化删除逻辑
jvm·数据库·python
Soari8 分钟前
终结 Vibe Coding(Harness Engineering)!深度拆解 ralph:以交付所有 PRD 为生命周期的自主 AI Agent 闭环
自动化测试·人工智能·软件工程·aiagent·ralph·harnesseng·prd驱动
yezannnnnn9 分钟前
ToAgent:下一个被颠覆的不是某个行业,是"App"这个概念本身
人工智能
2303_8212873810 分钟前
如何安装Oracle 12c Cloud Control_OMS服务端组件与Agent部署
jvm·数据库·python
Be reborn10 分钟前
用例不是孤立执行的:依赖、变量池与 storage_state 设计
python·自动化·pytest
m0_6091604912 分钟前
React Flow 边缘错位与消失问题的根源分析与 Hooks 重构方案
jvm·数据库·python
Marvel__Dead12 分钟前
微调 Gemma 4 识别腾讯天御全系列验证码【解决方案-一个模型识别 滑块|文字点选|图标点选|空间点选】
人工智能·爬虫·python·验证码识别·ai 大模型
Agent手记13 分钟前
成品发货全流程自动化,落地实操与错发漏发规避方案 | 2026企业级Agent端到端落地指南
运维·人工智能·ai·自动化