loss.backward() 在做什么
这一行调用的是 PyTorch 的反向传播 机制------核心作用是:从 loss 出发,沿着前向构建的计算图反向走,对所有 requires_grad=True 的参数计算梯度(∂loss / ∂param)并累加到 param.grad 上。
一、背景:前向时 PyTorch 在做什么
当 kd_algorithm.training_step(micro_batch) 执行学生模型前向、算 logits、算 loss 时,PyTorch 自动构建了一张计算图(computation graph):
- 每个张量记录了"我是怎么从哪些张量算出来的"(
grad_fn) - 这张图把
loss和所有可训练参数θ连接起来
例如简化后:
bash
input → [Linear w₁] → h₁ → [Linear w₂] → logits → [CE loss] → loss
↑ ↑ ↑
w₁.grad_fn w₂.grad_fn loss.grad_fn
二、loss.backward() 反向走这张图
1. 计算梯度(链式法则)
按拓扑逆序遍历计算图,对每个节点应用链式法则:
bash
∂loss/∂logits → ∂loss/∂h₁ → ∂loss/∂w₂ → ∂loss/∂w₁ ...
每经过一个算子,就调用其 backward 实现(如 matmul 的 backward 是矩阵乘转置等),把上游传下来的梯度乘上局部雅可比,再传给更上游。
2. 把梯度写到 param.grad
最终对每个叶子参数 p(p.requires_grad=True 且没有 grad_fn):
python
p.grad += ∂loss/∂p # 累加,不是覆盖
注意是 累加 ------这就是为什么需要 optimizer.zero_grad() 在每次 step 后把 grad 清零,否则梯度会一直叠加。
利用了"累加"特性来实现梯度累积 :连续多次 backward() 不调 zero_grad,梯度自然累加,最后一次累积满后才调 optimizer.step() + zero_grad()。
3. 释放计算图
默认情况下,反向走完后 PyTorch 会释放保存的中间激活值 (用于反向计算的 saved tensors),节省显存。再次对同一张图 backward 会报错,除非传 retain_graph=True。
三、在 FSDP 下的特殊行为
普通单卡 loss.backward() 只算梯度。但 FSDP(Fully Sharded Data Parallel),它在 backward 阶段会做更多事:
| 阶段 | FSDP 行为 |
|---|---|
| 前向 | 各 rank 持有 param 的一个 shard;前向时 all-gather 完整参数 → 计算 → 释放完整参数 |
| 反向 | 反向再次 all-gather 参数计算梯度 → reduce-scatter:把每个参数的梯度 reduce(求和/平均),并 scatter 回该 param shard 所在的 rank |
| 结果 | 每个 rank 只在自己负责的 param shard 上得到对应梯度(p.grad),梯度也被分片存储 |
所以 loss.backward() 在 FSDP 上下文里还隐含了跨 rank 通信 :reduce-scatter 把数据并行的梯度聚合并分片,相当于普通 DDP 的 all-reduce,但通信对象是 sharded 梯度。
四、整体作用与定位
backward() 是梯度计算阶段 ------它不更新参数 ,只把"这一步学生模型应该怎么调整"以梯度形式存到 p.grad 里,等优化器 step 时再用。
loss.backward() = 沿计算图反向应用链式法则,对所有可训练参数计算 ∂loss/∂p 并累加到 p.grad;在 FSDP 下还会自动完成跨 rank 的梯度 reduce-scatter,让每个 rank 拿到自己 param shard 对应的梯度。它是"算梯度",参数更新由后续 optimizer.step() 完成。
梯度累积(Gradient Accumulation)的作用与差异
一、什么是梯度累积
核心思想 :把一个"大 batch"拆成 N 个"小 batch(micro_batch)",连续做 N 次前向 + 反向但不更新参数 ,让梯度在 p.grad 上自然累加;累加 N 次后再做一次 optimizer.step()。
python
def backward(self, loss, model, optimizer, **kwargs):
self.step = (self.step + 1) % self.accumulated_gradient
loss = loss / self.accumulated_gradient # ← 先除以 N
loss.backward() # ← 梯度累加到 p.grad
def optimizer_step(self, optimizer, model, scheduler, **kwargs):
if self.step == 0: # ← 累满 N 次才真正更新
...
optimizer.step()
optimizer.zero_grad()
scheduler.step()
累积 N 次 backward → 1 次 optimizer.step。
二、为什么要累积------主要作用
1. 在显存受限下实现"大 batch 训练"(最核心动机)
显存放不下 batch_size=64 怎么办?拆成 8 次 batch_size=8 累积,效果约等于 batch_size=64:
| 方案 | 单步显存峰值 | 等效 batch | 显存友好 |
|---|---|---|---|
| 一次 batch=64 | 高(容易 OOM) | 64 | ✗ |
| 累积 8×8 | 低(只到 batch=8) | 64 | ✓ |
LLM 训练里"大 batch"对收敛稳定性、梯度信噪比、scaling law 都有帮助,但显存往往是瓶颈,累积是绕过显存限制拿到大有效 batch 的标准手段。
2. 匹配分布式并行的语义
每个 prompt batch 生成一批 rollout,被切成多个 micro_batch 喂给学生 actor。让 actor 在内部把这 N 个 micro_batch 累积成一次更新,意味着"一次 rollout → 一次梯度更新"------保持 on-policy 的语义清晰,避免一份 rollout 数据用多次。
三、累积 vs 不累积的具体区别
数学上:梯度等价(在常见条件下)
假设 loss 是样本平均(reduction="mean"):
- 不累积,batch=N:
bash
L = (1/N) Σᵢ Lᵢ
∂L/∂θ = (1/N) Σᵢ ∂Lᵢ/∂θ
- 累积 N 次,每次 micro_batch=1 (带
loss /= N归一化):
bash
每次 backward 累加: (1/N) ∂Lᵢ/∂θ
N 次后 p.grad = Σᵢ (1/N) ∂Lᵢ/∂θ = (1/N) Σᵢ ∂Lᵢ/∂θ ✓
例如配置 train_batch_size=128、micro_train_batch_size=4、world_size=8:
python
num_micro_batches = self.args.train.train_batch_size // self.args.train.micro_train_batch_size
- 数据并行下每 GPU 一次 forward 跑 4 条样本
- 累积步数 = 128 / 4 / 8 = 4(每张卡反向 4 次后再 step 一次)
- 等效全局 batch = 4 × 4 × 8 = 128
梯度累积 = 把"大 batch 一次更新"拆成"小 batch 多次累加 + 最后一次更新",用计算时间换显存空间,数学上(在样本平均损失 + 无 BN 的前提下)与不累积等价。区别主要在显存峰值、吞吐、BN 统计、通信次数和更新频率上;LLM 训练里因为用 LayerNorm,二者基本等价,所以累积是大模型训练框架的默认手段。
梯度累积为什么能节省显存
显存峰值由"单次前向 + 反向时同时驻留的中间激活值(activations)"决定,而梯度累积让每一步只跑一个小 micro_batch,激活值只按小 batch 算 → 峰值显存大幅下降。累加的只是参数级别的梯度(p.grad),它的大小与 batch 无关。
一、训练时显存的几大组成
一次训练步显存占用大致为:
| 组成 | 大小是否随 batch 增长 | 说明 |
|---|---|---|
模型参数 θ |
否 | 与模型大小有关,与 batch 无关 |
梯度 p.grad |
否 | 形状与参数相同,与 batch 无关 |
| 优化器状态(Adam: m, v) | 否 | ~2× 参数大小,与 batch 无关 |
| 前向激活值(activations) | 是,随 batch 线性增长 | 反向计算梯度时必须保留的中间张量 |
| 临时 buffer / workspace | 部分 | 通信缓冲、cuDNN workspace 等 |
LLM 训练里最大的显存开销往往是 activations------因为每层 transformer 的输入、attention 中间结果(Q/K/V、attention scores)都要保存以便反向传播。它的大小约为:
python
activations ≈ batch_size × seq_len × hidden_size × num_layers × const
显存峰值就是发生在反向开始前------所有层的 activations 都还驻留在 GPU 上的那一刻。
二、不累积 vs 累积的显存对比
不累积(一次大 batch=N)
bash
forward (batch=N) → 保存 ~N 倍激活值 ─┐
├─ 显存峰值 ∝ N
backward ← 消耗激活算梯度 ─────┘
optimizer.step
激活值正比于 N,N 大就 OOM。
累积(micro_batch=N/k,累积 k 次)
bash
循环 k 次:
forward (batch=N/k) → 保存 ~(N/k) 倍激活值 ─┐
├─ 每轮峰值 ∝ N/k
backward ← 消耗激活算梯度 ─────────┘
✅ 激活值 backward 后立即释放
✅ p.grad 累加(大小不变)
optimizer.step(累积 k 次后才执行一次)
每轮反向结束后 PyTorch 自动释放 该轮的激活值(默认 retain_graph=False),所以下一轮重新分配的激活值占用同一片显存,不会叠加。
→ 峰值显存 ∝ N/k,是不累积方案的 1/k。
三、为什么"累加梯度"自己不会占额外显存
python
loss.backward() # 在已有的 p.grad 上做 += (in-place)
p.grad是和p形状一致的张量,第一次 backward 就会分配好- 后续 backward 是 in-place 累加到这块已有显存上
- 累积 1 次 vs 累积 100 次,
p.grad占用一模一样
所以"累积"的代价是时间 (多跑 k 次 forward+backward),收益是激活值显存峰值降到 1/k------而梯度本身不变大。
四、一个直观的数字例子
假设:
- 模型 7B 参数,bf16 → 参数 14 GB ,梯度 14 GB ,Adam 状态 28 GB(fp32 m+v)
- 单卡 80GB 显存,去掉以上 56 GB,剩 ~24 GB 给 activations + buffer
- batch_size=64、seq_len=4096 时 activations 估 ~80 GB → OOM
- batch_size=8、seq_len=4096 时 activations 估 ~10 GB → 装得下
不累积要 OOM;累积 8 次 micro_batch=8:
- 每轮峰值 ≈ 56 + 10 = 66 GB ✓ 不 OOM
- 8 次反向后
p.grad仍是 14 GB(不变) - 累积满 8 次再
optimizer.step(),等效 batch=64
五、与 FSDP / 梯度检查点的关系
显存优化技术分三类------梯度累积只是其中之一,常常组合使用:
| 技术 | 削减的部分 | 代价 |
|---|---|---|
| 梯度累积 | activations(按 1/k) | 时间(多次 forward) |
| gradient checkpointing(重计算) | activations(按 ~1/√L 或更多) | 时间(反向时重做部分前向) |
| FSDP(参数/梯度/优化器状态分片) | params + grads + optim states(按 1/world_size) | 通信开销 |
六、为什么"梯度求和"不会爆?
可能有人会担心:累积 k 次梯度,p.grad 会越来越大?
不会 ------p.grad 是数值上累加,张量大小恒定不变(形状和 dtype 不变)。变化的只是里面存的数值,从 0 累加到最终的累计梯度,不消耗额外显存空间。
显存峰值主要被 forward 期间保留的 activations 撑大,而 activations 与 batch 大小成正比;梯度累积通过把大 batch 拆成多个小 micro_batch 串行跑,每轮的 activations 在 backward 后立刻释放,使峰值降到 1/k;累加发生在已分配好的 p.grad 上(in-place),不引入额外显存------所以"用时间换空间"成立。
optimizer.step() 的作用
核心作用 :用 loss.backward() 累积在 p.grad 中的梯度,按优化算法(如 AdamW、SGD)的更新规则实际修改模型参数 p.data 。这是训练循环里唯一真正改动模型权重的一步。
一、做了什么
简化的伪代码(以 SGD 为例):
python
for group in optimizer.param_groups:
lr = group["lr"]
for p in group["params"]:
if p.grad is None:
continue
p.data -= lr * p.grad # 实际更新参数
不同优化器的更新规则不同:
| 优化器 | 更新公式(简化) |
|---|---|
| SGD | p ← p - lr · g |
| SGD+Momentum | v ← μv + g; p ← p - lr · v |
| Adam / AdamW | 维护一阶/二阶矩估计 m, v;p ← p - lr · m̂ / (√v̂ + ε) - lr · wd · p |
本项目用的是 AdamW(fsdp_strategy.py 中通过 create_optimizer 创建),所以 optimizer.step() 在内部:
- 读取
p.grad - 更新一阶矩
m和二阶矩v(保存在optimizer.state[p]) - 做 bias correction
- 计算更新量并写回
p.data - 应用 weight decay(AdamW 把 weight decay 单独从梯度中解耦)
二、在训练循环中的位置
python
def optimizer_step(self, optimizer, model, scheduler, **kwargs):
if self.step == 0: # 累积满才更新
if self.max_norm > 0.0:
...clip_grad_norm_(model.parameters(), self.max_norm) # ① 梯度裁剪
optimizer.step() # ② 真正更新参数 ← 你问的这行
optimizer.zero_grad() # ③ 清零梯度,准备下一轮累积
if scheduler:
scheduler.step() # ④ 更新学习率
完整训练步顺序:
bash
forward → loss → backward → [grad_norm 监控] → clip_grad_norm_ → optimizer.step() → zero_grad → scheduler.step()
↑ ↑ ↑
算梯度 限制梯度大小 用梯度更新参数
注意三个关键点:
if self.step == 0守卫 :梯度累积下,只有累满accumulated_gradient次 backward 才执行 step------这就是"多次累加 + 一次更新"的实现- 顺序:clip 在 step 之前------必须先把梯度限制住,再用梯度更新
- step 后立即
zero_grad:清空p.grad,否则下一轮 backward 还会在旧梯度上累加,相当于把已用过的梯度重复用
三、与 FSDP 的关系
在 FSDP 下,每个 rank 只持有参数的一个 shard,对应的 p.grad 也是分片的。optimizer.step() 在 FSDP 中:
- 每个 rank 只更新自己的 param shard------不需要跨 rank 通信,因为梯度已在 backward 阶段通过 reduce-scatter 分到了对应 shard 所在 rank
- 优化器状态(Adam 的 m, v)也是分片存储的,与 param shard 一一对应------这是 FSDP 相对 DDP 节省显存的关键之一(DDP 下每个 rank 都要存完整的 m, v)
所以 optimizer.step() 在 FSDP 上是个"本地操作",性能开销与单卡基本一致。
四、step() vs backward() 的角色对比
| 操作 | 改变什么 | 输出 |
|---|---|---|
loss.backward() |
计算并累加 p.grad |
不改 p.data |
optimizer.step() |
修改 p.data(真正更新模型);更新优化器内部状态(m, v 等) |
不改 p.grad |
optimizer.zero_grad() |
把 p.grad 清零 |
准备下一轮累积 |
可以说:backward 是"算账",step 是"扣款" 。前者只是在 .grad 里写下"应该这样调整",后者才把这个调整真正落到模型权重上。
optimizer.step() 是按优化算法(这里是 AdamW)读取累积在 p.grad 里的梯度并真正更新模型参数 p.data 的步骤------这是整个训练循环里唯一改动模型权重的地方;前面的 backward 只是算梯度,没有 step 模型不会变。