一、早停策略核心原理
1. 为什么需要早停?
- 过拟合:模型训练轮次过多时,会 "死记" 训练集噪声,验证集指标先升后降;
- 效率提升:避免无效的训练轮次(比如验证集损失已最优,继续训练只是浪费资源);
- 无需手动调训练轮次:自动找到 "最优训练轮数"。
2. 早停关键参数
| 参数 | 作用 |
|---|---|
patience |
容忍多少轮验证集指标无提升(如patience=5:连续 5 轮没提升就停止) |
min_delta |
指标的 "最小提升幅度"(过滤微小波动,如min_delta=1e-4:提升 < 0.0001 视为无提升) |
monitor |
监控的指标(如val_loss:越小越好;val_acc:越大越好) |
verbose |
是否打印早停日志(如 "验证损失未提升,剩余 patience:4") |
3. 模型权重保存的两种方式
| 保存类型 | 用途 | 保存内容 | 示例路径 |
|---|---|---|---|
| 最优模型权重 | 最终推理 / 部署 | 仅model.state_dict() |
best_model.pth |
| 完整 Checkpoint | 断点续训 + 早停恢复 | 模型 + 优化器 + epoch + 早停状态 + 指标 | checkpoint.pth |
二、核心逻辑解释
1. 早停类的核心设计
__call__方法:每轮验证后调用,简化调用逻辑(early_stopper(avg_val_loss, model, optimizer, epoch));- 区分监控指标:自动适配
val_loss(越小越好)和val_acc(越大越好); - 双保存机制:
_save_best_model:仅保存模型权重(轻量,用于最终推理);_save_checkpoint:保存完整状态(模型 + 优化器 + epoch + 早停计数器),用于续训。
2. 训练流程关键步骤
- 初始化早停 :指定监控指标(如
val_loss)、patience 等参数; - 续训检查:若有 checkpoint,自动加载并恢复训练状态;
- 每轮验证后调用早停:判断是否提升,提升则保存最优模型,未提升则累加计数器;
- 触发早停则终止训练:避免无效训练,直接加载最优模型。
3. 最优模型加载与推理
训练完成后,通过model.load_state_dict(torch.load('best_credit_model.pth'))加载最优模型,用于后续推理:
# 最优模型推理示例
def infer(model, X_test):
model.eval()
with torch.no_grad():
X_test_tensor = torch.from_numpy(X_test).to(DEVICE, dtype=torch.float32)
preds = model(X_test_tensor)
preds = (preds > 0.5).int().cpu().numpy()
return preds
# 加载最优模型
best_model = CreditDefaultMLP(input_dim=20).to(DEVICE)
best_model.load_state_dict(torch.load('best_credit_model.pth'))
# 推理
X_test = np.random.randn(10, 20) # 10个测试样本
preds = infer(best_model, X_test)
print("推理结果:", preds)
三、模型保存的 3 种核心形式
| 保存方式 | 保存内容 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|---|
仅保存state_dict(推荐) |
model.state_dict()(模型权重) |
轻量、灵活、版本兼容性好 | 需重建模型结构 | 推理部署、最优模型保存 |
| 保存 Checkpoint | 模型 + 优化器 + 训练状态(epoch/loss) | 支持断点续训、恢复完整状态 | 体积较大 | 长时训练、断点续训 |
| 保存完整模型 | 整个model对象 |
无需重建结构,加载简单 | 兼容性差(版本 / 设备) | 临时测试、小模型 |
1. 仅保存模型权重(state_dict)
state_dict是 PyTorch 模型的「参数字典」(key = 参数名,value = 参数张量),仅保存可训练参数(权重 / 偏置),不保存模型结构,是最灵活、轻量的方式。
(1)保存代码
# 训练完成后,保存最优模型的权重(示例:早停后的最优模型)
torch.save(model.state_dict(), "best_credit_model.pth")
# 命名规范(推荐):含关键参数,便于区分
# torch.save(model.state_dict(), "best_mlp_input20_hidden128_64.pth")
(2)加载代码(核心:先重建结构,再加载权重)
加载时必须先重建和保存时完全一致的模型结构,否则会因参数不匹配报错:
import torch
from your_model_file import CreditDefaultMLP # 导入你的MLP模型类
# 设备配置
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Step1:重建模型结构(必须和保存时一致!)
# 示例:保存时input_dim=20,hidden_dims=[128,64],加载时参数必须完全相同
model = CreditDefaultMLP(input_dim=20, hidden_dims=[128, 64], dropout=0.2).to(DEVICE)
# Step2:加载权重(适配设备)
# 场景1:GPU保存 → CPU加载
model.load_state_dict(torch.load("best_credit_model.pth", map_location=torch.device("cpu")))
# 场景2:CPU保存 → GPU加载
model.load_state_dict(torch.load("best_credit_model.pth", map_location="cuda:0"))
# 场景3:默认加载(和保存时同设备)
model.load_state_dict(torch.load("best_credit_model.pth"))
# Step3:推理前必做!切换到评估模式(关闭Dropout/BatchNorm训练行为)
model.eval()
(3)推理示例(加载后使用)
# 模拟测试数据(20维特征,和预处理一致)
X_test = np.random.randn(10, 20) # 10个测试样本
X_test_tensor = torch.from_numpy(X_test).to(DEVICE, dtype=torch.float32)
# 关闭梯度计算(省显存、加速)
with torch.no_grad():
preds = model(X_test_tensor)
preds = (preds > 0.5).int().cpu().numpy() # 转换为0/1标签(信用违约:1,未违约:0)
print("信用违约预测结果:", preds)
2. 保存 Checkpoint(断点续训专用)
Checkpoint 是「训练快照」,除模型权重外,还保存优化器状态、训练轮次、损失等,用于训练中断后续训(如服务器宕机、手动停止)。
(1)保存代码
# 训练过程中保存Checkpoint(示例:每5轮保存/验证损失最优时保存)
checkpoint = {
"model_state_dict": model.state_dict(), # 模型权重
"optimizer_state_dict": optimizer.state_dict(), # 优化器状态(如Adam的动量)
"epoch": epoch, # 当前训练轮次
"train_loss": avg_train_loss, # 当前训练损失
"val_loss": avg_val_loss, # 当前验证损失
"best_val_loss": best_val_loss, # 最优验证损失
"early_stop_counter": early_stop_counter # 早停计数器(可选)
}
torch.save(checkpoint, "credit_checkpoint.pth")
(2)加载 + 续训代码
# Step1:重建模型和优化器结构(必须和保存时一致!)
model = CreditDefaultMLP(input_dim=20).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 优化器参数也需一致
# Step2:加载Checkpoint
checkpoint = torch.load("credit_checkpoint.pth", map_location=DEVICE)
# Step3:恢复状态
model.load_state_dict(checkpoint["model_state_dict"]) # 恢复模型权重
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # 恢复优化器状态
start_epoch = checkpoint["epoch"] + 1 # 从下一轮开始续训
best_val_loss = checkpoint["best_val_loss"] # 恢复最优损失
early_stop_counter = checkpoint.get("early_stop_counter", 0) # 恢复早停计数器
# Step4:继续训练
for epoch in range(start_epoch, 100): # 从断点到目标轮次
train_one_epoch(model, optimizer, train_loader) # 你的训练函数
val_loss = evaluate(model, val_loader) # 你的验证函数
# ... 早停逻辑 ...
3. 保存完整模型
直接保存整个模型对象,加载时无需重建结构,但兼容性极差(PyTorch 版本、Python 环境、设备变化都可能导致加载失败)。
(1)保存代码
# 不推荐!仅临时测试用
torch.save(model, "full_credit_model.pth")
(2)加载代码
# 直接加载,无需重建结构
model = torch.load("full_credit_model.pth")
# 推理前仍需切换评估模式
model.eval()
作业
import torch
import numpy as np
from your_model_file import CreditDefaultMLP
from your_early_stop_file import EarlyStopping
# 设备/种子配置
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SEED = 42
torch.manual_seed(SEED)
# ===================== 1. 训练+早停+保存 =====================
def train_and_save():
# 初始化模型/优化器/早停
model = CreditDefaultMLP(input_dim=20).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
early_stopper = EarlyStopping(patience=5, best_model_path="best_credit_model.pth")
# 模拟训练(替换为你的真实训练逻辑)
num_epochs = 100
best_val_loss = float('inf')
for epoch in range(num_epochs):
# 训练轮次(省略)
avg_train_loss = 0.1 # 模拟训练损失
avg_val_loss = 0.08 # 模拟验证损失
# 早停判断+保存最优模型
early_stopper(avg_val_loss, model, optimizer, epoch)
if early_stopper.early_stop:
break
# 保存Checkpoint(续训用)
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"best_val_loss": early_stopper.best_score
}
torch.save(checkpoint, "credit_checkpoint.pth")
print("训练完成:最优模型+Checkpoint已保存")
# ===================== 2. 加载模型推理 =====================
def load_and_infer():
# 重建模型结构
model = CreditDefaultMLP(input_dim=20).to(DEVICE)
# 加载最优权重
model.load_state_dict(torch.load("best_credit_model.pth", map_location=DEVICE))
# 评估模式
model.eval()
# 推理
X_test = np.random.randn(10, 20) # 10个测试样本
X_test_tensor = torch.from_numpy(X_test).to(DEVICE, dtype=torch.float32)
with torch.no_grad():
preds = model(X_test_tensor)
preds = (preds > 0.5).int().cpu().numpy()
print("推理结果:", preds)
# ===================== 3. 加载Checkpoint续训 =====================
def load_checkpoint_resume():
# 重建模型/优化器
model = CreditDefaultMLP(input_dim=20).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 加载Checkpoint
checkpoint = torch.load("credit_checkpoint.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
best_val_loss = checkpoint["best_val_loss"]
# 继续训练
for epoch in range(start_epoch, 100):
# 续训逻辑(省略)
print(f"续训Epoch:{epoch}")
# ===================== 执行 =====================
if __name__ == "__main__":
train_and_save() # 训练+保存
load_and_infer() # 加载最优模型推理
# load_checkpoint_resume() # 续训(按需执行)