PyTorch StepLR:等间隔学习率衰减的原理与实战

文章目录

  • 1、基本介绍
  • [2、StepLR - API 介绍](#2、StepLR - API 介绍)
  • [3、`StepLR` 的工作原理](#3、StepLR 的工作原理)
  • [4、scheduler 属性 / 方法](#4、scheduler 属性 / 方法)
  • [5、代码 & 学习率趋势图:](#5、代码 & 学习率趋势图:)

1、基本介绍

"等间隔学习率衰减"(Step Decay 或 Fixed Step Decay)是深度学习中一种常用的学习率调度(learning rate scheduling)策略,属于分段常数衰减(Piecewise Constant Decay)的一种形式。

基本思想

在训练过程中,每隔固定的训练轮数(epoch)或迭代步数(step),将学习率乘以一个衰减因子(通常小于1,如 0.1 或 0.5),从而逐步降低学习率。

公式表示

设初始学习率为 η 0 \eta_0 η0,衰减因子为 γ \gamma γ( 0 < γ < 1 0 < \gamma < 1 0<γ<1),每 T T T 个 epoch 衰减一次,则在第 t t t 个 epoch 的学习率为:

η ( t ) = η 0 ⋅ γ ⌊ t T ⌋ \eta(t) = \eta_0 \cdot \gamma^{\left\lfloor \frac{t}{T} \right\rfloor} η(t)=η0⋅γ⌊Tt⌋

其中 ⌊ ⋅ ⌋ \left\lfloor \cdot \right\rfloor ⌊⋅⌋ 表示向下取整。

举个例子

  • 初始学习率: η 0 = 0.1 \eta_0 = 0.1 η0=0.1
  • 衰减因子: γ = 0.5 \gamma = 0.5 γ=0.5
  • 衰减间隔:每 10 个 epoch

那么:

  • Epoch 0--9:学习率 = 0.1
  • Epoch 10--19:学习率 = 0.05
  • Epoch 20--29:学习率 = 0.025
  • ...以此类推

优点

  • 简单易实现:只需设置几个超参数(初始学习率、衰减间隔、衰减因子)。
  • 有效稳定训练后期:随着训练进行,模型接近收敛时,较小的学习率有助于更精细地调整参数,避免震荡。

缺点

  • 需要手动调参:衰减间隔和衰减因子的选择依赖经验或试错。
  • 不够自适应:不像余弦退火(Cosine Annealing)或 ReduceLROnPlateau 那样根据损失变化动态调整。

实现示例(PyTorch)

python 复制代码
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

for epoch in range(num_epochs):
    train(...)
    
    # 每个 epoch 后调用
    scheduler.step()  # 详情在后面: 调度器本身不自动计数,它依赖你每次调用 step() 来推进内部的 epoch 计数器。

总结

"等间隔学习率衰减"是一种简单但实用的学习率调度方法,通过定期按固定比例降低学习率,帮助模型在训练后期更稳定地收敛。虽然不如一些自适应方法灵活,但在很多任务中表现良好,尤其适合初学者或作为基线策略使用。

2、StepLR - API 介绍

📌 torch.optim.lr_scheduler.StepLR

StepLR = Step + LR(learning rate)

这是 PyTorch 中用于实现等间隔学习率衰减(Step Decay)的标准调度器。它会在训练过程中,每隔固定的 epoch 数,将学习率乘以一个衰减因子。

🔧 构造函数

python 复制代码
torch.optim.lr_scheduler.StepLR(    # 【 scheduler   n.调度器 】
    optimizer,
    step_size,
    gamma=0.1,
    last_epoch=-1,
    verbose=False
)

📘 参数详解

  1. optimizer(必需)
  • 类型torch.optim.Optimizer(如 SGD, Adam 等)
  • 作用:绑定要调整学习率的优化器。
  • 说明 :调度器通过修改 optimizer.param_groups 中的 'lr' 字段来动态改变学习率。必须在创建调度器前先定义好优化器。
  1. step_size(必需)
  • 类型int
  • 作用 :指定学习率衰减的间隔周期(单位:epoch)。
  • 说明
    • 每经过 step_size 个 epoch,学习率就会被乘以 gamma
    • 例如 step_size=10 表示:第 0--9 轮用初始学习率,第 10 轮开始衰减,第 20 轮再次衰减,依此类推。
    • 注意:这个"步长"是以 epoch 为单位,不是 batch 或 iteration。
  1. gamma(可选,默认 0.1
  • 类型float
  • 作用 :学习率的衰减系数
  • 说明
    • 每次衰减时,新的学习率 = 当前学习率 × gamma
    • 常见取值:0.1(每次变为 1/10)、0.5(每次减半)。
    • 必须满足 0 < gamma <= 1。若 gamma=1,则学习率永不衰减。
  1. last_epoch(可选,默认 -1
  • 类型int
  • 作用:指定调度器的起始 epoch 编号。
  • 说明
    • 默认为 -1,表示训练从 epoch 0 开始。
    • 如果你在恢复训练(checkpoint),可以设为上次训练结束时的 epoch 编号(如 last_epoch=49),这样调度器会自动计算当前应处的学习率阶段。
    • 调度器内部会根据 last_epoch 初始化状态,确保学习率与训练进度对齐。
  1. verbose(可选,默认 False
  • 类型bool

  • 作用:是否在每次学习率更新时打印日志。

  • 说明

    • 若设为 True,每次调用 scheduler.step() 后,会输出类似:

      复制代码
      Epoch 10: reducing learning rate of group 0 to 1.0000e-02.
    • 仅在 PyTorch 1.9 及以上版本支持。便于调试,但正式训练中通常关闭。


✅ 使用示例

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 模型和优化器
model = nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 创建调度器:每 5 个 epoch 衰减一次,衰减为原来的 0.5 倍
scheduler = StepLR(optimizer, step_size=5, gamma=0.5, verbose=True)

# 训练循环
for epoch in range(15):
    # ... 执行训练步骤 ...
    
    # 更新学习率(必须在每个 epoch 结束后调用!)
    scheduler.step() # 详情在后面: 调度器本身不自动计数,它依赖你每次调用 step() 来推进内部的 epoch 计数器。

输出示例(verbose=True)

复制代码
Epoch 5: reducing learning rate of group 0 to 5.0000e-02.
Epoch 10: reducing learning rate of group 0 to 2.5000e-02.

在 PyTorch ≥1.1.0 中,你应该先调用 optimizer.step(),再调用 scheduler.step()

否则,PyTorch 会跳过学习率调度器的第一个值(即初始学习率)。

在一个标准训练循环中,每个 epoch 内部 的典型流程是:

python 复制代码
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        optimizer.step()        # 1️⃣ 先更新参数
    scheduler.step()            # 2️⃣ 再更新学习率(每个 epoch 一次)

⚠️ 注意事项

  • 调用时机scheduler.step() 必须在每个 epoch 结束后调用一次(不是每个 batch!),否则学习率不会按预期衰减。
  • 与训练循环对齐 :如果你使用自定义的 epoch 计数(比如从 checkpoint 恢复),务必正确设置 last_epoch
  • 多参数组支持 :如果 optimizer 包含多个参数组(如不同层用不同学习率),StepLR 会对所有组统一应用相同的衰减规则。

🔗 相关调度器(扩展)

如果你需要非等间隔但预设的衰减点,可以使用:

  • MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
    在指定 epoch(如 30、80)衰减,常用于 ResNet 等经典训练策略。

但严格意义上的"等间隔衰减",首选仍是 StepLR

3、StepLR 的工作原理

虽然 step_size=5 表示"每 5 个 epoch 衰减一次",但 scheduler.step() 仍然需要在每个 epoch 都调用一次

这是因为:调度器本身不自动计数,它依赖你每次调用 step() 来推进内部的 epoch 计数器


🧠 详细解释

🔁 StepLR 的工作原理

  • StepLR 内部维护一个计数器(本质上是 self.last_epoch)。
  • 每次你调用 scheduler.step(),它就认为"又过了一轮训练"(即 +1 epoch)。
  • 然后它检查:当前 epoch 是否是 step_size 的整数倍?如果是,就执行衰减。

所以:

调用次数(epoch) 内部 epoch 计数 是否触发衰减(step_size=5)
第 0 次(epoch=0) 0
第 1 次(epoch=1) 1
... ... ...
第 5 次(epoch=5) 5 ✅ 是(5 % 5 == 0)
第 10 次(epoch=10) 10 ✅ 是

💡 注意:StepLR 的衰减发生在 step_size 个 epoch 结束时 ,也就是从第 step_size 轮开始使用新学习率。


❌ 常见误解

有些人会误以为:

"既然每 5 个 epoch 衰减一次,那我应该每 5 次才调用一次 scheduler.step()。"

这是错误的!如果你这样做:

python 复制代码
for epoch in range(15):
    train(...)
    if epoch % 5 == 0:
        scheduler.step()   # ❌ 错误做法!

那么调度器内部的 epoch 计数器只会在 epoch=0,5,10 时 +1,变成 0→1→2→3...

结果:实际上每 5×5=25 个 epoch 才衰减一次,完全偏离预期!


✅ 正确做法(再次强调)

python 复制代码
for epoch in range(num_epochs):
    train_one_epoch(...)      # 训练当前 epoch
    scheduler.step()          # 每个 epoch 结束后都调用!
  • step_size=5 控制的是衰减频率
  • scheduler.step() 的调用频率必须是每个 epoch 一次,用于推进调度器的内部状态。

🔍 验证小实验

你可以运行这段代码观察学习率变化:

python 复制代码
optimizer = optim.SGD([torch.tensor(0.0)], lr=1.0)
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

for epoch in range(12):
    print(f"Epoch {epoch}: LR = {optimizer.param_groups[0]['lr']:.4f}")
    scheduler.step()

输出:

复制代码
Epoch 0: LR = 1.0000
Epoch 1: LR = 1.0000
Epoch 2: LR = 1.0000
Epoch 3: LR = 1.0000
Epoch 4: LR = 1.0000
Epoch 5: LR = 0.1000   ← 衰减发生!
Epoch 6: LR = 0.1000
...
Epoch 10: LR = 0.0100  ← 再次衰减

可以看到:衰减确实发生在第 5、10 轮之后 ,而 step() 是每轮都调用的。


🔍 对比其他调度器命名逻辑

PyTorch 的调度器命名通常反映其行为特征:

调度器 名称含义
StepLR 阶梯式衰减:每隔固定步长,学习率突降
MultiStepLR 多个指定台阶点衰减
ExponentialLR 学习率按指数函数连续衰减
CosineAnnealingLR 学习率按余弦函数平滑下降
ReduceLROnPlateau 当指标停滞(plateau) 时才衰减

所以,"Step" 强调的是 不连续、阶段性、台阶式的跳跃行为,而非"步骤"(step as in iteration)的意思------尽管它确实通过"每 step_size 个 epoch 执行一次衰减"来实现。


📝 总结

问题 回答
step_size=5 是否意味着每 5 个 epoch 衰减一次? ✅ 是
是否需要每个 epoch 都调用 scheduler.step() ✅ 是!必须每个 epoch 调用一次
如果跳着调用 step() 会怎样? ❌ 调度器计数错乱,衰减时机错误

4、scheduler 属性 / 方法

基于 PyTorch ≥1.4 的主流版本:

属性 / 方法 详细说明 类型 / 返回值 注意事项与使用建议
scheduler.last_epoch 表示调度器内部记录的"已执行的 epoch 数"(即调用 step() 的次数)。初始为 -1,每调用一次 step() 自增 1。 int 已完成的step次数 :调度器内部计数器,记录已调用 step() 方法的次数。从-1开始,第一次调用后变为0。 - 实际对应的是 下一个将要使用的 epoch 编号 。 - 若从 checkpoint 恢复训练,可通过 last_epoch=N 初始化调度器,使其从第 N+1 轮开始计数。
scheduler.get_last_lr() 返回上一次 step() 调用后生效的学习率列表(即当前正在使用的 lr)。 List[float] 推荐方式 获取当前学习率。 - 每个参数组一个 lr(多数情况只有一个元素)。 - PyTorch ≥1.4 引入,替代旧的 get_lr()
scheduler.get_lr() 已弃用/不推荐使用。在旧版本中用于计算下一轮的学习率,但行为不稳定。 List[float] 不要在训练循环中调用 ! - 在 PyTorch 新版本中,此方法仅用于内部计算。 - 使用 get_last_lr() 替代。
scheduler.base_lrs 调度器初始化时从 optimizer 中保存的原始学习率列表(每个参数组一个)。 List[float] - 即使学习率被衰减多次,该值不会改变 。 - 可用于调试或重置学习率。
scheduler.gamma 学习率衰减系数(每次衰减时乘以此值)。 float - 由构造函数传入,默认为 0.1。 - 只读属性(部分版本可修改,但不建议)。
scheduler.step_size 学习率衰减的间隔(单位:epoch)。 int - 例如 step_size=20 表示每 20 个 epoch 衰减一次。 - 只读属性。
scheduler.state_dict() 返回调度器当前状态的字典,包含 last_epoch 等关键信息。 Dict[str, Any] 用于保存训练状态python<br>torch.save({'sched': scheduler.state_dict()}, 'ckpt.pth')<br>
scheduler.load_state_dict(state_dict) 从字典恢复调度器状态(通常与 state_dict() 配对使用)。 None 用于断点续训python<br>scheduler.load_state_dict(ckpt['sched'])<br> - 必须在 optimizer 加载之后调用。
scheduler.step(epoch=None) 推进调度器一步(默认自动递增 last_epoch),或跳转到指定 epoch。 None - 标准用法scheduler.step()(无参),放在每个 epoch 末尾。 - 高级用法scheduler.step(epoch=50) 可手动设置内部 epoch 计数(慎用,一般用于调试或特殊调度逻辑)。

💡 补充说明

  • 为什么有 get_last_lr()get_lr() 两个方法?

    PyTorch 在 1.4 版本重构了调度器 API,明确区分:

    • get_last_lr():返回已经应用的学习率(安全、可靠)。
    • get_lr():原本设计为返回"下一步"的 lr,但因实现混乱被弃用。
  • 如何打印当前学习率?

    python 复制代码
    print(f"Current LR: {scheduler.get_last_lr()[0]:.6f}")
    # 或
    print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")

    两者在正常调用 step() 后结果一致,但前者更语义清晰。

  • 多参数组支持

    如果 optimizer 有多个参数组(如不同层设不同 lr),上述所有列表(base_lrs, get_last_lr() 等)长度 >1,按组一一对应。

5、代码 & 学习率趋势图:

python 复制代码
import torch
from torch import optim
import matplotlib.pyplot as plt

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # ←←← 关键!放在最前面(解决报错)
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"]    # 设置显示中文字体
mpl.rcParams["axes.unicode_minus"] = False      # 设置正常显示符号

w = torch.tensor(data=[0.0], requires_grad=True, dtype=torch.float32)

optimizer = optim.SGD(params=[w], lr=0.01)

# 【 scheduler   n.调度器 】
# gamma=0.5: lr = lr * 0.5
scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.5)

lr_list = []

# 训练循环
epochs = 100
batch_size = 20
for epoch in range(1, epochs + 1):
    print(f'第 {epoch} 个 epoch 训练: ')

    # 假装使用 batch 
    for batch in range(batch_size):
        optimizer.zero_grad()     # 1. 清零梯度
        loss = (w - 5) ** 2  # 只在第10,20,...步计算梯度
        loss.backward()           # 4. 反向传播, 计算梯度
        optimizer.step()          # 5. 更新参数

    print(f'梯度: {w.grad.item()}')  # 只有一个梯度, 直接打印值就行了
    print(f'跟新后的权重: {w.item()}')

    # get_last_lr() 返回一个 list,即使只有一个参数组,也会返回 [0.01], 取出这个值就行
    # 否则绘图时可能会出错(比如 plt.plot(...) 不知道如何处理嵌套列表)
    # 记录当前学习率(取第一个)
    lr_list.append(scheduler.get_last_lr()[0])

    # 更新调度器
    scheduler.step()

plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.xlabel('epoch')
plt.ylabel('学习率')
plt.plot(range(1, epochs + 1), lr_list)
plt.title('间隔学习率衰减')
plt.show()

# 第 1 个 epoch 训练:
# 梯度: -6.812325954437256
# 跟新后的权重: 1.661960244178772
# 第 2 个 epoch 训练:
# 梯度: -4.547963619232178
# 跟新后的权重: 2.7714977264404297
# 第 3 个 epoch 训练:
# 梯度: -3.036256790161133
# 跟新后的权重: 3.5122342109680176
# 第 4 个 epoch 训练:
# 梯度: -2.0270285606384277
# 跟新后的权重: 4.006755828857422
# 第 5 个 epoch 训练:
# 梯度: -1.3532600402832031
# 跟新后的权重: 4.336902618408203
# 第 6 个 epoch 训练:
# 梯度: -0.903447151184082
# 跟新后的权重: 4.557311058044434
# ...
# 第 95 个 epoch 训练:
# 梯度: -0.0032291412353515625
# 跟新后的权重: 4.998385429382324
# 第 96 个 epoch 训练:
# 梯度: -0.0032291412353515625
# 跟新后的权重: 4.998385429382324
# 第 97 个 epoch 训练:
# 梯度: -0.0032291412353515625
# 跟新后的权重: 4.998385429382324
# 第 98 个 epoch 训练:
# 梯度: -0.0032291412353515625
# 跟新后的权重: 4.998385429382324
# 第 99 个 epoch 训练:
# 梯度: -0.0032291412353515625
# 跟新后的权重: 4.998385429382324
# 第 100 个 epoch 训练:
# 梯度: -0.0032291412353515625
# 跟新后的权重: 4.998385429382324
相关推荐
Wis4e3 小时前
基于PyTorch的深度学习——迁移学习1
pytorch·深度学习·机器学习
北山小恐龙4 小时前
针对性模型压缩:YOLOv8n安全帽检测模型剪枝方案
人工智能·深度学习·算法·计算机视觉·剪枝
Wis4e4 小时前
基于PyTorch的深度学习——迁移学习2
pytorch·深度学习·迁移学习
从负无穷开始的三次元代码生活4 小时前
深度学习知识点概念速通——人工智能专业考试基础知识点
人工智能·深度学习
BB_CC_DD14 小时前
超简单搭建AI去水印和图像修复算法lama-cleaner二
人工智能·深度学习
高洁0115 小时前
DNN案例一步步构建深层神经网络(二)
人工智能·python·深度学习·算法·机器学习
Coding茶水间15 小时前
基于深度学习的螺栓螺母检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
AI小怪兽16 小时前
RF-DETR:实时检测Transformer的神经架构搜索,首个突破 60 AP 的实时检测器 | ICLR 2026 in Submission
人工智能·深度学习·yolo·目标检测·架构·transformer
【建模先锋】16 小时前
故障诊断模型讲解:基于1D-CNN、2D-CNN分类模型的详细教程!
人工智能·深度学习·分类·cnn·卷积神经网络·故障诊断·轴承故障诊断