关于模型学习策略

Warm-up + Cosine Decay 学习率策略

适用场景:Transformer/ViT、AdamW、大 batch、混合精度、分布式训练;以及所有"训练开局容易炸、后期难收敛"的任务。


1. 为什么学习率策略这么关键?

训练深度网络,本质是在高维空间里做优化。学习率 lrlrlr 就像"每一步迈多大"。

  • lrlrlr 太大:像开车不踩刹车,容易冲出路面(loss 爆炸、NaN、发散)
  • lrlrlr 太小:像推石头推不动,训练很慢,还可能卡在不好位置

更现实的是:同一场训练,不同阶段需要不同的"步幅"

  • 初期:参数随机、梯度噪声大,最危险 → 需要"先稳住"
  • 中期:需要更大步探索 → 需要"敢走"
  • 后期:需要小步精修 → 需要"慢慢收敛"

Warm-up + Cosine Decay 就是一个非常工程化的答案:
前面预热稳住,中后期余弦平滑衰减,收敛更稳、调参更省心。


2. 这套策略在做什么?

把训练总过程(按 step)分两段:

  1. Warm-up(预热) :lrlrlr 从很小逐渐升到 lrmaxlr_{max}lrmax
  2. Cosine Decay(余弦衰减) :lrlrlr 从 lrmaxlr_{max}lrmax 平滑降到 lrminlr_{min}lrmin

你可以把它理解成:

  • Warm-up:起步缓慢加速,避免"冷启动"就猛踩油门
  • Cosine:先慢降、再更慢降、最后贴地走,让收敛过程非常平滑

3. 为什么需要 Warm-up?

训练最开始,经常出现这些"初期不稳定因素":

  • 参数随机初始化,输出分布乱 → 梯度方向不稳定
  • Adam/AdamW 的一阶矩、二阶矩统计还没"热起来"
  • 混合精度下初期数值更脆弱
  • 大 batch / 分布式下等效更新更激进

如果上来就用一个较大的 lrlrlr,很容易出现:

  • loss 突然飙升
  • 梯度爆炸
  • NaN / Inf
  • 训练抖动很久才稳定,甚至直接发散

Warm-up 的核心价值:让模型从"危险冷启动"过渡到"稳定可学习"状态。


4. 为什么 Cosine Decay 好用?

常见衰减方式有 Step Decay、Exponential、Cosine 等。

Step Decay(每隔几轮乘个 0.1)的问题是:

学习率会"突然跳",loss 往往也会跟着抖一下。

Cosine 的特点是:全程连续、可导、非常平滑 。训练表现通常更稳定。

并且它的节奏很符合优化直觉:

  • 前期学习率保持较高更久 → 更充分探索
  • 后期学习率极其平缓 → 便于精细收敛

5. 数学解释

下面全部按 step(iteration) 来定义,更精确、更不容易踩坑。

设:

  • 总训练步数为 TTT
  • warm-up 步数为 TwT_wTw
  • 最大学习率 lrmaxlr_{max}lrmax
  • 最小学习率 lrminlr_{min}lrmin
  • warm-up 起始学习率 lrstartlr_{start}lrstart(常取 0 或者 0.1⋅lrmax0.1\cdot lr_{max}0.1⋅lrmax)

5.1 Warm-up:线性预热(最常用)

当 0≤t<Tw0 \le t < T_w0≤t<Tw 时:

lr(t)=lrstart+(lrmax−lrstart)⋅tTw lr(t)=lr_{start}+(lr_{max}-lr_{start})\cdot \frac{t}{T_w} lr(t)=lrstart+(lrmax−lrstart)⋅Twt

解释一下:

  • t=0t=0t=0 时 lr(0)=lrstartlr(0)=lr_{start}lr(0)=lrstart
  • t=Twt=T_wt=Tw 附近 lr≈lrmaxlr \approx lr_{max}lr≈lrmax

5.2 Cosine Decay:余弦衰减到 lrminlr_{min}lrmin

从 warm-up 结束开始进入余弦阶段。令

  • τ=t−Tw\tau = t - T_wτ=t−Tw
  • Tc=T−TwT_c = T - T_wTc=T−Tw(余弦阶段总步数)

当 Tw≤t≤TT_w \le t \le TTw≤t≤T 时:

lr(t)=lrmin+12(lrmax−lrmin)(1+cos⁡(π⋅τTc)) lr(t)=lr_{min}+\frac{1}{2}(lr_{max}-lr_{min})\left(1+\cos\left(\pi \cdot \frac{\tau}{T_c}\right)\right) lr(t)=lrmin+21(lrmax−lrmin)(1+cos(π⋅Tcτ))

边界非常漂亮(建议你在博客里强调):

  • 当 t=Twt=T_wt=Tw:τ=0\tau=0τ=0,cos⁡(0)=1\cos(0)=1cos(0)=1,所以 lr=lrmaxlr=lr_{max}lr=lrmax
  • 当 t=Tt=Tt=T:τ=Tc\tau=T_cτ=Tc,cos⁡(π)=−1\cos(\pi)=-1cos(π)=−1,所以 lr=lrminlr=lr_{min}lr=lrmin

6. 最重要的工程细节:到底用 epoch 还是 step?

强烈建议:按 step 调度学习率

因为你的真实优化更新发生在每次 optimizer.step(),而不是每个 epoch。并且下列因素会让"每个 epoch 的 step 数"不再直观:

  • drop_last
  • 分布式(global batch 变化)
  • 梯度累积(很多 step 才更新一次)
  • 动态 padding / 变长数据
  • 数据集最后一个 batch 不满

6.1 计算总步数 TTT

