Pytorch | 对比Pytorch中的十种优化器:基于CIFAR10上的ResNet分类器

Pytorch | 对比Pytorch中的十种优化器:基于CIFAR10上的ResNet分类器

上篇文章中实现了十种优化算法:Python | 从零实现10种优化算法并比较

这篇文章我们用Pytorch上不同的优化器在CIFAR10数据集上训练ResNet模型,比较不同优化器的效果。

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:

ResNet

ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,在图像识别等计算机视觉任务中取得了巨大成功。

提出背景

随着神经网络深度的增加,出现了梯度消失/爆炸以及网络退化等问题,导致训练难度增大,精度饱和甚至下降。ResNet通过引入残差连接(shortcut connection)有效地解决了这些问题,使得训练极深的网络成为可能。

网络结构特点

  • 残差块(Residual Block) :这是ResNet的核心结构。它由多个卷积层组成,并且在卷积层之间引入了shortcut connection。一个基本的残差块包含两个3×3卷积层,中间有一个ReLU激活函数,其输入可以直接跳过这两个卷积层与输出相加,这种结构使得网络能够学习到残差函数,即输入与输出之间的差异,而不是直接学习输出本身。

  • 多种层数的网络结构 :ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等,其中数字表示网络的层数。层数越深,模型的表示能力越强,但计算成本也越高,训练难度也相应增大。

  • 瓶颈结构(Bottleneck):在较深的ResNet架构如ResNet-50及以上中,使用了瓶颈结构来减少计算量。它由1×1、3×3和1×1三个卷积层组成,1×1卷积层用于降低输入特征图的通道数,3×3卷积层进行主要的特征提取,最后1×1卷积层用于恢复通道数。

工作原理

在正向传播时,输入特征图通过残差块中的卷积层进行特征提取,得到输出特征图。然后将输入特征图与输出特征图相加,得到最终的输出。如果残差块中的卷积层没有学到有用的特征,那么它们的输出接近于零,此时最终的输出就近似等于输入,即网络可以学习到恒等映射。在反向传播时,由于shortcut connection的存在,梯度可以直接通过捷径传播到较早的层,避免了梯度消失或爆炸的问题,使得网络能够更容易地训练深层网络。

优势

  • 有效解决梯度消失和退化问题:使得训练非常深的网络成为可能,能够提取更高级的图像特征,从而提高了模型的准确性和泛化能力。
  • 降低模型训练难度:残差连接使得网络在训练过程中更容易收敛,减少了对超参数调整的依赖,提高了训练效率。
  • 模型具有很强的可扩展性:可以通过增加残差块的数量来构建更深的网络,以适应不同的任务和数据集。

代码实现分析

utils.py

该文件中我们预先为调用不同的优化器做好准备,对于不同的优化器参数,为了方便,这里只设置学习率,其余参数可根据需要自己设置。

python 复制代码
import torch
import torch.optim as optim

