基于Resnet50预训练模型实现CIFAR-10数据集的分类任务

一、Resnet预训练模型介绍

Resnet是计算机视觉(CV)领域著名的预训练模型,ResNet50是计算机视觉领域最具影响力的深度学习架构之一,由何恺明等人在2015年提出。它解决了深度神经网络中的退化问题,使得训练极深的网络成为可能。下面详细介绍ResNet50的架构与设计思想.

  • 传统网络:直接学习目标映射 H(x)
  • 残差网络:学习残差映射 F(x) = H(x) - x,最终输出为 F(x) + x

这种设计允许梯度直接流过恒等映射(identity mapping)路径,有效缓解了梯度消失问题,使得训练上百层的网络成为可能。

ResNet50总共有50层(按权重层计算),由以下部分组成:

  1. 初始层:7×7卷积 + 最大池化
  2. 4个残差阶段(Stage),每个阶段包含多个残差块
  3. 全局平均池化层
  4. 1000类分类器(原始版本用于ImageNet)

ResNet50使用Bottleneck结构作为基本单元,每个Bottleneck包含3个卷积层:

  1. 1×1卷积:降维,减少计算量
  2. 3×3卷积:空间特征提取
  3. 1×1卷积:升维,恢复通道数

具体公式:y = F(x, {Wi}) + x

  • F(x)是残差函数
  • x是输入
  • y是输出
  • "+"操作通过元素相加实现,要求F(x)和x维度相同

二、基于Resnet50实现cifar-10数据集分类任务

下面给出基于Resnet50实现cifar-10数据集分类任务的具体实现代码:

1.导包与预训练模型的导入

python 复制代码
# 一、预训练模型的下载和导入
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2.数据集的下载和预处理

  • 训练集:随机裁剪、水平翻转、标准化
  • 测试集:仅标准化
  • 均值(0.4914, 0.4822, 0.4465)和标准差(0.2023, 0.1994, 0.2010)为CIFAR-10数据集统计值
python 复制代码
# 二、数据集的下载和预处理
def prepare_data():
    """下载并预处理CIFAR-10数据集"""
    # 训练集数据增强
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # 测试集预处理
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # 加载数据集
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainloader, testloader, classes

3.构建网络模型

python 复制代码
# 三、构建网络模型
def build_model():
    """构建适应CIFAR-10的ResNet50模型"""
    # 加载预训练权重
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    
    # 冻结所有预训练层
    for param in model.parameters():
        param.requires_grad = False
    
    # 调整网络结构以适应32x32的小图像
    # 替换第一层卷积,减小感受野
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    # 移除最大池化层,防止小图像特征过度下采样
    model.maxpool = nn.Identity()
    
    # 修改全连接层,输出10个类别
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    
    # 解冻需要训练的层
    for param in model.conv1.parameters():
        param.requires_grad = True
    for param in model.layer4.parameters():
        param.requires_grad = True
    for param in model.fc.parameters():
        param.requires_grad = True
    
    model = model.to(device)
    return model

4.编写训练函数

