Day 41 早停策略和模型权重的保存

一、早停策略(Early Stopping)

1. 核心问题:为什么需要早停?

深度学习模型训练时,随着 epoch 增加,模型在训练集 上的误差会持续下降,但在验证集 上的误差会先下降(模型学习到泛化能力),后上升(模型开始过拟合训练数据)。早停的本质是:在验证集性能达到峰值时停止训练,避免模型继续学习训练集的噪声,从而保留泛化能力最强的模型状态。

2. 核心原理与关键参数

早停的核心逻辑是 "监控验证集指标,当指标不再提升时停止训练",需明确 3 个关键参数(缺一不可):

补充细节:
  • "指标改善" 的定义 :默认是 "严格优于历史最优",但实际中会设置min_delta(最小改善幅度),例如min_delta=0.001:只有当指标变化超过 0.001 时,才认为是 "改善",避免因微小波动误判。
  • 恢复最优权重 :早停时,模型的最后一个 epoch 权重可能不是最优的(因为patience期间指标已下降),因此需要在训练中实时保存验证集最优的权重,早停后加载该权重。

3. 常见实现方式

早停通常通过框架自带的回调函数(Callback)实现,无需手动编写逻辑,主流框架(TensorFlow/Keras、PyTorch Lightning)均支持:

(1)TensorFlow/Keras 实现

Keras 内置EarlyStopping回调函数,直接传入训练的callbacks列表即可:

复制代码
from tensorflow.keras.callbacks import EarlyStopping

# 定义早停策略
early_stopping = EarlyStopping(
    monitor='val_loss',        # 监控验证集损失
    patience=5,                # 连续5个epoch无改善则停止
    min_delta=0.0001,          # 最小改善幅度(避免微小波动)
    mode='min',                # 损失越小越好
    restore_best_weights=True  # 早停后恢复验证集最优的权重(关键!)
)

# 训练时传入callbacks
model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),  # 必须有验证集,否则早停无意义
    epochs=100,  # 最大epoch数(早停会提前终止)
    batch_size=32,
    callbacks=[early_stopping]  # 加入早停回调
)
(2)PyTorch 实现(需手动逻辑或用 Lightning)

PyTorch 原生无内置早停,需手动记录验证集指标并判断,或使用PyTorch LightningEarlyStopping

复制代码
# PyTorch Lightning 实现(推荐,简洁高效)
from pytorch_lightning.callbacks import EarlyStopping

# 定义早停策略
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    min_delta=0.0001,
    mode='min',
    restore_best_weights=True
)

# 训练时传入callbacks
trainer = Trainer(callbacks=[early_stopping], max_epochs=100)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
(3)PyTorch 原生手动实现(了解逻辑)
复制代码
import torch

# 初始化参数
best_val_loss = float('inf')
patience = 5
current_patience = 0
max_epochs = 100

for epoch in range(max_epochs):
    # 训练步骤
    model.train()
    train_loss = train_one_epoch(model, train_loader)
    
    # 验证步骤
    model.eval()
    with torch.no_grad():
        val_loss = val_one_epoch(model, val_loader)
    
    # 早停判断
    if val_loss < best_val_loss - 0.0001:  # 满足最小改善幅度
        best_val_loss = val_loss
        current_patience = 0  # 重置耐心值
        # 保存最优权重(见下文"模型权重保存")
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        current_patience += 1
        if current_patience >= patience:
            print(f"早停触发:epoch {epoch+1},验证损失无改善")
            break  # 停止训练

4. 早停的注意事项

2. 保存的核心内容

深度学习模型的 "权重" 本质是模型中可学习的参数(如卷积核、全连接层的权重矩阵)

  • 必须有独立验证集:验证集不能与训练集重叠,否则无法反映泛化能力(早停会失效)。

  • 避免监控训练集指标 :若监控loss(训练集损失),早停会永远不触发(训练损失持续下降),导致过拟合。

  • restore_best_weights的重要性 :若不设置为True,早停后模型会保留 "最后一个 epoch" 的权重(可能已过拟合),而非 "验证集最优" 的权重。

  • patience的选择 :根据任务调整,简单任务(如 MNIST 分类)可设3-5,复杂任务(如 CNN 图像分割)可设10-20(避免因指标波动误停)

    二、模型权重保存(Model Checkpointing)

    1. 核心目的:为什么要保存权重?

  • 保留最优模型:训练过程中验证集性能最好的权重(用于最终部署)。

  • 断点续训:训练中断(如服务器宕机、手动停止)后,可加载中间权重继续训练,无需从头开始。

  • 复现实验:保存权重便于后续复现结果、微调模型。

3. 主流框架实现

(1)TensorFlow/Keras 保存权重

Keras 提供ModelCheckpoint回调函数,可与早停搭配,自动保存最优权重

