深度学习中的早停法

早停法(Early Stopping)是一种用于防止模型过拟合的技术,在训练过程中监视验证集(或者测试集)上的损失值。具体设立早停的限制包括两个主要参数:

  1. Patience(耐心):这是指验证集损失在连续多少个epoch没有显著改善时,才触发早停。当验证集损失连续几个epoch没有下降或者停止减少时,表示模型可能已经过拟合或者陷入局部最优点,这时候早停就会被触发。

  2. Best Loss(最佳损失):这是指在早停过程中保存的最低验证集损失值。当验证集损失值低于当前最佳损失时,更新最佳损失并重置耐心计数器。如果验证集损失连续不降,耐心计数器超过设定的耐心值时,早停就会被触发,训练过程停止。

    早停的具体设立是基于验证集上的损失值 val_loss。每次验证后,如果当前的 val_lossbest_loss 还要低,就更新 best_loss 并重置 patience_counter;否则,增加 patience_counter。当 patience_counter 达到设定的 patience 值时,早停被触发,即停止训练过程以防止模型过拟合。

    总结来说,早停的设立限制是基于耐心参数和最佳损失值,用来判断模型是否应该停止训练以避免过拟合。

python 复制代码
# 训练模型
num_epochs = 200  # 总的训练轮数
best_loss = float('inf')  # 初始化最佳验证损失为正无穷大
patience = 10  # 早停的耐心值
patience_counter = 0  # 耐心计数器

for epoch in range(num_epochs):
    model.train()
    for geno, pheno in train_loader:
        optimizer.zero_grad()  # 梯度清零
        outputs = model(geno)  # 前向传播
        loss = criterion(outputs.squeeze(), pheno)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 优化模型参数

    model.eval()
    val_loss = 0
    with torch.no_grad():  # 不计算梯度
        for geno, pheno in test_loader:
            outputs = model(geno)  # 前向传播
            val_loss += criterion(outputs.squeeze(), pheno).item()  # 计算验证损失
    val_loss /= len(test_loader)  # 计算平均验证损失
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')

    scheduler.step(val_loss)  # 更新学习率

    # 早停法
    if val_loss < best_loss:
        best_loss = val_loss  # 更新最佳验证损失
        patience_counter = 0  # 重置耐心计数器
    else:
        patience_counter += 1  # 增加耐心计数器
        if patience_counter >= patience:  # 如果耐心计数器达到设定的耐心值
            print("Early stopping triggered")  # 触发早停
            break
  1. EarlyStopping
    • __init__ 方法初始化早停的参数,如 patience(耐心值)、verbose(是否打印消息)和 delta(损失改进的最小变化)。
    • __call__ 方法根据验证损失来决定是否更新 best_loss,以及是否增加计数器或者触发早停。
  2. 训练循环
    • 训练和验证过程与之前相同。
    • 每个epoch结束时,调用 early_stopping 对象,传入当前的验证损失。
    • 检查 early_stopping.early_stop 标志,如果为 True,则打印消息并停止训练。

通过使用 EarlyStopping 类,你可以更简洁和模块化地实现早停功能,使代码更易于维护和扩展。

python 复制代码
import torch
import numpy as np

class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0):
        """
        EarlyStopping 初始化.
        Args:
            patience (int): 当验证集损失在指定的epoch数内没有减少时触发早停.
            verbose (bool): 如果为True,则每次验证集损失改进时会打印一条消息.
            delta (float): 验证集损失改进的最小变化.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
            if self.verbose:
                print(f'Validation loss decreased to {self.best_loss:.6f}. Resetting counter.')

# 初始化EarlyStopping对象
early_stopping = EarlyStopping(patience=10, verbose=True)

# 训练模型
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for geno, pheno in train_loader:
        optimizer.zero_grad()
        outputs = model(geno)
        loss = criterion(outputs.squeeze(), pheno)
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for geno, pheno in test_loader:
            outputs = model(geno)
            val_loss += criterion(outputs.squeeze(), pheno).item()
    val_loss /= len(test_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')

    scheduler.step(val_loss)

    # 检查是否触发早停
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break
相关推荐
CoovallyAIHub5 小时前
无人机方案如何让桥梁监测更安全、更智能?融合RTK与超高分辨率成像,优于毫米精度
深度学习·算法·计算机视觉
大学生毕业题目6 小时前
毕业项目推荐:83-基于yolov8/yolov5/yolo11的农作物杂草检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·目标检测·cnn·pyqt·杂草识别
居7然6 小时前
美团大模型“龙猫”登场,能否重塑本地生活新战局?
人工智能·大模型·生活·美团
说私域6 小时前
社交新零售时代本地化微商的发展路径研究——基于开源AI智能名片链动2+1模式S2B2C商城小程序源的创新实践
人工智能·开源·零售
IT_陈寒6 小时前
Python性能优化:5个被低估的魔法方法让你的代码提速50%
前端·人工智能·后端
Deng_Xian_Sheng6 小时前
有哪些任务可以使用无监督的方式训练深度学习模型?
人工智能·深度学习·无监督
数据科学作家9 小时前
学数据分析必囤!数据分析必看!清华社9本书覆盖Stata/SPSS/Python全阶段学习路径
人工智能·python·机器学习·数据分析·统计·stata·spss
CV缝合救星10 小时前
【Arxiv 2025 预发行论文】重磅突破!STAR-DSSA 模块横空出世:显著性+拓扑双重加持,小目标、大场景统统拿下!
人工智能·深度学习·计算机视觉·目标跟踪·即插即用模块
TDengine (老段)12 小时前
从 ETL 到 Agentic AI:工业数据管理变革与 TDengine IDMP 的治理之道
数据库·数据仓库·人工智能·物联网·时序数据库·etl·tdengine
蓝桉80213 小时前
如何进行神经网络的模型训练(视频代码中的知识点记录)
人工智能·深度学习·神经网络