与 MultiStepLR指定间隔学习率衰减 类似的 StepLR等间隔学习率衰减,链接如下:
文章目录
- [1、基本介绍 & API](#1、基本介绍 & API)
- [2、scheduler 属性 / 方法](#2、scheduler 属性 / 方法)
- [3、代码 & 学习率趋势图:](#3、代码 & 学习率趋势图:)
1、基本介绍 & API
"指定间隔学习率衰减" 在 PyTorch 中,这通常指的是:
torch.optim.lr_scheduler.MultiStepLR
📌 什么是"指定间隔学习率衰减"?
与 StepLR 每隔固定 epoch 衰减不同,"指定间隔学习率衰减"允许你在预设的、任意的 epoch 节点上衰减学习率。这些节点不需要等距,完全由你指定。
- 例如:在第 30、80、120 个 epoch 时衰减学习率。
- 这在经典论文(如 ResNet、ImageNet 训练)中非常常见。
因此,它也被称为 "多步衰减"(Multi-Step Decay)。
🔧 PyTorch API:MultiStepLR
构造函数
python
torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones, # 【 milestones n.里程碑 】
gamma=0.1,
last_epoch=-1,
verbose=False
)
📘 参数详解
| 参数 | 类型 | 说明 |
|---|---|---|
optimizer |
Optimizer |
绑定的优化器(必需) |
milestones |
list of int |
关键参数 :学习率衰减发生的 epoch 列表,必须严格递增 (如 [30, 80]) |
gamma |
float |
衰减系数,每次衰减时 lr = lr * gamma(默认 0.1) |
last_epoch |
int |
上一个 epoch 编号,用于恢复训练(默认 -1) |
verbose |
bool |
是否打印学习率更新日志(PyTorch ≥1.9) |
⚠️ 注意:
milestones中的 epoch 是从 0 开始计数 的,且调度器会在 该 epoch 结束后 应用衰减。
🌰 举个例子
python
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = MultiStepLR(optimizer, milestones=[10, 25, 40], gamma=0.5)
学习率变化如下:
| Epoch 范围 | 学习率 |
|---|---|
| 0 -- 9 | 0.1 |
| 10 -- 24 | 0.05 |
| 25 -- 39 | 0.025 |
| 40+ | 0.0125 |
✅ 衰减发生在 epoch=10、25、40 的训练结束后,从下一个 epoch 开始使用新学习率。
📈 可视化效果(对比 StepLR)
StepLR(step_size=10):衰减点为 10, 20, 30, 40, 50...(等距)MultiStepLR(milestones=[10, 25, 40]):衰减点为 10, 25, 40(不等距,按需设置)
这种灵活性使得你可以:
- 在验证损失 plateau 时手动设定衰减点;
- 复现经典论文的训练策略;
- 在训练后期更精细地控制收敛。
✅ 使用示例(完整)
python
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import matplotlib.pyplot as plt
# 虚拟模型参数
param = torch.tensor(0.0, requires_grad=True)
optimizer = optim.SGD([param], lr=1.0)
# 在第 5、12、18 个 epoch 衰减学习率
scheduler = MultiStepLR(optimizer, milestones=[5, 12, 18], gamma=0.3)
lr_list = []
for epoch in range(25):
# 模拟训练步骤
optimizer.zero_grad()
loss = (param - 1) ** 2
loss.backward()
optimizer.step()
# 记录当前学习率
lr_list.append(scheduler.get_last_lr()[0])
# 更新调度器
scheduler.step()
# 绘图
plt.plot(range(25), lr_list, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('MultiStepLR: 指定间隔学习率衰减')
plt.grid(True)
plt.show()
输出曲线将显示在 epoch=5、12、18 处出现阶梯式下降。
StepLR vs MultiStepLR
| 特性 | StepLR |
MultiStepLR |
|---|---|---|
| 衰减间隔 | 固定(等距) | 自定义(不等距) |
| 控制粒度 | 粗 | 细 |
| 典型用途 | 简单实验、基线 | 论文复现、精细调优 |
| 参数 | step_size |
milestones(列表) |
| 灵活性 | 低 | 高 |
💡 实际上,
StepLR是MultiStepLR的特例:
StepLR(step_size=T)≈MultiStepLR(milestones=[T, 2T, 3T, ...])
⚠️ 注意事项
milestones必须递增,否则会报错。- 不要把
milestones设得太大(超过总 epoch 数),否则不会触发衰减。 - 和
StepLR一样,必须每个 epoch 调用一次scheduler.step()。 - 调用顺序:先
optimizer.step(),再scheduler.step()。
📚 经典应用场景
-
ResNet 论文(ImageNet):
pythonscheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1) -
CIFAR-10 训练 :常在
[80, 120]或[100, 150]衰减。
✅ 总结
"指定间隔学习率衰减" =
MultiStepLR它让你可以在任意预设的 epoch 节点 上衰减学习率,比
StepLR更灵活,是深度学习训练中的标准工具之一。
当知道模型大概在哪些阶段会"卡住"或需要更小的学习率时,MultiStepLR 就是你的好帮手!
2、scheduler 属性 / 方法
下面专门详细介绍 MultiStepLR 创建的 scheduler 对象 所具备的常用属性和方法,全部基于 PyTorch 官方实现(适用于 PyTorch ≥1.4 的主流版本)。
📌 一、核心属性(只读)
| 属性 | 类型 | 说明 | 示例值 |
|---|---|---|---|
scheduler.optimizer |
torch.optim.Optimizer |
绑定的优化器对象 | <torch.optim.SGD object> |
scheduler.last_epoch |
int |
调度器内部记录的"已完成的 epoch 数" (即已调用 step() 的次数) 初始为 -1,每调用一次 step() 自增 1 |
第 0 次 step 后:last_epoch = 0 |
scheduler.base_lrs |
List[float] |
初始化时从 optimizer 中保存的原始学习率列表 (每个参数组一个 lr,通常长度为 1) |
[0.01] |
scheduler.milestones |
List[int] |
用户传入的衰减里程碑列表 ⚠️ 必须严格递增(如 [10, 30, 60]) |
[10, 40, 80] |
scheduler.gamma |
float |
学习率衰减系数 在每个 milestone 处执行:lr = lr * gamma |
0.5 |
✅ 这些属性都可以直接访问,例如:
pythonprint(scheduler.milestones) # [10, 40, 80] print(scheduler.gamma) # 0.5
📌 二、核心方法
| 方法 | 返回值 | 说明 | 使用示例 |
|---|---|---|---|
scheduler.step(epoch=None) |
None |
推进调度器一步: - 无参:自动将 last_epoch += 1,并更新 lr - 有参(如 epoch=25):强制设置 last_epoch = epoch,然后更新 lr |
python<br>scheduler.step() # 标准用法<br>scheduler.step(epoch=50) # 跳转到第 50 轮(慎用) |
scheduler.get_last_lr() |
List[float] |
返回上一次 step() 后生效的学习率 (即当前正在使用的 lr) ✅ 这是获取当前 lr 的推荐方式 |
python<br>current_lr = scheduler.get_last_lr()[0]<br>print(f"LR: {current_lr}") |
scheduler.state_dict() |
Dict[str, Any] |
返回调度器当前状态字典,包含: - 'last_epoch' - 其他内部状态(如有) |
python<br>ckpt = {'sched': scheduler.state_dict()}<br>torch.save(ckpt, 'model.pth') |
scheduler.load_state_dict(state_dict) |
None |
从字典恢复调度器状态(用于断点续训) | python<br>scheduler.load_state_dict(ckpt['sched']) |
❌ 不推荐使用的方法
| 方法 | 问题 | 替代方案 |
|---|---|---|
scheduler.get_lr() |
在旧版本中用于计算"下一轮 lr",但行为不一致且易出错 PyTorch ≥1.4 已弃用其外部用途 | 改用 get_last_lr() |
⚠️ 即使你能调用
get_lr(),它返回的是"即将应用"的 lr(在step()内部使用),不要在训练循环中依赖它。
🧪 实际使用示例
python
import torch
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
# 初始化
w = torch.tensor([0.0], requires_grad=True)
optimizer = optim.SGD([w], lr=0.1)
scheduler = MultiStepLR(optimizer, milestones=[10, 25, 40], gamma=0.5)
# 查看属性
print("初始 base_lrs:", scheduler.base_lrs) # [0.1]
print("milestones:", scheduler.milestones) # [10, 25, 40]
print("gamma:", scheduler.gamma) # 0.5
# 模拟训练
for epoch in range(1, 51):
optimizer.zero_grad()
loss = (w - 5) ** 2
loss.backward()
optimizer.step()
# 获取当前学习率(推荐方式)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch}: LR = {current_lr:.6f}")
# 更新调度器
scheduler.step()
# 保存状态
state = scheduler.state_dict()
print("\n保存的状态:", state) # 包含 last_epoch 等
📝 小结:MultiStepLR scheduler 的关键点
- 衰减只发生在
milestones列表中的 epoch 结束后; last_epoch是调度器的"计数器",决定是否触发衰减;get_last_lr()是获取当前 lr 的唯一可靠方式;state_dict()+load_state_dict()支持完整断点续训;- 所有属性均为只读(不建议修改)。
这些属性和方法足以满足你在训练监控、日志记录、可视化和恢复训练等所有常见需求。
3、代码 & 学习率趋势图:
python
import torch
from torch import optim
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # ←←← 关键!放在最前面(解决报错)
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"] # 设置显示中文字体
mpl.rcParams["axes.unicode_minus"] = False # 设置正常显示符号
w = torch.tensor(data=[0.0], requires_grad=True, dtype=torch.float32)
optimizer = optim.SGD(params=[w], lr=0.01)
# 【 milestones n.里程碑 】
# gamma=0.5: lr = lr * 0.5
scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[10, 40, 80], gamma=0.5)
lr_list = []
# 训练循环
epochs = 100
batch_size = 20
for epoch in range(1, epochs + 1):
print(f'第 {epoch} 个 epoch 训练: ')
# 假装使用 batch
for batch in range(batch_size):
optimizer.zero_grad() # 1. 清零梯度
loss = (w - 5) ** 2 # 只在第10,20,...步计算梯度
loss.backward() # 4. 反向传播, 计算梯度
optimizer.step() # 5. 更新参数
print(f'梯度: {w.grad.item()}') # 只有一个梯度, 直接打印值就行了
print(f'跟新后的权重: {w.item()}')
# get_last_lr() 返回一个 list,即使只有一个参数组,也会返回 [0.01], 取出这个值就行
# 否则绘图时可能会出错(比如 plt.plot(...) 不知道如何处理嵌套列表)
# 记录当前学习率(取第一个)
lr_list.append(scheduler.get_last_lr()[0])
# 更新调度器
scheduler.step()
plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.xlabel('epoch')
plt.ylabel('学习率')
plt.plot(range(1, epochs + 1), lr_list)
plt.title('MultiStepLR: 指定间隔学习率衰减')
plt.show()
# 第 1 个 epoch 训练:
# 梯度: -6.812325954437256
# 跟新后的权重: 1.661960244178772
# 第 2 个 epoch 训练:
# 梯度: -4.547963619232178
# 跟新后的权重: 2.7714977264404297
# 第 3 个 epoch 训练:
# 梯度: -3.036256790161133
# 跟新后的权重: 3.5122342109680176
# 第 4 个 epoch 训练:
# 梯度: -2.0270285606384277
# 跟新后的权重: 4.006755828857422
# 第 5 个 epoch 训练:
# 梯度: -1.3532600402832031
# 跟新后的权重: 4.336902618408203
# 第 6 个 epoch 训练:
# 梯度: -0.903447151184082
# 跟新后的权重: 4.557311058044434
# 第 7 个 epoch 训练:
# 梯度: -0.6031484603881836
# 跟新后的权重: 4.7044572830200195
# 第 8 个 epoch 训练:
# 梯度: -0.4026670455932617
# 跟新后的权重: 4.8026933670043945
# 第 9 个 epoch 训练:
# 梯度: -0.26882171630859375
# 跟新后的权重: 4.868277549743652
# 第 10 个 epoch 训练:
# 梯度: -0.17946720123291016
# 跟新后的权重: 4.9120612144470215
# ...
# 第 95 个 epoch 训练:
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 96 个 epoch 训练:
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 97 个 epoch 训练:
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 98 个 epoch 训练:
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 99 个 epoch 训练:
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
# 第 100 个 epoch 训练:
# 梯度: -9.5367431640625e-05
# 跟新后的权重: 4.99995231628418
