基于Resnet50预训练模型实现CIFAR-10数据集的分类任务

一、Resnet预训练模型介绍

Resnet是计算机视觉(CV)领域著名的预训练模型,ResNet50是计算机视觉领域最具影响力的深度学习架构之一,由何恺明等人在2015年提出。它解决了深度神经网络中的退化问题,使得训练极深的网络成为可能。下面详细介绍ResNet50的架构与设计思想.

  • 传统网络:直接学习目标映射 H(x)
  • 残差网络:学习残差映射 F(x) = H(x) - x,最终输出为 F(x) + x

这种设计允许梯度直接流过恒等映射(identity mapping)路径,有效缓解了梯度消失问题,使得训练上百层的网络成为可能。

ResNet50总共有50层(按权重层计算),由以下部分组成:

  1. 初始层:7×7卷积 + 最大池化
  2. 4个残差阶段(Stage),每个阶段包含多个残差块
  3. 全局平均池化层
  4. 1000类分类器(原始版本用于ImageNet)

ResNet50使用Bottleneck结构作为基本单元,每个Bottleneck包含3个卷积层:

  1. 1×1卷积:降维,减少计算量
  2. 3×3卷积:空间特征提取
  3. 1×1卷积:升维,恢复通道数

具体公式:y = F(x, {Wi}) + x

  • F(x)是残差函数
  • x是输入
  • y是输出
  • "+"操作通过元素相加实现,要求F(x)和x维度相同

二、基于Resnet50实现cifar-10数据集分类任务

下面给出基于Resnet50实现cifar-10数据集分类任务的具体实现代码:

1.导包与预训练模型的导入

python 复制代码
# 一、预训练模型的下载和导入
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2.数据集的下载和预处理

  • 训练集:随机裁剪、水平翻转、标准化
  • 测试集:仅标准化
  • 均值(0.4914, 0.4822, 0.4465)和标准差(0.2023, 0.1994, 0.2010)为CIFAR-10数据集统计值
python 复制代码
# 二、数据集的下载和预处理
def prepare_data():
    """下载并预处理CIFAR-10数据集"""
    # 训练集数据增强
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # 测试集预处理
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # 加载数据集
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainloader, testloader, classes

3.构建网络模型

python 复制代码
# 三、构建网络模型
def build_model():
    """构建适应CIFAR-10的ResNet50模型"""
    # 加载预训练权重
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    
    # 冻结所有预训练层
    for param in model.parameters():
        param.requires_grad = False
    
    # 调整网络结构以适应32x32的小图像
    # 替换第一层卷积,减小感受野
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    # 移除最大池化层,防止小图像特征过度下采样
    model.maxpool = nn.Identity()
    
    # 修改全连接层,输出10个类别
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    
    # 解冻需要训练的层
    for param in model.conv1.parameters():
        param.requires_grad = True
    for param in model.layer4.parameters():
        param.requires_grad = True
    for param in model.fc.parameters():
        param.requires_grad = True
    
    model = model.to(device)
    return model

4.编写训练函数

