Day 41 早停策略和模型权重的保存

@浙大疏锦行

一、早停策略核心原理

1. 为什么需要早停?

  • 过拟合:模型训练轮次过多时,会 "死记" 训练集噪声,验证集指标先升后降;
  • 效率提升:避免无效的训练轮次(比如验证集损失已最优,继续训练只是浪费资源);
  • 无需手动调训练轮次:自动找到 "最优训练轮数"。

2. 早停关键参数

参数 作用
patience 容忍多少轮验证集指标无提升(如patience=5:连续 5 轮没提升就停止)
min_delta 指标的 "最小提升幅度"(过滤微小波动,如min_delta=1e-4:提升 < 0.0001 视为无提升)
monitor 监控的指标(如val_loss:越小越好;val_acc:越大越好)
verbose 是否打印早停日志(如 "验证损失未提升,剩余 patience:4")

3. 模型权重保存的两种方式

保存类型 用途 保存内容 示例路径
最优模型权重 最终推理 / 部署 model.state_dict() best_model.pth
完整 Checkpoint 断点续训 + 早停恢复 模型 + 优化器 + epoch + 早停状态 + 指标 checkpoint.pth

二、核心逻辑解释

1. 早停类的核心设计

  • __call__ 方法:每轮验证后调用,简化调用逻辑(early_stopper(avg_val_loss, model, optimizer, epoch));
  • 区分监控指标:自动适配val_loss(越小越好)和val_acc(越大越好);
  • 双保存机制:
    • _save_best_model:仅保存模型权重(轻量,用于最终推理);
    • _save_checkpoint:保存完整状态(模型 + 优化器 + epoch + 早停计数器),用于续训。

2. 训练流程关键步骤

  1. 初始化早停 :指定监控指标(如val_loss)、patience 等参数;
  2. 续训检查:若有 checkpoint,自动加载并恢复训练状态;
  3. 每轮验证后调用早停:判断是否提升,提升则保存最优模型,未提升则累加计数器;
  4. 触发早停则终止训练:避免无效训练,直接加载最优模型。

3. 最优模型加载与推理

训练完成后,通过model.load_state_dict(torch.load('best_credit_model.pth'))加载最优模型,用于后续推理:

复制代码
# 最优模型推理示例
def infer(model, X_test):
    model.eval()
    with torch.no_grad():
        X_test_tensor = torch.from_numpy(X_test).to(DEVICE, dtype=torch.float32)
        preds = model(X_test_tensor)
        preds = (preds > 0.5).int().cpu().numpy()
    return preds

# 加载最优模型
best_model = CreditDefaultMLP(input_dim=20).to(DEVICE)
best_model.load_state_dict(torch.load('best_credit_model.pth'))

# 推理
X_test = np.random.randn(10, 20)  # 10个测试样本
preds = infer(best_model, X_test)
print("推理结果:", preds)

三、模型保存的 3 种核心形式

保存方式 保存内容 优点 缺点 适用场景
仅保存state_dict(推荐) model.state_dict()(模型权重) 轻量、灵活、版本兼容性好 需重建模型结构 推理部署、最优模型保存
保存 Checkpoint 模型 + 优化器 + 训练状态(epoch/loss) 支持断点续训、恢复完整状态 体积较大 长时训练、断点续训
保存完整模型 整个model对象 无需重建结构,加载简单 兼容性差(版本 / 设备) 临时测试、小模型

1. 仅保存模型权重(state_dict

state_dict是 PyTorch 模型的「参数字典」(key = 参数名,value = 参数张量),仅保存可训练参数(权重 / 偏置),不保存模型结构,是最灵活、轻量的方式。

(1)保存代码

复制代码
# 训练完成后,保存最优模型的权重(示例:早停后的最优模型)
torch.save(model.state_dict(), "best_credit_model.pth")

# 命名规范(推荐):含关键参数,便于区分
# torch.save(model.state_dict(), "best_mlp_input20_hidden128_64.pth")

(2)加载代码(核心:先重建结构,再加载权重)

加载时必须先重建和保存时完全一致的模型结构,否则会因参数不匹配报错:

复制代码
import torch
from your_model_file import CreditDefaultMLP  # 导入你的MLP模型类

# 设备配置
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Step1:重建模型结构(必须和保存时一致!)
# 示例:保存时input_dim=20,hidden_dims=[128,64],加载时参数必须完全相同
model = CreditDefaultMLP(input_dim=20, hidden_dims=[128, 64], dropout=0.2).to(DEVICE)

# Step2:加载权重(适配设备)
# 场景1:GPU保存 → CPU加载
model.load_state_dict(torch.load("best_credit_model.pth", map_location=torch.device("cpu")))

# 场景2:CPU保存 → GPU加载
model.load_state_dict(torch.load("best_credit_model.pth", map_location="cuda:0"))

# 场景3:默认加载(和保存时同设备)
model.load_state_dict(torch.load("best_credit_model.pth"))

# Step3:推理前必做!切换到评估模式(关闭Dropout/BatchNorm训练行为)
model.eval()

(3)推理示例(加载后使用)

复制代码
# 模拟测试数据(20维特征,和预处理一致)
X_test = np.random.randn(10, 20)  # 10个测试样本
X_test_tensor = torch.from_numpy(X_test).to(DEVICE, dtype=torch.float32)

# 关闭梯度计算(省显存、加速)
with torch.no_grad():
    preds = model(X_test_tensor)
    preds = (preds > 0.5).int().cpu().numpy()  # 转换为0/1标签(信用违约:1,未违约:0)

print("信用违约预测结果:", preds)

2. 保存 Checkpoint(断点续训专用)

Checkpoint 是「训练快照」,除模型权重外,还保存优化器状态、训练轮次、损失等,用于训练中断后续训(如服务器宕机、手动停止)。

(1)保存代码

复制代码
# 训练过程中保存Checkpoint(示例:每5轮保存/验证损失最优时保存)
checkpoint = {
    "model_state_dict": model.state_dict(),       # 模型权重
    "optimizer_state_dict": optimizer.state_dict(),  # 优化器状态(如Adam的动量)
    "epoch": epoch,                               # 当前训练轮次
    "train_loss": avg_train_loss,                 # 当前训练损失
    "val_loss": avg_val_loss,                     # 当前验证损失
    "best_val_loss": best_val_loss,               # 最优验证损失
    "early_stop_counter": early_stop_counter      # 早停计数器(可选)
}
torch.save(checkpoint, "credit_checkpoint.pth")

(2)加载 + 续训代码

复制代码
# Step1:重建模型和优化器结构(必须和保存时一致!)
model = CreditDefaultMLP(input_dim=20).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)  # 优化器参数也需一致

