深度学习早停(early stop)训练策略

深度学习早停(Early Stopping)训练策略

早停(Early Stopping)是一种防止深度学习模型过拟合的正则化技术。在训练过程中,当模型在验证集上的性能不再显著提高时,早停策略会提前停止训练。这样可以避免模型在训练集上表现得越来越好,但在验证集上表现变差。

早停策略的步骤
  1. 划分数据集:将数据集分为训练集和验证集。
  2. 定义监控指标:通常是验证集上的损失或精度。
  3. 设定耐心值(Patience):耐心值表示在验证指标不再改善的情况下,允许继续训练的最大次数。
  4. 训练模型:在每个训练轮次后,计算验证集上的指标。如果在耐心值内验证指标没有改善,则停止训练。
示例代码实现

我们使用TensorFlow和Keras来实现早停策略。假设我们使用一个简单的全连接神经网络来分类MNIST手写数字数据集。

python 复制代码
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.callbacks import EarlyStopping

# 加载MNIST数据集
(x_train, y_train), (x_val, y_val) = mnist.load_data()

# 数据归一化处理
x_train = x_train / 255.0
x_val = x_val / 255.0

# 定义模型
model = Sequential([
    Flatten(input_shape=(28, 28)),  # 将28x28的图片展平为一维向量
    Dense(128, activation='relu'),  # 第一个全连接层,128个神经元,激活函数为ReLU
    Dense(10, activation='softmax') # 输出层,10个神经元(10个类别),激活函数为softmax
])

# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 定义早停回调函数
early_stopping = EarlyStopping(
    monitor='val_loss',   # 监控验证集上的损失
    patience=3,           # 如果验证集上的损失在3个轮次内没有改善,则停止训练
    restore_best_weights=True  # 恢复验证集损失最好的模型权重
)

# 训练模型
history = model.fit(
    x_train, y_train,             # 训练数据
    epochs=50,                    # 最大训练轮次
    validation_data=(x_val, y_val),# 验证数据
    callbacks=[early_stopping]    # 早停回调函数
)
代码解释
  1. 导入必要的库:导入TensorFlow和Keras相关的模块。
  2. 加载数据集:加载MNIST手写数字数据集,并划分为训练集和验证集。
  3. 数据归一化处理:将数据归一化到0-1范围内。
  4. 定义模型:使用Keras的Sequential API定义一个简单的全连接神经网络。
  5. 编译模型:指定优化器、损失函数和评估指标。
  6. 定义早停回调函数:使用Keras的EarlyStopping回调函数,设定监控指标为验证集上的损失,耐心值为3,训练过程中恢复验证集上损失最小的模型权重。
  7. 训练模型 :调用model.fit方法训练模型,同时传入早停回调函数。模型会在验证损失不再改善时提前停止训练。

这个例子演示了如何使用早停策略来防止模型过拟合,从而提高模型在验证集上的性能。

pytorch代码

以下是一个使用PyTorch实现早停策略的例子,同样使用MNIST手写数字数据集。

使用PyTorch实现早停策略

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# 定义一个简单的全连接神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()  # 将输入展平为一维
        self.fc1 = nn.Linear(28 * 28, 128)  # 定义一个全连接层,输入大小为28*28,输出大小为128
        self.relu = nn.ReLU()  # 定义ReLU激活函数
        self.fc2 = nn.Linear(128, 10)  # 定义另一个全连接层,输入大小为128,输出大小为10(对应10个类别)
        self.softmax = nn.Softmax(dim=1)  # 定义Softmax输出层,沿着维度1进行

    def forward(self, x):
        x = self.flatten(x)  # 将输入展平
        x = self.fc1(x)  # 输入到第一个全连接层
        x = self.relu(x)  # 通过ReLU激活函数
        x = self.fc2(x)  # 输入到第二个全连接层
        x = self.softmax(x)  # 通过Softmax激活函数
        return x

# 数据预处理:转换为张量并归一化到[-1, 1]范围内
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_size = 10000  # 验证集大小
train_size = len(train_dataset) - val_size  # 训练集大小
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])  # 划分训练集和验证集

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 训练集数据加载器
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)  # 验证集数据加载器

