loss.backward() 和 梯度累积

loss.backward() 在做什么

这一行调用的是 PyTorch 的反向传播 机制------核心作用是:从 loss 出发,沿着前向构建的计算图反向走,对所有 requires_grad=True 的参数计算梯度(∂loss / ∂param)并累加到 param.grad

一、背景:前向时 PyTorch 在做什么

kd_algorithm.training_step(micro_batch) 执行学生模型前向、算 logits、算 loss 时,PyTorch 自动构建了一张计算图(computation graph)

  • 每个张量记录了"我是怎么从哪些张量算出来的"(grad_fn
  • 这张图把 loss 和所有可训练参数 θ 连接起来

例如简化后:

bash 复制代码
input → [Linear w₁] → h₁ → [Linear w₂] → logits → [CE loss] → loss
                ↑                  ↑                   ↑
              w₁.grad_fn        w₂.grad_fn         loss.grad_fn

二、loss.backward() 反向走这张图

1. 计算梯度(链式法则)

按拓扑逆序遍历计算图,对每个节点应用链式法则:

bash 复制代码
∂loss/∂logits  →  ∂loss/∂h₁  →  ∂loss/∂w₂  →  ∂loss/∂w₁ ...

每经过一个算子,就调用其 backward 实现(如 matmul 的 backward 是矩阵乘转置等),把上游传下来的梯度乘上局部雅可比,再传给更上游。

2. 把梯度写到 param.grad

最终对每个叶子参数 pp.requires_grad=True 且没有 grad_fn):

python 复制代码
p.grad += ∂loss/∂p     # 累加,不是覆盖

注意是 累加 ------这就是为什么需要 optimizer.zero_grad() 在每次 step 后把 grad 清零,否则梯度会一直叠加。

利用了"累加"特性来实现梯度累积 :连续多次 backward() 不调 zero_grad,梯度自然累加,最后一次累积满后才调 optimizer.step() + zero_grad()

3. 释放计算图

默认情况下,反向走完后 PyTorch 会释放保存的中间激活值 (用于反向计算的 saved tensors),节省显存。再次对同一张图 backward 会报错,除非传 retain_graph=True

三、在 FSDP 下的特殊行为

普通单卡 loss.backward() 只算梯度。但 FSDP(Fully Sharded Data Parallel),它在 backward 阶段会做更多事:

阶段 FSDP 行为
前向 各 rank 持有 param 的一个 shard;前向时 all-gather 完整参数 → 计算 → 释放完整参数
反向 反向再次 all-gather 参数计算梯度 → reduce-scatter:把每个参数的梯度 reduce(求和/平均),并 scatter 回该 param shard 所在的 rank
结果 每个 rank 只在自己负责的 param shard 上得到对应梯度(p.grad),梯度也被分片存储

所以 loss.backward() 在 FSDP 上下文里还隐含了跨 rank 通信 :reduce-scatter 把数据并行的梯度聚合并分片,相当于普通 DDP 的 all-reduce,但通信对象是 sharded 梯度。

四、整体作用与定位

backward()梯度计算阶段 ------它不更新参数 ,只把"这一步学生模型应该怎么调整"以梯度形式存到 p.grad 里,等优化器 step 时再用。

loss.backward() = 沿计算图反向应用链式法则,对所有可训练参数计算 ∂loss/∂p 并累加到 p.grad;在 FSDP 下还会自动完成跨 rank 的梯度 reduce-scatter,让每个 rank 拿到自己 param shard 对应的梯度。它是"算梯度",参数更新由后续 optimizer.step() 完成。

梯度累积(Gradient Accumulation)的作用与差异

一、什么是梯度累积

核心思想 :把一个"大 batch"拆成 N 个"小 batch(micro_batch)",连续做 N 次前向 + 反向但不更新参数 ,让梯度在 p.grad 上自然累加;累加 N 次后再做一次 optimizer.step()

python 复制代码
def backward(self, loss, model, optimizer, **kwargs):
    self.step = (self.step + 1) % self.accumulated_gradient
    loss = loss / self.accumulated_gradient   # ← 先除以 N
    loss.backward()                            # ← 梯度累加到 p.grad

