在深度学习的训练过程中,我们经常面临两个核心问题:"训练到什么时候停止?" 和 "训练好的模型怎么存?"。
如果训练时间太短,模型欠拟合;训练时间太长,模型过拟合。手动盯着Loss曲线决定何时停止既累人又不精确。早停策略 (Early Stopping) 就是为了解决这个问题而生的自动化工具。而模型保存则是将我们消耗算力炼出的"丹"(模型参数)持久化存储的关键步骤。
一、 过拟合与监控机制
1.1 什么是过拟合的信号?
在训练过程中,我们通常会观察到以下现象:
- 训练集 Loss:持续下降(因为模型在死记硬背训练数据)。
- 测试集/验证集 Loss:先下降,达到一个最低点后,开始震荡甚至反弹上升。
关键点:当训练集 Loss 下降但测试集 Loss 不再下降(甚至上升)时,就是过拟合的开始。这就是我们应该停止训练的最佳时机。
1.2 如何监控?
我们需要在训练循环中,每隔一定的 Epoch(例如每1个或每100个Epoch),暂停训练模式,切换到评估模式 (model.eval()),计算测试集上的 Loss。
# 伪代码逻辑
for epoch in range(num_epochs):
train(...) # 训练
if epoch % check_interval == 0:
model.eval()
test_loss = validate(...) # 验证
print(f"Train Loss: {train_loss}, Test Loss: {test_loss}")
二、 早停策略 (Early Stopping) 实战
早停策略的核心思想是:给模型几次机会(Patience),如果它在验证集上的表现连续几次都没有提升,那就强制停止。
2.1 核心参数
- best_score**/** best_loss: 记录历史最好的指标。
- patience**(耐心值)**: 允许模型连续多少次没有提升。比如设为 10,意味着即使 Loss 上升了,我也再等你 10 轮,万一后面又降了呢?
- counter: 计数器,记录连续没有提升的次数。
- min_delta: 只有当提升幅度超过这个阈值时,才算作"提升"(防止微小的抖动被误判)。
2.2 代码实现模板
这是一个可以直接套用的标准早停逻辑代码块:
# 初始化早停参数
best_test_loss = float('inf') # 初始最佳Loss设为无穷大
patience = 20 # 耐心值:20轮不降就停
counter = 0 # 计数器
early_stopped = False # 停止标志
for epoch in range(num_epochs):
# ... (训练代码省略) ...
# --- 验证阶段 ---
if (epoch + 1) % 10 == 0: # 假设每10轮验证一次
model.eval()
with torch.no_grad():
outputs = model(X_test)
test_loss = criterion(outputs, y_test)
model.train() # 切回训练模式
current_loss = test_loss.item()
# --- 早停核心逻辑 ---
if current_loss < best_test_loss:
# 情况1:Loss 创新低(表现更好)
best_test_loss = current_loss
counter = 0 # 重置计数器
# 【关键】保存当前最好的模型,防止后面训练这就"烂"了
torch.save(model.state_dict(), 'best_model.pth')
print(f"Epoch {epoch}: New best loss {best_test_loss:.4f}, model saved.")
else:
# 情况2:Loss 没创新低(表现变差或持平)
counter += 1
print(f"Epoch {epoch}: No improvement. Counter {counter}/{patience}")
if counter >= patience:
print("早停触发!停止训练。")
early_stopped = True
break # 跳出训练循环
重要提示:
早停触发后,模型当前的状态通常已经过拟合了(因为最后 patience 轮都在变差)。所以,必须 在训练结束后,重新加载我们中间保存的那个 best_model.pth,那才是真正的最佳模型。
if early_stopped:
model.load_state_dict(torch.load('best_model.pth'))
print("已回滚至最佳模型参数。")
三、 模型的保存与加载
PyTorch 提供了多种保存方式,但在工业界和学术界,只保存参数(state_dict) 是绝对的主流和最佳实践。
3.1 方式一:仅保存模型参数 (推荐) ⭐⭐⭐⭐⭐
这是最轻量级、最灵活的方式。它只保存模型的权重(Tensor数据),不保存模型的类定义。
-
保存:
model.state_dict() 是一个字典,包含所有层的权重
torch.save(model.state_dict(), "model_weights.pth")
-
加载:
需要先实例化模型对象(代码中必须有 class MLP(...) 的定义),然后把参数填进去。
model = MLP() # 1. 先实例化结构
model.load_state_dict(torch.load("model_weights.pth")) # 2. 填充参数
model.eval() # 3. 如果用于推理,记得切到eval模式
3.2 方式二:保存整个模型 (不推荐) ⭐
这种方式会把模型结构和参数打包一起存。
- 保存 :
torch.save(model, "full_model.pth") - 加载 :
model = torch.load("full_model.pth") - 缺点 :它严重依赖代码目录结构。如果你把代码发给别人,或者把模型类定义的 py 文件改了个名字/移了个位置,加载就会直接报错 (
AttributeError)。它是基于 Python 的pickle序列化的,非常脆弱。
3.3 方式三:保存 Checkpoint (断点续训) ⭐⭐⭐⭐
如果你跑一个大模型需要训练几天几夜,你肯定不希望电脑死机后重头再来。这时需要保存所有训练状态。
-
保存:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), # 优化器也有参数(如动量),必须存!
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth') -
加载与恢复:
model = MLP()
optimizer = optim.Adam(model.parameters())checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] # 从断掉的下一轮开始继续训练...
for epoch in range(start_epoch, num_epochs):
...
四、 总结
- 不要盲目训练:始终监控测试集 Loss,它是过拟合的"报警器"。
- 早停是标配 :设置合理的
patience(通常 10-50,视数据波动情况而定),配合best_model保存机制,可以让你获得泛化能力最好的模型。 - 只存参数 :养成使用
model.state_dict()的好习惯,避免使用torch.save(model)。 - 断点保护:对于长时训练,务必定期保存 Checkpoint,包含优化器状态。