复制代码
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# 定义权重保存回调(保存验证集最优权重)
checkpoint = ModelCheckpoint(
    filepath='best_model_keras.h5',  # 保存路径(.h5格式)
    monitor='val_loss',              # 与早停监控同一指标
    mode='min',
    save_best_only=True,             # 只保存最优模型(关键)
    save_weights_only=False,         # False:保存整个模型(结构+权重);True:仅保存权重
    verbose=1                        # 保存时打印日志
)

# 搭配早停(注意:早停的restore_best_weights可省略,直接加载checkpoint文件)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, mode='min')

# 训练时传入两个回调
model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[early_stopping, checkpoint]
)

# 加载权重(后续使用)
model.load_weights('best_model_keras.h5')  # 仅加载权重(需先定义相同结构的模型)
# 或加载整个模型(无需提前定义结构)
from tensorflow.keras.models import load_model
loaded_model = load_model('best_model_keras.h5')
2)PyTorch 保存权重

PyTorch 中常用torch.save()保存,torch.load()加载,需注意 "模型结构与权重匹配":

复制代码
import torch
import torch.nn as nn

# 1. 定义模型结构(示例)
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 2. 保存权重(三种常见场景)
## 场景1:仅保存最优权重(State Dict,推荐部署)
torch.save(model.state_dict(), 'best_model_pytorch.pth')

## 场景2:保存断点(用于续训,包含权重+优化器+epoch)
checkpoint = {
    'epoch': 20,                  # 当前epoch
    'model_state_dict': model.state_dict(),  # 模型权重
    'optimizer_state_dict': optimizer.state_dict(),  # 优化器状态(学习率等)
    'val_loss': 0.123,            # 当前验证损失
}
torch.save(checkpoint, 'checkpoint_pytorch.pth')

# 3. 加载权重
## 场景1:加载仅权重(需先定义模型结构)
loaded_model = SimpleModel()  # 必须先实例化相同结构的模型
loaded_model.load_state_dict(torch.load('best_model_pytorch.pth'))
loaded_model.eval()  # 部署前需切换到评估模式(禁用Dropout、BatchNorm更新)

## 场景2:加载断点(续训)
checkpoint = torch.load('checkpoint_pytorch.pth')
loaded_model = SimpleModel()
loaded_optimizer = torch.optim.Adam(loaded_model.parameters(), lr=1e-3)

loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1  # 从下一个epoch继续训练
best_val_loss = checkpoint['val_loss']

# 继续训练
for epoch in range(start_epoch, 100):
    train_one_epoch(loaded_model, loaded_optimizer, train_loader)
    # ...
(3)PyTorch Lightning 保存权重

Lightning 内置ModelCheckpoint回调,与早停无缝搭配:

复制代码
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# 定义权重保存回调
checkpoint = ModelCheckpoint(
    dirpath='./checkpoints/',  # 保存目录
    filename='best-model-{epoch:02d}-{val_loss:.4f}',  # 文件名(包含epoch和损失)
    monitor='val_loss',
    mode='min',
    save_best_only=True,  # 只保存最优模型
    save_weights_only=False,  # 保存整个模型(LightningModule)
)

# 搭配早停
early_stopping = EarlyStopping(monitor='val_loss', patience=5, mode='min')

# 训练
trainer = Trainer(
    callbacks=[early_stopping, checkpoint],
    max_epochs=100,
    default_root_dir='./logs/'
)
trainer.fit(model, train_loader, val_loader)

# 加载最优模型
from pytorch_lightning import Trainer
loaded_model = SimpleModel.load_from_checkpoint(checkpoint.best_model_path)

@浙大疏锦行

相关推荐
龙腾AI白云2 小时前
DNN案例一步步构建深层神经网络(二)三、深层神经网络
人工智能·神经网络
洛阳泰山2 小时前
快速上手 MaxKB4J:开源企业级 Agentic 工作流系统在 Sealos 上的完整部署指南
java·人工智能·后端
爱打代码的小林2 小时前
机器学习(决策树)
人工智能·决策树·机器学习
光羽隹衡2 小时前
机器学习——决策树
人工智能·决策树·机器学习
roman_日积跬步-终至千里2 小时前
【计算机视觉(17)】语义理解-训练神经网络2_优化器_正则化_超参数
人工智能·神经网络·计算机视觉
档案宝档案管理2 小时前
电子会计档案管理系统:档案宝如何发挥会计档案的价值?
大数据·数据库·人工智能·档案·档案管理
世岩清上2 小时前
AI绘就文化新画卷:数字化保护留存历史瑰宝,普惠创作绽放艺术繁花
人工智能
像风一样自由20202 小时前
从GAN到WGAN-GP:生成对抗网络的进化之路与实战详解
人工智能·神经网络·生成对抗网络
性感博主在线瞎搞2 小时前
【神经网络】超参调优策略(二):Batch Normalization批量归一化
人工智能·神经网络·机器学习·batch·批次正规化