def optimizer_step(self, optimizer, model, scheduler, **kwargs):
    if self.step == 0:                         # ← 累满 N 次才真正更新
        ...
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

累积 N 次 backward → 1 次 optimizer.step

二、为什么要累积------主要作用

1. 在显存受限下实现"大 batch 训练"(最核心动机)

显存放不下 batch_size=64 怎么办?拆成 8 次 batch_size=8 累积,效果约等于 batch_size=64:

方案 单步显存峰值 等效 batch 显存友好
一次 batch=64 高(容易 OOM) 64
累积 8×8 低(只到 batch=8) 64

LLM 训练里"大 batch"对收敛稳定性、梯度信噪比、scaling law 都有帮助,但显存往往是瓶颈,累积是绕过显存限制拿到大有效 batch 的标准手段。

2. 匹配分布式并行的语义

每个 prompt batch 生成一批 rollout,被切成多个 micro_batch 喂给学生 actor。让 actor 在内部把这 N 个 micro_batch 累积成一次更新,意味着"一次 rollout → 一次梯度更新"------保持 on-policy 的语义清晰,避免一份 rollout 数据用多次。

三、累积 vs 不累积的具体区别

数学上:梯度等价(在常见条件下)

假设 loss 是样本平均(reduction="mean"):

  • 不累积,batch=N
bash 复制代码
L = (1/N) Σᵢ Lᵢ
∂L/∂θ = (1/N) Σᵢ ∂Lᵢ/∂θ
  • 累积 N 次,每次 micro_batch=1 (带 loss /= N 归一化):
bash 复制代码
每次 backward 累加: (1/N) ∂Lᵢ/∂θ
N 次后 p.grad = Σᵢ (1/N) ∂Lᵢ/∂θ = (1/N) Σᵢ ∂Lᵢ/∂θ   ✓

例如配置 train_batch_size=128micro_train_batch_size=4world_size=8

python 复制代码
num_micro_batches = self.args.train.train_batch_size // self.args.train.micro_train_batch_size
  • 数据并行下每 GPU 一次 forward 跑 4 条样本
  • 累积步数 = 128 / 4 / 8 = 4(每张卡反向 4 次后再 step 一次)
  • 等效全局 batch = 4 × 4 × 8 = 128

梯度累积 = 把"大 batch 一次更新"拆成"小 batch 多次累加 + 最后一次更新",用计算时间换显存空间,数学上(在样本平均损失 + 无 BN 的前提下)与不累积等价。区别主要在显存峰值、吞吐、BN 统计、通信次数和更新频率上;LLM 训练里因为用 LayerNorm,二者基本等价,所以累积是大模型训练框架的默认手段。

梯度累积为什么能节省显存

显存峰值由"单次前向 + 反向时同时驻留的中间激活值(activations)"决定,而梯度累积让每一步只跑一个小 micro_batch,激活值只按小 batch 算 → 峰值显存大幅下降。累加的只是参数级别的梯度(p.grad),它的大小与 batch 无关。

一、训练时显存的几大组成

一次训练步显存占用大致为:

组成 大小是否随 batch 增长 说明
模型参数 θ 与模型大小有关,与 batch 无关
梯度 p.grad 形状与参数相同,与 batch 无关
优化器状态(Adam: m, v) ~2× 参数大小,与 batch 无关
前向激活值(activations) 是,随 batch 线性增长 反向计算梯度时必须保留的中间张量
临时 buffer / workspace 部分 通信缓冲、cuDNN workspace 等

LLM 训练里最大的显存开销往往是 activations------因为每层 transformer 的输入、attention 中间结果(Q/K/V、attention scores)都要保存以便反向传播。它的大小约为:

python 复制代码
activations ≈ batch_size × seq_len × hidden_size × num_layers × const

显存峰值就是发生在反向开始前------所有层的 activations 都还驻留在 GPU 上的那一刻。

二、不累积 vs 累积的显存对比

不累积(一次大 batch=N)
bash 复制代码
forward (batch=N)  →  保存 ~N 倍激活值 ─┐
                                          ├─ 显存峰值 ∝ N
backward          ←  消耗激活算梯度 ─────┘
optimizer.step

激活值正比于 N,N 大就 OOM

