PyTorch中实现早停机制(EarlyStopping)附代码

1. 核心目的

  • 当模型在验证集上的性能不再提升时,提前终止训练
  • 防止过拟合,节省计算资源

2. 实现方法

监控验证集指标(如损失、准确率),设置耐心值(Patience)

3. 代码:

python 复制代码
class EarlyStopping:
    def __init__(self,patience =10,delta=0):
        """
        Early stopping
        Args:
            patience: int, number of epochs to wait before stopping
            delta: float, the minimum improvements
        """
        self.patience = patience
        self.delta = delta
        self.counter =0 
        self.early_stop = False
        self.best_loss = float('inf')

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter =0 
        else:
            self.counter+=1
            if self.counter >= self.patience:
                self.early_stop = True
    

解释__call__ 方法的作用

在 Python 中,当一个类定义了 __call__ 方法时,这个类的实例就可以被当作函数来调用。例如:

复制代码
early_stopper = EarlyStopping(patience=3)  # 创建实例
early_stopper(val_loss=0.5)  # 调用实例,实际执行 __call__ 方法