PyTorch MultiStepLR:指定间隔学习率衰减的原理、API、参数详解、实战

与 MultiStepLR指定间隔学习率衰减 类似的 StepLR等间隔学习率衰减,链接如下:

PyTorch StepLR:等间隔学习率衰减的原理与实战

文章目录

  • [1、基本介绍 & API](#1、基本介绍 & API)
  • [2、scheduler 属性 / 方法](#2、scheduler 属性 / 方法)
  • [3、代码 & 学习率趋势图:](#3、代码 & 学习率趋势图:)

1、基本介绍 & API

"指定间隔学习率衰减" 在 PyTorch 中,这通常指的是:

torch.optim.lr_scheduler.MultiStepLR


📌 什么是"指定间隔学习率衰减"?

StepLR 每隔固定 epoch 衰减不同,"指定间隔学习率衰减"允许你在预设的、任意的 epoch 节点上衰减学习率。这些节点不需要等距,完全由你指定。

  • 例如:在第 30、80、120 个 epoch 时衰减学习率。
  • 这在经典论文(如 ResNet、ImageNet 训练)中非常常见。

因此,它也被称为 "多步衰减"(Multi-Step Decay)


🔧 PyTorch API:MultiStepLR

构造函数

python 复制代码
torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones,    # 【 milestones   n.里程碑 】
    gamma=0.1,
    last_epoch=-1,
    verbose=False
)

📘 参数详解

参数 类型 说明
optimizer Optimizer 绑定的优化器(必需)
milestones list of int 关键参数 :学习率衰减发生的 epoch 列表,必须严格递增 (如 [30, 80]
gamma float 衰减系数,每次衰减时 lr = lr * gamma(默认 0.1
last_epoch int 上一个 epoch 编号,用于恢复训练(默认 -1
verbose bool 是否打印学习率更新日志(PyTorch ≥1.9)

⚠️ 注意:milestones 中的 epoch 是从 0 开始计数 的,且调度器会在 该 epoch 结束后 应用衰减。


🌰 举个例子

python 复制代码
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = MultiStepLR(optimizer, milestones=[10, 25, 40], gamma=0.5)

学习率变化如下:

Epoch 范围 学习率
0 -- 9 0.1
10 -- 24 0.05
25 -- 39 0.025
40+ 0.0125

✅ 衰减发生在 epoch=10、25、40 的训练结束后,从下一个 epoch 开始使用新学习率。


📈 可视化效果(对比 StepLR)

  • StepLR(step_size=10):衰减点为 10, 20, 30, 40, 50...(等距)
  • MultiStepLR(milestones=[10, 25, 40]):衰减点为 10, 25, 40(不等距,按需设置)

这种灵活性使得你可以:

  • 在验证损失 plateau 时手动设定衰减点;
  • 复现经典论文的训练策略;
  • 在训练后期更精细地控制收敛。

✅ 使用示例(完整)

python 复制代码
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import matplotlib.pyplot as plt

# 虚拟模型参数
param = torch.tensor(0.0, requires_grad=True)
optimizer = optim.SGD([param], lr=1.0)

# 在第 5、12、18 个 epoch 衰减学习率
scheduler = MultiStepLR(optimizer, milestones=[5, 12, 18], gamma=0.3)

lr_list = []
for epoch in range(25):
    # 模拟训练步骤
    optimizer.zero_grad()
    loss = (param - 1) ** 2
    loss.backward()
    optimizer.step()
    
    # 记录当前学习率
    lr_list.append(scheduler.get_last_lr()[0])
    
    # 更新调度器
    scheduler.step()

# 绘图
plt.plot(range(25), lr_list, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('MultiStepLR: 指定间隔学习率衰减')
plt.grid(True)
plt.show()

输出曲线将显示在 epoch=5、12、18 处出现阶梯式下降


StepLR vs MultiStepLR

特性 StepLR MultiStepLR
衰减间隔 固定(等距) 自定义(不等距)
控制粒度
典型用途 简单实验、基线 论文复现、精细调优
参数 step_size milestones(列表)
灵活性

💡 实际上,StepLRMultiStepLR 的特例:
StepLR(step_size=T)MultiStepLR(milestones=[T, 2T, 3T, ...])


⚠️ 注意事项

  1. milestones 必须递增,否则会报错。
  2. 不要把 milestones 设得太大(超过总 epoch 数),否则不会触发衰减。
  3. StepLR 一样,必须每个 epoch 调用一次 scheduler.step()
  4. 调用顺序:先 optimizer.step(),再 scheduler.step()

📚 经典应用场景

  • ResNet 论文(ImageNet):

    python 复制代码
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)
  • CIFAR-10 训练 :常在 [80, 120][100, 150] 衰减。


✅ 总结

"指定间隔学习率衰减" = MultiStepLR

它让你可以在任意预设的 epoch 节点 上衰减学习率,比 StepLR 更灵活,是深度学习训练中的标准工具之一。

当知道模型大概在哪些阶段会"卡住"或需要更小的学习率时,MultiStepLR 就是你的好帮手!

2、scheduler 属性 / 方法

下面专门详细介绍 MultiStepLR 创建的 scheduler 对象 所具备的常用属性和方法,全部基于 PyTorch 官方实现(适用于 PyTorch ≥1.4 的主流版本)。


📌 一、核心属性(只读)

属性 类型 说明 示例值
scheduler.optimizer torch.optim.Optimizer 绑定的优化器对象 <torch.optim.SGD object>
scheduler.last_epoch int 调度器内部记录的"已完成的 epoch 数" (即已调用 step() 的次数) 初始为 -1,每调用一次 step() 自增 1 第 0 次 step 后:last_epoch = 0
scheduler.base_lrs List[float] 初始化时从 optimizer 中保存的原始学习率列表 (每个参数组一个 lr,通常长度为 1) [0.01]
scheduler.milestones List[int] 用户传入的衰减里程碑列表 ⚠️ 必须严格递增(如 [10, 30, 60] [10, 40, 80]
scheduler.gamma float 学习率衰减系数 在每个 milestone 处执行:lr = lr * gamma 0.5

✅ 这些属性都可以直接访问,例如:

python 复制代码
print(scheduler.milestones)  # [10, 40, 80]
print(scheduler.gamma)       # 0.5

📌 二、核心方法

方法 返回值 说明 使用示例
scheduler.step(epoch=None) None 推进调度器一步: - 无参:自动将 last_epoch += 1,并更新 lr - 有参(如 epoch=25):强制设置 last_epoch = epoch,然后更新 lr python<br>scheduler.step() # 标准用法<br>scheduler.step(epoch=50) # 跳转到第 50 轮(慎用)
scheduler.get_last_lr() List[float] 返回上一次 step() 后生效的学习率 (即当前正在使用的 lr) ✅ 这是获取当前 lr 的推荐方式 python<br>current_lr = scheduler.get_last_lr()[0]<br>print(f"LR: {current_lr}")
scheduler.state_dict() Dict[str, Any] 返回调度器当前状态字典,包含: - 'last_epoch' - 其他内部状态(如有) python<br>ckpt = {'sched': scheduler.state_dict()}<br>torch.save(ckpt, 'model.pth')
scheduler.load_state_dict(state_dict) None 从字典恢复调度器状态(用于断点续训) python<br>scheduler.load_state_dict(ckpt['sched'])

❌ 不推荐使用的方法

方法 问题 替代方案
scheduler.get_lr() 在旧版本中用于计算"下一轮 lr",但行为不一致且易出错 PyTorch ≥1.4 已弃用其外部用途 改用 get_last_lr()

⚠️ 即使你能调用 get_lr(),它返回的是"即将应用"的 lr(在 step() 内部使用),不要在训练循环中依赖它


🧪 实际使用示例

python 复制代码
import torch
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

# 初始化
w = torch.tensor([0.0], requires_grad=True)
optimizer = optim.SGD([w], lr=0.1)
scheduler = MultiStepLR(optimizer, milestones=[10, 25, 40], gamma=0.5)

# 查看属性
print("初始 base_lrs:", scheduler.base_lrs)        # [0.1]
print("milestones:", scheduler.milestones)         # [10, 25, 40]
print("gamma:", scheduler.gamma)                   # 0.5

# 模拟训练
for epoch in range(1, 51):
    optimizer.zero_grad()
    loss = (w - 5) ** 2
    loss.backward()
    optimizer.step()

    # 获取当前学习率(推荐方式)
    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch}: LR = {current_lr:.6f}")

    # 更新调度器
    scheduler.step()

# 保存状态
state = scheduler.state_dict()
print("\n保存的状态:", state)  # 包含 last_epoch 等

📝 小结:MultiStepLR scheduler 的关键点

  • 衰减只发生在 milestones 列表中的 epoch 结束后
  • last_epoch 是调度器的"计数器",决定是否触发衰减;
  • get_last_lr() 是获取当前 lr 的唯一可靠方式
  • state_dict() + load_state_dict() 支持完整断点续训
  • 所有属性均为只读(不建议修改)

这些属性和方法足以满足你在训练监控、日志记录、可视化和恢复训练等所有常见需求。

3、代码 & 学习率趋势图:

python 复制代码
import torch
from torch import optim
import matplotlib.pyplot as plt

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # ←←← 关键!放在最前面(解决报错)
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"]    # 设置显示中文字体
mpl.rcParams["axes.unicode_minus"] = False      # 设置正常显示符号

w = torch.tensor(data=[0.0], requires_grad=True, dtype=torch.float32)

optimizer = optim.SGD(params=[w], lr=0.01)

# 【 milestones   n.里程碑 】
# gamma=0.5: lr = lr * 0.5
scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[10, 40, 80], gamma=0.5)

lr_list = []

# 训练循环
epochs = 100
batch_size = 20
for epoch in range(1, epochs + 1):
    print(f'第 {epoch} 个 epoch 训练: ')

    # 假装使用 batch
    for batch in range(batch_size):
        optimizer.zero_grad()     # 1. 清零梯度
        loss = (w - 5) ** 2  # 只在第10,20,...步计算梯度
        loss.backward()           # 4. 反向传播, 计算梯度
        optimizer.step()          # 5. 更新参数

    print(f'梯度: {w.grad.item()}')  # 只有一个梯度, 直接打印值就行了
    print(f'跟新后的权重: {w.item()}')

    # get_last_lr() 返回一个 list,即使只有一个参数组,也会返回 [0.01], 取出这个值就行
    # 否则绘图时可能会出错(比如 plt.plot(...) 不知道如何处理嵌套列表)
    # 记录当前学习率(取第一个)
    lr_list.append(scheduler.get_last_lr()[0])

    # 更新调度器
    scheduler.step()

plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.xlabel('epoch')
plt.ylabel('学习率')
plt.plot(range(1, epochs + 1), lr_list)
plt.title('MultiStepLR: 指定间隔学习率衰减')
plt.show()

# 第 1 个 epoch 训练: 
# 梯度: -6.812325954437256
# 跟新后的权重: 1.661960244178772
# 第 2 个 epoch 训练: 
# 梯度: -4.547963619232178
# 跟新后的权重: 2.7714977264404297
# 第 3 个 epoch 训练: 
# 梯度: -3.036256790161133
# 跟新后的权重: 3.5122342109680176
# 第 4 个 epoch 训练: 
# 梯度: -2.0270285606384277
# 跟新后的权重: 4.006755828857422
# 第 5 个 epoch 训练: 
# 梯度: -1.3532600402832031
# 跟新后的权重: 4.336902618408203
# 第 6 个 epoch 训练: 
# 梯度: -0.903447151184082
# 跟新后的权重: 4.557311058044434
# 第 7 个 epoch 训练: 
# 梯度: -0.6031484603881836
# 跟新后的权重: 4.7044572830200195
# 第 8 个 epoch 训练: 
# 梯度: -0.4026670455932617
# 跟新后的权重: 4.8026933670043945
# 第 9 个 epoch 训练: 
# 梯度: -0.26882171630859375
# 跟新后的权重: 4.868277549743652
# 第 10 个 epoch 训练: 
# 梯度: -0.17946720123291016
# 跟新后的权重: 4.9120612144470215
# ...
# 第 95 个 epoch 训练: 
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 96 个 epoch 训练: 
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 97 个 epoch 训练: 
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 98 个 epoch 训练: 
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 99 个 epoch 训练: 
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 100 个 epoch 训练: 
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
相关推荐
摸鱼仙人~2 小时前
大语言模型微调中的数据分布不均与长尾任务优化策略
人工智能·深度学习·机器学习
Wishell20152 小时前
日拱一卒之Python与matlab的内存读取区别
pytorch
AI即插即用2 小时前
即插即用系列 | CVPR 2024 RMT:既要全局感受野,又要 CNN 的局部性?一种拥有显式空间先验的线性 Transformer
人工智能·深度学习·神经网络·目标检测·计算机视觉·cnn·transformer
渡我白衣2 小时前
导论:什么是机器学习?——破除迷思,建立全景地图
人工智能·深度学习·神经网络·目标检测·microsoft·机器学习·自然语言处理
smile_Iris3 小时前
Day 45 简单CNN
人工智能·深度学习·cnn
渡我白衣3 小时前
计算机组成原理(8):各种码的作用详解
c++·人工智能·深度学习·神经网络·其他·机器学习
此处不留情3 小时前
从零构建智能水果识别系统:数据模块深度解析
人工智能·pytorch
小龙报4 小时前
【算法通关指南:算法基础篇 】双指针专题:1.唯一的雪花 2.逛画展 3.字符串 4.丢手绢
c语言·数据结构·c++·人工智能·深度学习·算法·信息与通信
万俟淋曦4 小时前
【论文速递】2025年第39周(Sep-21-27)(Robotics/Embodied AI/LLM)
人工智能·深度学习·机器学习·机器人·大模型·论文·具身智能