python 复制代码
# 四、编写训练函数
def train(model, trainloader, criterion, optimizer, epoch):
    """训练模型一个epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # 训练进度条
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}", leave=False)
    for batch_idx, (inputs, targets) in enumerate(progress_bar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计信息
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # 更新进度条
        progress_bar.set_postfix(loss=running_loss/(batch_idx+1), acc=100.*correct/total)
    
    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

5.编写推理函数

python 复制代码
# 五、编写推理函数
def test(model, testloader, criterion):
    """在测试集上评估模型"""
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_loss = test_loss / len(testloader)
    test_acc = 100. * correct / total
    return test_loss, test_ac

6.编写运行主函数

python 复制代码
# 六、主函数运行
def main():
    """主函数:组织训练和测试流程"""
    # 准备数据
    print("Preparing data...")
    trainloader, testloader, classes = prepare_data()
    
    # 构建模型
    print("Building model...")
    model = build_model()
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    # 只优化需要训练的参数
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                         lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    
    # 训练参数
    num_epochs = 100
    best_acc = 0
    
    # 记录训练过程
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    
    print("Starting training...")
    start_time = time.time()
    
    # 训练循环
    for epoch in range(num_epochs):
        # 训练
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, epoch)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # 测试
        test_loss, test_acc = test(model, testloader, criterion)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        # 更新学习率
        scheduler.step()
        
        # 保存最佳模型
        if test_acc > best_acc:
            print(f'Saving best model with accuracy: {test_acc:.2f}%')
            best_acc = test_acc
            torch.save(model.state_dict(), 'resnet50_cifar10_best.pth')
        
        # 打印结果
        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    
    total_time = time.time() - start_time
    print(f'Training completed in {total_time//60:.0f}m {total_time%60:.0f}s')
    print(f'Best test accuracy: {best_acc:.2f}%')
    
    # 保存最终模型
    torch.save(model.state_dict(), 'resnet50_cifar10_final.pth')
    
    # 绘制训练过程
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(test_accs, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.savefig('training_results.png')
    plt.close()

if __name__ == '__main__':
    main()

7.运行结果展示

(mlstat) [haichao@node01 cifar]$ python demo1.py

Using device: cpu

Preparing data...

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [1:09:08<00:00, 41.1kB/s]

Building model...

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/haichao/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [01:30<00:00, 1.13MB/s]

Starting training...

Saving best model with accuracy: 76.80%

Epoch 1/100, Train Loss: 1.1939, Train Acc: 57.72%, Test Loss: 0.6765, Test Acc: 76.80%

Saving best model with accuracy: 83.43%

Epoch 2/100, Train Loss: 0.5763, Train Acc: 80.08%, Test Loss: 0.4767, Test Acc: 83.43%

Saving best model with accuracy: 87.68%

Epoch 3/100, Train Loss: 0.4238, Train Acc: 85.34%, Test Loss: 0.3579, Test Acc: 87.68%

Saving best model with accuracy: 88.45%

Epoch 4/100, Train Loss: 0.3421, Train Acc: 88.15%, Test Loss: 0.3430, Test Acc: 88.45%

Saving best model with accuracy: 89.71%

Epoch 5/100, Train Loss: 0.2856, Train Acc: 90.00%, Test Loss: 0.2943, Test Acc: 89.71%

Saving best model with accuracy: 90.81%

Epoch 6/100, Train Loss: 0.2477, Train Acc: 91.43%, Test Loss: 0.2668, Test Acc: 90.81%

Saving best model with accuracy: 91.19%

Epoch 7/100, Train Loss: 0.2098, Train Acc: 92.81%, Test Loss: 0.2614, Test Acc: 91.19%

Saving best model with accuracy: 91.64%

Epoch 8/100, Train Loss: 0.1903, Train Acc: 93.34%, Test Loss: 0.2588, Test Acc: 91.64%

Saving best model with accuracy: 91.67%

Epoch 9/100, Train Loss: 0.1713, Train Acc: 94.04%, Test Loss: 0.2498, Test Acc: 91.67%

Saving best model with accuracy: 91.74%

Epoch 10/100, Train Loss: 0.1534, Train Acc: 94.66%, Test Loss: 0.2517, Test Acc: 91.74%

Saving best model with accuracy: 91.88%

Epoch 11/100, Train Loss: 0.1398, Train Acc: 95.09%, Test Loss: 0.2474, Test Acc: 91.88%

Epoch 12/100, Train Loss: 0.1248, Train Acc: 95.49%, Test Loss: 0.2576, Test Acc: 91.78%

Saving best model with accuracy: 92.38%

Epoch 13/100, Train Loss: 0.1128, Train Acc: 95.99%, Test Loss: 0.2469, Test Acc: 92.38%

Epoch 14/100, Train Loss: 0.1031, Train Acc: 96.38%, Test Loss: 0.2492, Test Acc: 92.28%

Saving best model with accuracy: 92.42%

Epoch 15/100, Train Loss: 0.0950, Train Acc: 96.62%, Test Loss: 0.2462, Test Acc: 92.42%

Epoch 16/100, Train Loss: 0.0830, Train Acc: 97.09%, Test Loss: 0.2650, Test Acc: 92.34%

Epoch 17/100, Train Loss: 0.0807, Train Acc: 97.17%, Test Loss: 0.2655, Test Acc: 92.25%

Epoch 18/100, Train Loss: 0.0707, Train Acc: 97.52%, Test Loss: 0.2650, Test Acc: 92.08%

Epoch 19/100, Train Loss: 0.0719, Train Acc: 97.50%, Test Loss: 0.2786, Test Acc: 92.13%

Epoch 20/100, Train Loss: 0.0644, Train Acc: 97.78%, Test Loss: 0.2768, Test Acc: 91.91%

Epoch 21/100, Train Loss: 0.0604, Train Acc: 97.92%, Test Loss: 0.2606, Test Acc: 92.33%

Saving best model with accuracy: 92.43%

Epoch 22/100, Train Loss: 0.0576, Train Acc: 98.00%, Test Loss: 0.2669, Test Acc: 92.43%

Saving best model with accuracy: 92.52%

Epoch 23/100, Train Loss: 0.0652, Train Acc: 97.70%, Test Loss: 0.2641, Test Acc: 92.52%

Saving best model with accuracy: 92.84%

Epoch 24/100, Train Loss: 0.0520, Train Acc: 98.23%, Test Loss: 0.2680, Test Acc: 92.84%

Epoch 25/100, Train Loss: 0.0501, Train Acc: 98.33%, Test Loss: 0.2845, Test Acc: 92.15%

Epoch 26/100, Train Loss: 0.0443, Train Acc: 98.45%, Test Loss: 0.2744, Test Acc: 92.62%

Epoch 27/100, Train Loss: 0.0530, Train Acc: 98.15%, Test Loss: 0.2827, Test Acc: 92.09%

Epoch 28/100, Train Loss: 0.0618, Train Acc: 97.83%, Test Loss: 0.2654, Test Acc: 92.58%

Epoch 29/100, Train Loss: 0.0470, Train Acc: 98.42%, Test Loss: 0.2585, Test Acc: 92.75%

Epoch 30/100, Train Loss: 0.0468, Train Acc: 98.43%, Test Loss: 0.2678, Test Acc: 92.33%

Epoch 31/100, Train Loss: 0.0455, Train Acc: 98.45%, Test Loss: 0.2765, Test Acc: 92.38%

Saving best model with accuracy: 93.07%

Epoch 32/100, Train Loss: 0.0383, Train Acc: 98.67%, Test Loss: 0.2709, Test Acc: 93.07%

Saving best model with accuracy: 93.13%

Epoch 33/100, Train Loss: 0.0341, Train Acc: 98.83%, Test Loss: 0.2599, Test Acc: 93.13%

Epoch 34/100, Train Loss: 0.0308, Train Acc: 98.96%, Test Loss: 0.2650, Test Acc: 92.95%

Epoch 35/100, Train Loss: 0.0302, Train Acc: 99.00%, Test Loss: 0.2663, Test Acc: 93.10%

Epoch 36/100, Train Loss: 0.0365, Train Acc: 98.78%, Test Loss: 0.2682, Test Acc: 93.03%

Epoch 37/100, Train Loss: 0.0319, Train Acc: 98.90%, Test Loss: 0.2826, Test Acc: 92.80%

Epoch 38/100, Train Loss: 0.0337, Train Acc: 98.89%, Test Loss: 0.2660, Test Acc: 93.06%

Epoch 39/100, Train Loss: 0.0322, Train Acc: 98.89%, Test Loss: 0.2696, Test Acc: 93.02%

Epoch 40/100, Train Loss: 0.0289, Train Acc: 99.05%, Test Loss: 0.2725, Test Acc: 92.81%

Epoch 41/100, Train Loss: 0.0274, Train Acc: 99.08%, Test Loss: 0.2560, Test Acc: 93.07%

Saving best model with accuracy: 93.41%

Epoch 42/100, Train Loss: 0.0244, Train Acc: 99.18%, Test Loss: 0.2690, Test Acc: 93.41%

Epoch 43/100, Train Loss: 0.0222, Train Acc: 99.21%, Test Loss: 0.2687, Test Acc: 93.25%

Epoch 44/100, Train Loss: 0.0216, Train Acc: 99.29%, Test Loss: 0.2630, Test Acc: 93.17%

Epoch 45/100, Train Loss: 0.0217, Train Acc: 99.29%, Test Loss: 0.2694, Test Acc: 93.35%

Epoch 46/100, Train Loss: 0.0214, Train Acc: 99.33%, Test Loss: 0.2632, Test Acc: 93.21%

Saving best model with accuracy: 93.59%

Epoch 47/100, Train Loss: 0.0172, Train Acc: 99.45%, Test Loss: 0.2481, Test Acc: 93.59%

Epoch 48/100, Train Loss: 0.0164, Train Acc: 99.47%, Test Loss: 0.2558, Test Acc: 93.43%

Epoch 49: 50%|██████████████████████████████████████████████████████████████████████████████████████▏ | 197/391 [03:20<03:19, 1.03s/it, acc=99.4, loss=0.0164]

Epoch 49/100, Train Loss: 0.0166, Train Acc: 99.43%, Test Loss: 0.2684, Test Acc: 93.39%

Saving best model with accuracy: 93.78%

Epoch 50/100, Train Loss: 0.0164, Train Acc: 99.45%, Test Loss: 0.2503, Test Acc: 93.78%

Epoch 51/100, Train Loss: 0.0157, Train Acc: 99.50%, Test Loss: 0.2619, Test Acc: 93.63%

Epoch 52/100, Train Loss: 0.0143, Train Acc: 99.54%, Test Loss: 0.2692, Test Acc: 93.52%

Epoch 53/100, Train Loss: 0.0129, Train Acc: 99.62%, Test Loss: 0.2634, Test Acc: 93.35%

相关推荐
xiaobaishuoAI2 小时前
后端工程化实战指南:从规范到自动化,打造高效协作体系
java·大数据·运维·人工智能·maven·devops·geo
dazzle2 小时前
计算机视觉处理(OpenCV基础教学(二十一):模板匹配技术详解)
人工智能·opencv·计算机视觉
TTGGGFF2 小时前
【零基础教程】从零部署 NewBie-image-Exp0.1:避开所有源码坑点
人工智能·多模态·图片生成
小明_GLC2 小时前
LangGraph
人工智能
PeterClerk2 小时前
深度学习-NLP 常见语料库
人工智能·深度学习·自然语言处理
啊巴矲2 小时前
小白从零开始勇闯人工智能:计算机视觉初级篇(初识Opencv中)
人工智能·opencv·计算机视觉
-dcr2 小时前
50.智能体
前端·javascript·人工智能·ai·easyui
向上的车轮2 小时前
AI编辑器的兴起:如何用好AI编辑器解决实际问题?
人工智能·编辑器
咚咚王者2 小时前
人工智能之核心基础 机器学习 第十一章 无监督学习总结
人工智能·学习·机器学习