python 复制代码
# 四、编写训练函数
def train(model, trainloader, criterion, optimizer, epoch):
    """训练模型一个epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # 训练进度条
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}", leave=False)
    for batch_idx, (inputs, targets) in enumerate(progress_bar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计信息
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # 更新进度条
        progress_bar.set_postfix(loss=running_loss/(batch_idx+1), acc=100.*correct/total)
    
    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

5.编写推理函数

python 复制代码
# 五、编写推理函数
def test(model, testloader, criterion):
    """在测试集上评估模型"""
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_loss = test_loss / len(testloader)
    test_acc = 100. * correct / total
    return test_loss, test_ac

6.编写运行主函数

python 复制代码
# 六、主函数运行
def main():
    """主函数:组织训练和测试流程"""
    # 准备数据
    print("Preparing data...")
    trainloader, testloader, classes = prepare_data()
    
    # 构建模型
    print("Building model...")
    model = build_model()
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    # 只优化需要训练的参数
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                         lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    
    # 训练参数
    num_epochs = 100
    best_acc = 0
    
    # 记录训练过程
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []
    
    print("Starting training...")
    start_time = time.time()
    
    # 训练循环
    for epoch in range(num_epochs):
        # 训练
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, epoch)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # 测试
        test_loss, test_acc = test(model, testloader, criterion)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        # 更新学习率
        scheduler.step()
        
        # 保存最佳模型
        if test_acc > best_acc:
            print(f'Saving best model with accuracy: {test_acc:.2f}%')
            best_acc = test_acc
            torch.save(model.state_dict(), 'resnet50_cifar10_best.pth')
        
        # 打印结果
        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    
    total_time = time.time() - start_time
    print(f'Training completed in {total_time//60:.0f}m {total_time%60:.0f}s')
    print(f'Best test accuracy: {best_acc:.2f}%')
    
    # 保存最终模型
    torch.save(model.state_dict(), 'resnet50_cifar10_final.pth')
    
    # 绘制训练过程
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(test_accs, label='Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.savefig('training_results.png')
    plt.close()

if __name__ == '__main__':
    main()

7.运行结果展示

(mlstat) [haichao@node01 cifar]$ python demo1.py

Using device: cpu

Preparing data...

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [1:09:08<00:00, 41.1kB/s]

Building model...

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/haichao/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [01:30<00:00, 1.13MB/s]

Starting training...

Saving best model with accuracy: 76.80%

Epoch 1/100, Train Loss: 1.1939, Train Acc: 57.72%, Test Loss: 0.6765, Test Acc: 76.80%

Saving best model with accuracy: 83.43%

Epoch 2/100, Train Loss: 0.5763, Train Acc: 80.08%, Test Loss: 0.4767, Test Acc: 83.43%

Saving best model with accuracy: 87.68%

Epoch 3/100, Train Loss: 0.4238, Train Acc: 85.34%, Test Loss: 0.3579, Test Acc: 87.68%

Saving best model with accuracy: 88.45%

Epoch 4/100, Train Loss: 0.3421, Train Acc: 88.15%, Test Loss: 0.3430, Test Acc: 88.45%

Saving best model with accuracy: 89.71%

Epoch 5/100, Train Loss: 0.2856, Train Acc: 90.00%, Test Loss: 0.2943, Test Acc: 89.71%

Saving best model with accuracy: 90.81%

Epoch 6/100, Train Loss: 0.2477, Train Acc: 91.43%, Test Loss: 0.2668, Test Acc: 90.81%

Saving best model with accuracy: 91.19%

Epoch 7/100, Train Loss: 0.2098, Train Acc: 92.81%, Test Loss: 0.2614, Test Acc: 91.19%

Saving best model with accuracy: 91.64%

Epoch 8/100, Train Loss: 0.1903, Train Acc: 93.34%, Test Loss: 0.2588, Test Acc: 91.64%

Saving best model with accuracy: 91.67%

Epoch 9/100, Train Loss: 0.1713, Train Acc: 94.04%, Test Loss: 0.2498, Test Acc: 91.67%

Saving best model with accuracy: 91.74%

Epoch 10/100, Train Loss: 0.1534, Train Acc: 94.66%, Test Loss: 0.2517, Test Acc: 91.74%

Saving best model with accuracy: 91.88%

Epoch 11/100, Train Loss: 0.1398, Train Acc: 95.09%, Test Loss: 0.2474, Test Acc: 91.88%

Epoch 12/100, Train Loss: 0.1248, Train Acc: 95.49%, Test Loss: 0.2576, Test Acc: 91.78%

Saving best model with accuracy: 92.38%

Epoch 13/100, Train Loss: 0.1128, Train Acc: 95.99%, Test Loss: 0.2469, Test Acc: 92.38%

Epoch 14/100, Train Loss: 0.1031, Train Acc: 96.38%, Test Loss: 0.2492, Test Acc: 92.28%

Saving best model with accuracy: 92.42%

Epoch 15/100, Train Loss: 0.0950, Train Acc: 96.62%, Test Loss: 0.2462, Test Acc: 92.42%

Epoch 16/100, Train Loss: 0.0830, Train Acc: 97.09%, Test Loss: 0.2650, Test Acc: 92.34%

Epoch 17/100, Train Loss: 0.0807, Train Acc: 97.17%, Test Loss: 0.2655, Test Acc: 92.25%

Epoch 18/100, Train Loss: 0.0707, Train Acc: 97.52%, Test Loss: 0.2650, Test Acc: 92.08%

Epoch 19/100, Train Loss: 0.0719, Train Acc: 97.50%, Test Loss: 0.2786, Test Acc: 92.13%

Epoch 20/100, Train Loss: 0.0644, Train Acc: 97.78%, Test Loss: 0.2768, Test Acc: 91.91%

Epoch 21/100, Train Loss: 0.0604, Train Acc: 97.92%, Test Loss: 0.2606, Test Acc: 92.33%

Saving best model with accuracy: 92.43%

Epoch 22/100, Train Loss: 0.0576, Train Acc: 98.00%, Test Loss: 0.2669, Test Acc: 92.43%

Saving best model with accuracy: 92.52%

Epoch 23/100, Train Loss: 0.0652, Train Acc: 97.70%, Test Loss: 0.2641, Test Acc: 92.52%

Saving best model with accuracy: 92.84%

Epoch 24/100, Train Loss: 0.0520, Train Acc: 98.23%, Test Loss: 0.2680, Test Acc: 92.84%

Epoch 25/100, Train Loss: 0.0501, Train Acc: 98.33%, Test Loss: 0.2845, Test Acc: 92.15%

Epoch 26/100, Train Loss: 0.0443, Train Acc: 98.45%, Test Loss: 0.2744, Test Acc: 92.62%

Epoch 27/100, Train Loss: 0.0530, Train Acc: 98.15%, Test Loss: 0.2827, Test Acc: 92.09%

Epoch 28/100, Train Loss: 0.0618, Train Acc: 97.83%, Test Loss: 0.2654, Test Acc: 92.58%

Epoch 29/100, Train Loss: 0.0470, Train Acc: 98.42%, Test Loss: 0.2585, Test Acc: 92.75%

Epoch 30/100, Train Loss: 0.0468, Train Acc: 98.43%, Test Loss: 0.2678, Test Acc: 92.33%

Epoch 31/100, Train Loss: 0.0455, Train Acc: 98.45%, Test Loss: 0.2765, Test Acc: 92.38%

Saving best model with accuracy: 93.07%

Epoch 32/100, Train Loss: 0.0383, Train Acc: 98.67%, Test Loss: 0.2709, Test Acc: 93.07%

Saving best model with accuracy: 93.13%

Epoch 33/100, Train Loss: 0.0341, Train Acc: 98.83%, Test Loss: 0.2599, Test Acc: 93.13%

Epoch 34/100, Train Loss: 0.0308, Train Acc: 98.96%, Test Loss: 0.2650, Test Acc: 92.95%

Epoch 35/100, Train Loss: 0.0302, Train Acc: 99.00%, Test Loss: 0.2663, Test Acc: 93.10%

Epoch 36/100, Train Loss: 0.0365, Train Acc: 98.78%, Test Loss: 0.2682, Test Acc: 93.03%

Epoch 37/100, Train Loss: 0.0319, Train Acc: 98.90%, Test Loss: 0.2826, Test Acc: 92.80%

Epoch 38/100, Train Loss: 0.0337, Train Acc: 98.89%, Test Loss: 0.2660, Test Acc: 93.06%

Epoch 39/100, Train Loss: 0.0322, Train Acc: 98.89%, Test Loss: 0.2696, Test Acc: 93.02%

Epoch 40/100, Train Loss: 0.0289, Train Acc: 99.05%, Test Loss: 0.2725, Test Acc: 92.81%

Epoch 41/100, Train Loss: 0.0274, Train Acc: 99.08%, Test Loss: 0.2560, Test Acc: 93.07%

Saving best model with accuracy: 93.41%

Epoch 42/100, Train Loss: 0.0244, Train Acc: 99.18%, Test Loss: 0.2690, Test Acc: 93.41%

Epoch 43/100, Train Loss: 0.0222, Train Acc: 99.21%, Test Loss: 0.2687, Test Acc: 93.25%

Epoch 44/100, Train Loss: 0.0216, Train Acc: 99.29%, Test Loss: 0.2630, Test Acc: 93.17%

Epoch 45/100, Train Loss: 0.0217, Train Acc: 99.29%, Test Loss: 0.2694, Test Acc: 93.35%

Epoch 46/100, Train Loss: 0.0214, Train Acc: 99.33%, Test Loss: 0.2632, Test Acc: 93.21%

Saving best model with accuracy: 93.59%

Epoch 47/100, Train Loss: 0.0172, Train Acc: 99.45%, Test Loss: 0.2481, Test Acc: 93.59%

Epoch 48/100, Train Loss: 0.0164, Train Acc: 99.47%, Test Loss: 0.2558, Test Acc: 93.43%

Epoch 49: 50%|██████████████████████████████████████████████████████████████████████████████████████▏ | 197/391 [03:20<03:19, 1.03s/it, acc=99.4, loss=0.0164]

Epoch 49/100, Train Loss: 0.0166, Train Acc: 99.43%, Test Loss: 0.2684, Test Acc: 93.39%

Saving best model with accuracy: 93.78%

Epoch 50/100, Train Loss: 0.0164, Train Acc: 99.45%, Test Loss: 0.2503, Test Acc: 93.78%

Epoch 51/100, Train Loss: 0.0157, Train Acc: 99.50%, Test Loss: 0.2619, Test Acc: 93.63%

Epoch 52/100, Train Loss: 0.0143, Train Acc: 99.54%, Test Loss: 0.2692, Test Acc: 93.52%

Epoch 53/100, Train Loss: 0.0129, Train Acc: 99.62%, Test Loss: 0.2634, Test Acc: 93.35%

相关推荐
冬奇Lab36 分钟前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab36 分钟前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP4 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年4 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼5 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS5 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区6 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈6 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang7 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk18 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能