设:

  • 训练轮数 EEE
  • 每个 epoch 的 iteration 数为 III
  • 梯度累积步数为 AAA(每 AAA 次反传才 optimizer.step() 一次)

那么优化器更新次数(也就是 scheduler 的 step 次数)大致为:

T≈E⋅⌈IA⌉ T \approx E \cdot \left\lceil \frac{I}{A} \right\rceil T≈E⋅⌈AI⌉

建议:在代码里让计数器只在 optimizer.step() 时加 1,这样就不会算错。


7. 参数怎么选?

7.1 warm-up 占比 Tw/TT_w/TTw/T

经验范围(不是死规则):

  • 一般 CNN:11%\sim5%1
  • Transformer/ViT、大 batch:55%\sim10%5
  • 或固定步数:500、1000、2000(看训练总步数数量级)

7.2 lrminlr_{min}lrmin 怎么定?

常见选择:

  • lrmin=0lr_{min}=0lrmin=0
  • 或 lrmin=0.01⋅lrmaxlr_{min}=0.01\cdot lr_{max}lrmin=0.01⋅lrmax
  • 或固定如 1e−61e{-6}1e−6、3e−63e{-6}3e−6(AdamW 很常见)

7.3 lrstartlr_{start}lrstart 怎么定?

常见:

  • lrstart=0lr_{start}=0lrstart=0
  • 或 lrstart=0.1⋅lrmaxlr_{start}=0.1\cdot lr_{max}lrstart=0.1⋅lrmax(避免一开始完全为 0)

8. 一份"写不错"的 PyTorch 实现

这里给两种写法:手写函数官方 scheduler 拼装

更推荐手写函数:更容易理解,也更不容易踩 step/epoch 的坑。


8.1 写法 A:手写 Warm-up + Cosine(清晰、可控、最好讲)

python 复制代码
import math

def warmup_cosine_lr(step, total_steps, warmup_steps, lr_max, lr_min=0.0, lr_start=0.0):
    """
    step: 当前第几个 optimizer step(从0开始)
    total_steps: 总 optimizer step 数(T)
    warmup_steps: warm-up 的 optimizer step 数(Tw)
    """
    if warmup_steps < 1:
        warmup_steps = 1
    if total_steps <= warmup_steps:
        total_steps = warmup_steps + 1

    if step < warmup_steps:
        # Linear Warmup
        return lr_start + (lr_max - lr_start) * step / warmup_steps

    # Cosine Decay
    tau = step - warmup_steps
    Tc = total_steps - warmup_steps
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * tau / Tc))

训练循环(重点:只在 optimizer.step() 时推进 step):

python 复制代码
global_step = 0
for epoch in range(num_epochs):
    for it, batch in enumerate(dataloader):
        loss = model(batch)
        loss.backward()

        if (it + 1) % grad_accum == 0:
            lr = warmup_cosine_lr(
                step=global_step,
                total_steps=total_steps,
                warmup_steps=warmup_steps,
                lr_max=lr_max,
                lr_min=lr_min,
                lr_start=lr_start
            )
            for pg in optimizer.param_groups:
                pg["lr"] = lr

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

如果做了梯度累积,那么 warmup/cosine 的 step 应该是"优化器更新次数",而不是"反传次数"。


8.2 写法 B:官方 scheduler 组合

python 复制代码
import torch
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

optimizer = torch.optim.AdamW(model.parameters(), lr=lr_max)

warmup = LinearLR(
    optimizer,
    start_factor=lr_start / lr_max if lr_max > 0 else 0.0,
    end_factor=1.0,
    total_iters=warmup_steps
)

cosine = CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
    eta_min=lr_min
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_steps]
)

global_step = 0
for epoch in range(num_epochs):
    for it, batch in enumerate(dataloader):
        loss = model(batch)
        loss.backward()

        if (it + 1) % grad_accum == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            scheduler.step()  # 注意:按 optimizer step 调用
            global_step += 1

总结:这套策略为什么值得默认优先?

Warm-up + Cosine Decay 是一种非常"现代训练管线友好"的学习率策略:

  • Warm-up:显著减少训练初期发散与抖动
  • Cosine:平滑衰减,后期更容易稳定收敛
  • 对 AdamW / Transformer / 大 batch 特别友好
  • 参数少、调参成本低,容易作为 baseline
相关推荐
牛客企业服务2 小时前
2026年AI面试布局:破解规模化招聘的效率困局
人工智能·面试·职场和发展
gorgeous(๑>؂<๑)2 小时前
【北理工-AAAI26】MODA:首个无人机多光谱目标检测数据集
人工智能·目标检测·计算机视觉·目标跟踪·无人机
嵌入式的飞鱼2 小时前
SD NAND 焊接避坑指南:LGA-8 封装手工焊接技巧与常见错误
人工智能·stm32·单片机·嵌入式硬件·tf卡
serve the people2 小时前
tensorflow 零基础吃透:RaggedTensor 与其他张量类型的转换
人工智能·tensorflow·neo4j
serve the people3 小时前
tensorflow 核心解析:tf.RaggedTensorSpec 作用与参数说明
人工智能·python·tensorflow
yzx9910133 小时前
当AI握住方向盘:智能驾驶如何重新定义出行未来
人工智能
Sui_Network3 小时前
备受期待的 POP 射击游戏 XOCIETY 正式在 Epic Games Store 开启体验
人工智能·游戏·rpc·区块链·量子计算·graphql
漫长的~以后3 小时前
GPT-5.2深度拆解:多档位自适应架构如何重塑AI推理效率
人工智能·gpt·架构
爱笑的眼睛114 小时前
自动机器学习组件的深度解析:超越AutoML框架的底层架构
java·人工智能·python·ai