累积(micro_batch=N/k,累积 k 次)
bash 复制代码
循环 k 次:
  forward (batch=N/k)  →  保存 ~(N/k) 倍激活值 ─┐
                                                  ├─ 每轮峰值 ∝ N/k
  backward            ←  消耗激活算梯度 ─────────┘
  ✅ 激活值 backward 后立即释放
  ✅ p.grad 累加(大小不变)
optimizer.step(累积 k 次后才执行一次)

每轮反向结束后 PyTorch 自动释放 该轮的激活值(默认 retain_graph=False),所以下一轮重新分配的激活值占用同一片显存,不会叠加

峰值显存 ∝ N/k,是不累积方案的 1/k

三、为什么"累加梯度"自己不会占额外显存

python 复制代码
loss.backward()   # 在已有的 p.grad 上做 += (in-place)
  • p.grad 是和 p 形状一致的张量,第一次 backward 就会分配好
  • 后续 backward 是 in-place 累加到这块已有显存上
  • 累积 1 次 vs 累积 100 次,p.grad 占用一模一样

所以"累积"的代价是时间 (多跑 k 次 forward+backward),收益是激活值显存峰值降到 1/k------而梯度本身不变大。

四、一个直观的数字例子

假设:

  • 模型 7B 参数,bf16 → 参数 14 GB梯度 14 GB ,Adam 状态 28 GB(fp32 m+v)
  • 单卡 80GB 显存,去掉以上 56 GB,剩 ~24 GB 给 activations + buffer
  • batch_size=64、seq_len=4096 时 activations 估 ~80 GB → OOM
  • batch_size=8、seq_len=4096 时 activations 估 ~10 GB → 装得下

不累积要 OOM;累积 8 次 micro_batch=8

  • 每轮峰值 ≈ 56 + 10 = 66 GB ✓ 不 OOM
  • 8 次反向后 p.grad 仍是 14 GB(不变)
  • 累积满 8 次再 optimizer.step(),等效 batch=64

五、与 FSDP / 梯度检查点的关系

显存优化技术分三类------梯度累积只是其中之一,常常组合使用:

技术 削减的部分 代价
梯度累积 activations(按 1/k) 时间(多次 forward)
gradient checkpointing(重计算) activations(按 ~1/√L 或更多) 时间(反向时重做部分前向)
FSDP(参数/梯度/优化器状态分片) params + grads + optim states(按 1/world_size) 通信开销

六、为什么"梯度求和"不会爆?

可能有人会担心:累积 k 次梯度,p.grad 会越来越大?

不会 ------p.grad 是数值上累加,张量大小恒定不变(形状和 dtype 不变)。变化的只是里面存的数值,从 0 累加到最终的累计梯度,不消耗额外显存空间。

显存峰值主要被 forward 期间保留的 activations 撑大,而 activations 与 batch 大小成正比;梯度累积通过把大 batch 拆成多个小 micro_batch 串行跑,每轮的 activations 在 backward 后立刻释放,使峰值降到 1/k;累加发生在已分配好的 p.grad 上(in-place),不引入额外显存------所以"用时间换空间"成立。

optimizer.step() 的作用

核心作用 :用 loss.backward() 累积在 p.grad 中的梯度,按优化算法(如 AdamW、SGD)的更新规则实际修改模型参数 p.data 。这是训练循环里唯一真正改动模型权重的一步。

一、做了什么

简化的伪代码(以 SGD 为例):

python 复制代码
for group in optimizer.param_groups:
    lr = group["lr"]
    for p in group["params"]:
        if p.grad is None:
            continue
        p.data -= lr * p.grad     # 实际更新参数

不同优化器的更新规则不同:

优化器 更新公式(简化)
SGD p ← p - lr · g
SGD+Momentum v ← μv + g; p ← p - lr · v
Adam / AdamW 维护一阶/二阶矩估计 m, v;p ← p - lr · m̂ / (√v̂ + ε) - lr · wd · p

