PyTorch 基础学习(17)- 过拟合

系列文章:
《PyTorch 基础学习》文章索引

介绍

在深度学习中,过拟合是一个常见的问题,它会导致模型在训练数据上表现很好,但在新数据上的表现却很差。为了提高模型的泛化能力,防止过拟合是至关重要的。本教程将详细介绍过拟合的概念、常用的防止过拟合的方法,并展示如何在 PyTorch 中实施这些方法。

什么是过拟合?

过拟合发生在模型过于复杂时,即模型包含的参数数量远超过训练数据的量。当模型在训练数据上表现得非常好时,它可能只是"记住"了这些数据的噪声或细节,而不是学到了数据的普遍规律。这会导致模型在遇到新数据时无法很好地预测,从而表现不佳。

典型的过拟合表现:

  • 训练集上的损失很低:模型在训练集上表现得很好。
  • 验证集上的损失较高:模型在验证集或测试集上的表现不如训练集。

防止过拟合的常用方法

防止过拟合的方法有很多,以下是一些常用的技术:

  • 正则化(Regularization):通过在损失函数中添加额外的惩罚项,限制模型的复杂度。
  • Dropout:在训练过程中随机丢弃部分神经元,减少模型对某些特征的依赖。
  • 数据增强(Data Augmentation):通过对训练数据进行变换,增加数据的多样性。
  • 提前停止(Early Stopping):在验证集上监控模型的性能,当性能不再提升时停止训练。
  • 减少模型复杂度:通过减少参数数量或层数来简化模型。

在 PyTorch 中防止过拟合的方法与实例

我们将分别讲解如何在 PyTorch 中使用这些技术,并提供具体的代码示例。

L2 正则化(权重衰减)

L2 正则化 通过在损失函数中加入所有权重的平方和的惩罚项,鼓励模型参数的大小趋向于零,从而防止模型变得过于复杂。在 PyTorch 中,可以通过在优化器中设置 weight_decay 参数来实现。

python 复制代码
import torch.optim as optim

# 在优化器中设置 weight_decay 来进行 L2 正则化
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

参数说明

  • weight_decay:L2 正则化的系数,值越大,正则化效果越强。

Dropout

Dropout 是一种防止过拟合的技术,通过在训练过程中随机丢弃部分神经元,使得模型不会对某些特定路径过度依赖。PyTorch 中通过 torch.nn.Dropout 层来实现 Dropout。

python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class DropoutModel(nn.Module):
    def __init__(self):
        super(DropoutModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(50, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # 在训练时以0.5的概率丢弃神经元
        x = self.fc2(x)
        return x

参数说明

  • p:Dropout的概率,表示每个神经元被丢弃的概率。

数据增强(Data Augmentation)

数据增强 通过对训练数据进行随机变换,增加数据的多样性,从而使模型在不同的数据上表现更好。PyTorch 中通常使用 torchvision.transforms 来进行数据增强。

python 复制代码
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# 定义数据增强操作
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# 应用数据增强到训练数据
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

常见的数据增强操作

  • RandomHorizontalFlip():随机水平翻转图片。
  • RandomRotation(degrees):随机旋转图片一定角度。

提前停止(Early Stopping)

提前停止是在训练过程中监控验证集上的损失,当损失不再下降时停止训练,以避免过拟合。这通常通过手动实现,监控模型在验证集上的表现。

python 复制代码
best_loss = float('inf')
patience = 10  # 在没有提升的情况下等待的轮数
counter = 0

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    output = model(train_data)
    loss = criterion(output, train_labels)
    loss.backward()
    optimizer.step()
    
    # 验证阶段
    model.eval()
    val_output = model(val_data)
    val_loss = criterion(val_output, val_labels)

    if val_loss < best_loss:
        best_loss = val_loss
        counter = 0  # 重置计数器
        torch.save(model.state_dict(), 'best_model.pth')  # 保存最佳模型
    else:
        counter += 1
    
    if counter >= patience:
        print("Early stopping!")
        break

参数说明

  • patience:当验证集上的损失不再改善时,允许的最大等待轮数。

减少模型复杂度

减少模型复杂度是最直接的防止过拟合的方法之一。这可以通过减少模型的层数、每层的神经元数量或减少模型参数的数量来实现。

python 复制代码
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)  # 减少神经元数量
        self.fc2 = nn.Linear(20, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

思路

  • 在设计模型时,根据数据的复杂度合理设计网络结构,避免使用过于复杂的模型。

神经网络预防过拟合

我们可以通过一个更复杂的神经网络示例,结合多种防止过拟合的方法来进行演示。以下是一个包含 L2 正则化、Dropout 和减少复杂度的神经网络实例。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class ComplexModel(nn.Module):
    def __init__(self):
        super(ComplexModel, self).__init__()
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 20)
        self.fc3 = nn.Linear(20, 10)
        self.dropout = nn.Dropout(p=0.5)  # Dropout
        self.output = nn.Linear(10, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(F.relu(self.fc2(x)))  # 使用 Dropout
        x = F.relu(self.fc3(x))
        x = self.output(x)
        return x

# 初始化模型
model = ComplexModel()

# 使用 L2 正则化的优化器
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# 损失函数
criterion = nn.MSELoss()

# 模拟训练
for epoch in range(100):
    optimizer.zero_grad()
    output = model(torch.randn(64, 100))
    loss = criterion(output, torch.randn(64, 1))
    loss.backward()
    optimizer.step()

在这个示例中,我们使用了一个多层神经网络模型,并结合了 L2 正则化(通过 weight_decay)、Dropout 层和适度减少神经元数量的策略来防止过拟合。

PyTorch 自带模型中使用防止过拟合

PyTorch 自带了一些预训练的模型(如 torchvision.models),你可以在这些模型中使用上述技巧来防止过拟合。

示例:在 ResNet 中使用 L2 正则化和 Dropout

python 复制代码
import torchvision.models as models

# 加载预训练的 ResNet 模型
model = models.resnet18(pretrained=True)

# 替换最后一层,以适应你的任务
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),  # 添加 Dropout 层
    nn.Linear(model.fc.in_features, 10)  # 假设有10个类别
)

# 使用 L2 正则化
optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-4)

在这个示例中,我们使用了预训练的 ResNet 模型,并在最后一层添加了 Dropout

层,同时使用了 L2 正则化来训练模型。

总结

防止过拟合是深度学习中的关键问题,通过合理使用正则化、Dropout、数据增强、提前停止等技术,可以显著提升模型的泛化能力。PyTorch 提供了灵活的接口,可以很方便地将这些技术应用到你的模型中。掌握这些技巧,将帮助你在实际项目中训练出更加稳健和高效的模型。

相关推荐
边缘计算社区24 分钟前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
游客52034 分钟前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
一位小说男主35 分钟前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
深圳南柯电子1 小时前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ1 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter00881 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖1 小时前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
IT古董2 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比2 小时前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
Code哈哈笑2 小时前
【Java 学习】深度剖析Java多态:从向上转型到向下转型,解锁动态绑定的奥秘,让代码更优雅灵活
java·开发语言·学习