# Step2:加载Checkpoint
checkpoint = torch.load("credit_checkpoint.pth", map_location=DEVICE)

# Step3:恢复状态
model.load_state_dict(checkpoint["model_state_dict"])          # 恢复模型权重
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])  # 恢复优化器状态
start_epoch = checkpoint["epoch"] + 1                          # 从下一轮开始续训
best_val_loss = checkpoint["best_val_loss"]                    # 恢复最优损失
early_stop_counter = checkpoint.get("early_stop_counter", 0)   # 恢复早停计数器

# Step4:继续训练
for epoch in range(start_epoch, 100):  # 从断点到目标轮次
    train_one_epoch(model, optimizer, train_loader)  # 你的训练函数
    val_loss = evaluate(model, val_loader)            # 你的验证函数
    # ... 早停逻辑 ...

3. 保存完整模型

直接保存整个模型对象,加载时无需重建结构,但兼容性极差(PyTorch 版本、Python 环境、设备变化都可能导致加载失败)。

(1)保存代码

复制代码
# 不推荐!仅临时测试用
torch.save(model, "full_credit_model.pth")

(2)加载代码

复制代码
# 直接加载,无需重建结构
model = torch.load("full_credit_model.pth")

# 推理前仍需切换评估模式
model.eval()

作业

复制代码
import torch
import numpy as np
from your_model_file import CreditDefaultMLP
from your_early_stop_file import EarlyStopping

# 设备/种子配置
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SEED = 42
torch.manual_seed(SEED)

# ===================== 1. 训练+早停+保存 =====================
def train_and_save():
    # 初始化模型/优化器/早停
    model = CreditDefaultMLP(input_dim=20).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    early_stopper = EarlyStopping(patience=5, best_model_path="best_credit_model.pth")

    # 模拟训练(替换为你的真实训练逻辑)
    num_epochs = 100
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        # 训练轮次(省略)
        avg_train_loss = 0.1  # 模拟训练损失
        avg_val_loss = 0.08   # 模拟验证损失

        # 早停判断+保存最优模型
        early_stopper(avg_val_loss, model, optimizer, epoch)
        if early_stopper.early_stop:
            break

    # 保存Checkpoint(续训用)
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "best_val_loss": early_stopper.best_score
    }
    torch.save(checkpoint, "credit_checkpoint.pth")
    print("训练完成:最优模型+Checkpoint已保存")

# ===================== 2. 加载模型推理 =====================
def load_and_infer():
    # 重建模型结构
    model = CreditDefaultMLP(input_dim=20).to(DEVICE)
    # 加载最优权重
    model.load_state_dict(torch.load("best_credit_model.pth", map_location=DEVICE))
    # 评估模式
    model.eval()

    # 推理
    X_test = np.random.randn(10, 20)  # 10个测试样本
    X_test_tensor = torch.from_numpy(X_test).to(DEVICE, dtype=torch.float32)
    with torch.no_grad():
        preds = model(X_test_tensor)
        preds = (preds > 0.5).int().cpu().numpy()
    print("推理结果:", preds)

# ===================== 3. 加载Checkpoint续训 =====================
def load_checkpoint_resume():
    # 重建模型/优化器
    model = CreditDefaultMLP(input_dim=20).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # 加载Checkpoint
    checkpoint = torch.load("credit_checkpoint.pth", map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]

    # 继续训练
    for epoch in range(start_epoch, 100):
        # 续训逻辑(省略)
        print(f"续训Epoch:{epoch}")

# ===================== 执行 =====================
if __name__ == "__main__":
    train_and_save()       # 训练+保存
    load_and_infer()       # 加载最优模型推理
    # load_checkpoint_resume()  # 续训(按需执行)
相关推荐
傅里叶的耶8 小时前
C++ Primer Plus(第6版):第四章 复合类型
开发语言·c++
MediaTea8 小时前
Python:接口隔离原则(ISP)
开发语言·网络·python·接口隔离原则
Clarence Liu8 小时前
Golang slice 深度原理与面试指南
开发语言·后端·golang
遇印记8 小时前
java期末复习(构造方法和成员方法,重写和重载)
java·开发语言·学习
weixin_307779138 小时前
Jenkins声明式流水线权威指南:从Model API基础到高级实践
开发语言·ci/cd·自动化·jenkins·etl
Aevget8 小时前
DevExtreme JS & ASP.NET Core v25.2预览 - DataGrid/TreeList全新升级
开发语言·javascript·asp.net·界面控件·ui开发·devextreme
破烂pan8 小时前
Elasticsearch 8.x + Python 官方客户端实战教程
python·elasticsearch
海涛高软8 小时前
Qt菜单项切换主界面
开发语言·qt
码界奇点8 小时前
基于Golang与Vue3的全栈博客系统设计与实现
开发语言·后端·golang·车载系统·毕业设计·源代码管理