# 初始化模型、损失函数和优化器
model = SimpleNN()  # 创建模型实例
criterion = nn.CrossEntropyLoss()  # 定义交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用Adam优化器

# 定义早停策略
class EarlyStopping:
    def __init__(self, patience=3, delta=0):
        self.patience = patience  # 设置耐心值,表示验证损失可以不改善的最大次数
        self.delta = delta  # 设置阈值,如果损失改善小于该值则认为没有改善
        self.best_loss = None  # 初始化最佳损失为None
        self.counter = 0  # 初始化计数器为0
        self.early_stop = False  # 初始化早停标志为False
        self.best_model_state = None  # 初始化最佳模型状态为None

    def __call__(self, val_loss, model):
        if self.best_loss is None:  # 如果最佳损失为None,说明是第一次调用
            self.best_loss = val_loss  # 将当前验证损失设为最佳损失
            self.best_model_state = model.state_dict()  # 保存模型的当前状态
        elif val_loss > self.best_loss + self.delta:  # 如果当前验证损失没有改善
            self.counter += 1  # 计数器加1
            if self.counter >= self.patience:  # 如果计数器达到耐心值
                self.early_stop = True  # 设置早停标志为True
                model.load_state_dict(self.best_model_state)  # 恢复模型到最佳状态
        else:  # 如果验证损失改善了
            self.best_loss = val_loss  # 更新最佳损失
            self.best_model_state = model.state_dict()  # 保存模型的当前状态
            self.counter = 0  # 重置计数器

early_stopping = EarlyStopping(patience=3, delta=0.01)  # 创建早停策略实例

# 训练模型
num_epochs = 50  # 最大训练轮次
for epoch in range(num_epochs):
    model.train()  # 设置模型为训练模式
    for batch in train_loader:
        images, labels = batch  # 获取一批数据和标签
        outputs = model(images)  # 将数据输入模型,获得输出
        loss = criterion(outputs, labels)  # 计算损失

        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新模型参数

    # 验证模型
    model.eval()  # 设置模型为评估模式
    val_loss = 0.0  # 初始化验证损失
    with torch.no_grad():  # 禁用梯度计算
        for batch in val_loader:
            images, labels = batch  # 获取一批数据和标签
            outputs = model(images)  # 将数据输入模型,获得输出
            loss = criterion(outputs, labels)  # 计算损失
            val_loss += loss.item()  # 累加损失

    val_loss /= len(val_loader)  # 计算验证集上的平均损失
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss}')  # 打印当前轮次的验证损失

    # 检查早停条件
    early_stopping(val_loss, model)  # 调用早停策略
    if early_stopping.early_stop:  # 如果早停标志为True
        print("Early stopping")  # 打印早停信息
        break  # 退出训练循环

# 模型训练完成
代码解释
  1. 定义模型:定义一个简单的全连接神经网络,包括展平层、全连接层、ReLU激活函数和Softmax输出层。
  2. 数据预处理 :使用transforms对MNIST数据集进行标准化处理。
  3. 加载数据集:下载MNIST数据集,并将其划分为训练集和验证集。
  4. 初始化模型、损失函数和优化器:创建模型实例,定义交叉熵损失函数,并使用Adam优化器。
  5. 定义早停策略类 :创建EarlyStopping类,包含早停所需的参数和逻辑。在验证损失不再改善时,保存模型的最佳状态,并在达到耐心值后停止训练。
  6. 训练模型:在每个训练轮次后,计算验证集上的损失,并使用早停策略检查是否需要停止训练。

这个PyTorch示例展示了如何实现早停策略,以防止模型过拟合并提高验证集上的性能。

相关推荐
qzhqbb2 小时前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨3 小时前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_883041083 小时前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
AI极客菌4 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭4 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^4 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
Power20246665 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k5 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫5 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
沉下心来学鲁班5 小时前
复现LLM:带你从零认识语言模型
人工智能·语言模型