def get_optimizer(optimizer_name, model_parameters, lr=0.001, **kwargs):
    """
    根据传入的优化器名称返回对应的优化器实例。

    参数:
    - optimizer_name: 优化器名称,如 "Adadelta", "Adagrad" 等。
    - model_parameters: 模型的可训练参数,通常通过 model.parameters() 获取。
    - lr: 学习率,默认值为 0.001。
    - **kwargs: 其他特定优化器需要的额外参数。

    返回:
    - optimizer: 对应的优化器实例。
    """
    if optimizer_name == "Adadelta":
        return optim.Adadelta(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adagrad":
        return optim.Adagrad(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adam":
        return optim.Adam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adamax":
        return optim.Adamax(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "AdamW":
        return optim.AdamW(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "NAdam":
        return optim.NAdam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "RMSprop":
        return optim.RMSprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Rprop":
        return optim.Rprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SGD":
        return optim.SGD(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SparseAdam":
        return optim.SparseAdam(model_parameters, lr=lr, **kwargs)
    else:
        raise ValueError(f"不支持的优化器名称: {optimizer_name}")

main.py

本文这里使用Pytorch自带的ResNet18模型。

以下是对上述代码的分块讲解:

导入必要的库

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from utils import get_optimizer

import warnings
warnings.filterwarnings("ignore")

import ssl

ssl._create_default_https_context = ssl._create_unverified_context
  • from utils import get_optimizer 从前面的 utils 模块中导入 get_optimizer 函数,用于调用不同的优化器;
  • warnings.filterwarnings("ignore") 用于忽略一些可能出现的警告信息,让代码运行时输出更简洁;
  • ssl 相关的代码是为了解决在下载数据集时可能出现的SSL验证问题,通过创建一个默认的不验证SSL证书的上下文来允许数据正常下载。

设备选择与数据预处理定义

python 复制代码
# 检查cuda是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义数据预处理操作,将图像转换为张量并进行归一化
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]
    )
  • 检查 gpu(cuda) 是否可用;
  • transforms.Normalize 对图像张量进行归一化操作,传入的两个元组分别表示每个通道的均值和标准差

加载训练集和测试集

python 复制代码
# 加载训练集,设置batch_size等参数
# 这里batch_size设为128,可按需调整
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

这部分代码用于加载CIFAR-10数据集,它是一个常用的图像分类数据集,包含10个不同类别的图像。

对于训练集:

  • torchvision.datasets.CIFAR10 函数用于创建训练集对象 trainset,其中 root='./data' 指定了数据集下载和保存的根目录(如果不存在会自动下载到该目录下),train=True 表示加载的是训练集部分,download=True 表示如果数据集不存在则自动下载,transform=transform 表示应用前面定义好的数据预处理操作。
  • torch.utils.data.DataLoader 用于将数据集 trainset 包装成一个可迭代的数据加载器 trainloader,设置 batch_size=128 意味着每次迭代会返回一个包含128张图像及其对应标签的批次数据,shuffle=True 会在每个训练轮次开始时打乱数据顺序,num_workers=2 表示使用2个子进程来并行加载数据,加快数据读取速度。

对于测试集:

  • 同样使用 torchvision.datasets.CIFAR10 函数创建 testset,不过 train=False 表示加载的是测试集部分,其他参数作用和训练集加载时类似。
  • 再通过 torch.utils.data.DataLoader 创建 testloader,只是 shuffle=False 因为测试集一般不需要打乱顺序。

主函数部分

训练部分

训练部分实现了使用10种优化器训练,共训练10个epoch,并记录损失下降情况以可视化,具体可参考下面代码和其中注释:

python 复制代码
# 定义类别标签,CIFAR-10有10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == "__main__":
    # 使用不同的优化器
    optimizer_names = ["Adadelta", "Adagrad", "Adam", "Adamax", "AdamW", 
                       "NAdam", "RMSProp", "RProp", "SGD", "SparseAdam"]
	for optimizer_name in optimizer_names:
	    print(f"========= Optimizer: {optimizer_name} ==========")
	    # 使用PyTorch自带的ResNet18模型,修改全连接层输出维度为10(对应CIFAR-10的类别数)
	    model = torchvision.models.resnet18(pretrained=False)
	    num_ftrs = model.fc.in_features
	    model.fc = nn.Linear(num_ftrs, 10)
	    model = model.to(device)
	
	    # 定义交叉熵损失函数,常用于分类任务
	    criterion = nn.CrossEntropyLoss()
	    # 定义优化器
	    optimizer = get_optimizer(optimizer_name, model.parameters(), lr=0.001)
	    # 训练轮数,设为10轮,可根据实际情况更改
	    num_epochs = 10
	    loss_history = []
	    for epoch in range(num_epochs):
	        running_loss = 0.0
	        for i, data in enumerate(trainloader, 0):
	            # 获取输入数据和标签
	            inputs, labels = data
	            inputs = inputs.to(device)
	            labels = labels.to(device)
	            # 梯度清零
	            optimizer.zero_grad()
	
	            # 前向传播 + 计算损失
	            outputs = model(inputs)
	            loss = criterion(outputs, labels)
	
	            # 反向传播并更新权重
	            loss.backward()
	            optimizer.step()
	
	            running_loss += loss.item()
	            loss_history.append(loss.item())
	            if i % 100 == 99:    # 每100个小批次打印一次平均损失
	                print(f'Epoch: {epoch + 1}  Batch: {i + 1}  loss: {running_loss / 100}')
	                running_loss = 0.0
	    
	    # 绘制并保存损失下降曲线
	    plt.plot(loss_history)
	    plt.xlabel('Iteration')
	    plt.ylabel('Loss')
	    plt.title(f'Loss Curve by {optimizer_name}')
	    plt.savefig(f'results\\loss_curve_{optimizer_name}.png')  # 保存图像为 loss_curve.png 文件,可根据需求修改文件名和路径
	    plt.close()
测试部分

测试部分使用已经训练好的模型对测试集进行预测,具体参考下面代码:

python 复制代码
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Optimizer: {optimizer_name} -- Accuracy of the network on the 10000 test images: {100 * correct / total}%')

结果

10种优化器对应的训练损失下降曲线

Adadelta

Adagrad

Adam

Adamax

AdamW

NAdam

RMSprop

Rprop

SGD

SparseAdam

本文训练到这里时,报错:RuntimeError: SparseAdam does not support dense gradients, please consider Adam instead .
原因: SparseAdam 优化器主要是设计用于处理稀疏梯度的场景,也就是梯度张量中大部分元素为零的情况(比如在处理稀疏的文本数据表示等情况时)。而在当前的代码应用场景中,很可能模型计算得到的梯度是密集的(即梯度张量中元素大多是非零值),这就导致 SparseAdam 优化器无法正常处理这样的梯度,进而抛出这个错误提示,建议你考虑使用 Adam 优化器来替代 SparseAdam 优化器。
----因此这里不再给出SparseAdam的优化结果----

测试结果

从测试结果来看,刨除 SparseAdamAdam 优化器的训练效果最优,Adadelta 优化器的训练效果最差.

代码汇总

项目结构:

|--data
|--results
|--utils.py
|--main.py

utils.py

python 复制代码
import torch
import torch.optim as optim

def get_optimizer(optimizer_name, model_parameters, lr=0.001, **kwargs):
    """
    根据传入的优化器名称返回对应的优化器实例。

    参数:
    - optimizer_name: 优化器名称,如 "Adadelta", "Adagrad" 等。
    - model_parameters: 模型的可训练参数,通常通过 model.parameters() 获取。
    - lr: 学习率,默认值为 0.001。
    - **kwargs: 其他特定优化器需要的额外参数。

    返回:
    - optimizer: 对应的优化器实例。
    """
    if optimizer_name == "Adadelta":
        return optim.Adadelta(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adagrad":
        return optim.Adagrad(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adam":
        return optim.Adam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adamax":
        return optim.Adamax(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "AdamW":
        return optim.AdamW(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "NAdam":
        return optim.NAdam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "RMSprop":
        return optim.RMSprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Rprop":
        return optim.Rprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SGD":
        return optim.SGD(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SparseAdam":
        return optim.SparseAdam(model_parameters, lr=lr, **kwargs)
    else:
        raise ValueError(f"不支持的优化器名称: {optimizer_name}")

main.py

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from utils import get_optimizer

import warnings
warnings.filterwarnings("ignore")

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

# 检查cuda是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义数据预处理操作,将图像转换为张量并进行归一化
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]
    )

# 加载训练集,设置batch_size等参数
# 这里batch_size设为128,可按需调整
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=4)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=4)

# 定义类别标签,CIFAR-10有10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


if __name__ == "__main__":
    # 使用不同的优化器
    optimizer_names = ["Adadelta", "Adagrad", "Adam", "Adamax", "AdamW", 
                       "NAdam", "RMSprop", "Rprop", "SGD", "SparseAdam"]
    
    for optimizer_name in optimizer_names:
        print(f"========= Optimizer: {optimizer_name} ==========")
        # 使用PyTorch自带的ResNet18模型,修改全连接层输出维度为10(对应CIFAR-10的类别数)
        model = torchvision.models.resnet18(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 10)
        model = model.to(device)

        # 定义交叉熵损失函数,常用于分类任务
        criterion = nn.CrossEntropyLoss()
        # 定义优化器
        optimizer = get_optimizer(optimizer_name, model.parameters(), lr=0.001)
        # 训练轮数,设为10轮,可根据实际情况更改
        num_epochs = 10
        loss_history = []
        for epoch in range(num_epochs):
            running_loss = 0.0
            for i, data in enumerate(trainloader, 0):
                # 获取输入数据和标签
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                # 梯度清零
                optimizer.zero_grad()

                # 前向传播 + 计算损失
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # 反向传播并更新权重
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                loss_history.append(loss.item())
                if i % 100 == 99:    # 每100个小批次打印一次平均损失
                    # print(f'Epoch: {epoch + 1}  Batch: {i + 1}  loss: {running_loss / 100}')
                    running_loss = 0.0

        # 绘制并保存损失下降曲线
        plt.plot(loss_history)
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.title(f'Loss Curve by {optimizer_name}')
        plt.savefig(f'results\\loss_curve_{optimizer_name}.png')  # 保存图像为 loss_curve.png 文件,可根据需求修改文件名和路径
        plt.close()

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Optimizer: {optimizer_name} -- Accuracy of the network on the 10000 test images: {100 * correct / total}%')
相关推荐
scdifsn2 分钟前
动手学深度学习11.1. 优化和深度学习-笔记&练习(PyTorch)
pytorch·笔记·深度学习·深度学习优化
知来者逆5 分钟前
计算机视觉单阶段实例分割实践指南与综述
人工智能·深度学习·机器学习·计算机视觉·目标跟踪·目标分割
Charge_A11 分钟前
深度学习作业 - 作业十一 - LSTM
人工智能·深度学习·lstm
行学AI16 分钟前
AI 赋能:医学科研审稿邀请的优化之道
人工智能
SchrodingerSDOG31 分钟前
算法刷题Day18: BM41 输出二叉树的右视图
数据结构·python·算法
QQ_77813297433 分钟前
基于机器学习的新闻分类系统
人工智能·机器学习·课程设计
B站计算机毕业设计超人35 分钟前
计算机毕业设计Python+CNN卷积神经网络高考推荐系统 高考分数线预测 高考爬虫 协同过滤推荐算法 Vue.js Django Hadoop 大数据毕设
大数据·爬虫·python·机器学习·课程设计·数据可视化·推荐算法
Aix9591 小时前
Dijkstra算法最短路径可视化(新)
python·opencv·算法
范桂飓1 小时前
AWS re:Invent 2024 — AI 基础设施架构
人工智能·架构·aws