本项目用的是 AdamW(fsdp_strategy.py 中通过 create_optimizer 创建),所以 optimizer.step() 在内部:

  1. 读取 p.grad
  2. 更新一阶矩 m 和二阶矩 v(保存在 optimizer.state[p]
  3. 做 bias correction
  4. 计算更新量并写回 p.data
  5. 应用 weight decay(AdamW 把 weight decay 单独从梯度中解耦)

二、在训练循环中的位置

python 复制代码
def optimizer_step(self, optimizer, model, scheduler, **kwargs):
    if self.step == 0:                                  # 累积满才更新
        if self.max_norm > 0.0:
            ...clip_grad_norm_(model.parameters(), self.max_norm)   # ① 梯度裁剪
        
        optimizer.step()                                # ② 真正更新参数 ← 你问的这行
        optimizer.zero_grad()                           # ③ 清零梯度,准备下一轮累积
        if scheduler:
            scheduler.step()                            # ④ 更新学习率

完整训练步顺序:

bash 复制代码
forward → loss → backward → [grad_norm 监控] → clip_grad_norm_ → optimizer.step() → zero_grad → scheduler.step()
                              ↑                       ↑                ↑
                          算梯度                  限制梯度大小       用梯度更新参数

注意三个关键点

  1. if self.step == 0 守卫 :梯度累积下,只有累满 accumulated_gradient 次 backward 才执行 step------这就是"多次累加 + 一次更新"的实现
  2. 顺序:clip 在 step 之前------必须先把梯度限制住,再用梯度更新
  3. step 后立即 zero_grad :清空 p.grad,否则下一轮 backward 还会在旧梯度上累加,相当于把已用过的梯度重复用

三、与 FSDP 的关系

在 FSDP 下,每个 rank 只持有参数的一个 shard,对应的 p.grad 也是分片的。optimizer.step() 在 FSDP 中:

  • 每个 rank 只更新自己的 param shard------不需要跨 rank 通信,因为梯度已在 backward 阶段通过 reduce-scatter 分到了对应 shard 所在 rank
  • 优化器状态(Adam 的 m, v)也是分片存储的,与 param shard 一一对应------这是 FSDP 相对 DDP 节省显存的关键之一(DDP 下每个 rank 都要存完整的 m, v)

所以 optimizer.step() 在 FSDP 上是个"本地操作",性能开销与单卡基本一致。

四、step() vs backward() 的角色对比

操作 改变什么 输出
loss.backward() 计算并累加 p.grad 不改 p.data
optimizer.step() 修改 p.data(真正更新模型);更新优化器内部状态(m, v 等) 不改 p.grad
optimizer.zero_grad() p.grad 清零 准备下一轮累积

可以说:backward 是"算账",step 是"扣款" 。前者只是在 .grad 里写下"应该这样调整",后者才把这个调整真正落到模型权重上。

optimizer.step() 是按优化算法(这里是 AdamW)读取累积在 p.grad 里的梯度并真正更新模型参数 p.data 的步骤------这是整个训练循环里唯一改动模型权重的地方;前面的 backward 只是算梯度,没有 step 模型不会变。

相关推荐
>ᴗoಣ2 小时前
COSER: Coordinating LLM-Based Persona Simulation of Established Roles
人工智能·深度学习
云和数据.ChenGuang2 小时前
openEuler下NLP模型的部署和推理
人工智能·深度学习·机器学习·自然语言处理·数据挖掘·边缘计算
人工智能培训3 小时前
数字孪生建模常用方式有哪些?
人工智能·深度学习·机器学习·容器·知识图谱
轻刀快马3 小时前
跨越“拟人”的最后一道天堑:大模型强化学习(RLHF/RLAIF)底层原理解析
人工智能·深度学习·机器学习
大江东去浪淘尽千古风流人物3 小时前
【KV-Tracker】Transformer 实时位姿跟踪:KV-Cache 加速多视图几何网络达 27FPS
网络·深度学习·transformer·slam·位姿估计·kv-cache
lqqjuly4 小时前
推荐系统技术解析(Recommendation Systems)
深度学习·推荐算法
老鱼说AI4 小时前
统计学习方法第八章:Boosting
人工智能·深度学习·神经网络·机器学习·学习方法·集成学习·boosting
钓了猫的鱼儿4 小时前
基于深度学习+AI的无人机森林火灾目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·无人机
无负今日_tq5 小时前
【无标题】
人工智能·深度学习·条纹