VGG16 迁移学习实战:CIFAR-10 图像分类优化方案

引言

VGGNet 是 2014 年 ILSVRC 比赛的亚军模型,以其简洁的设计理念(小卷积核 + 深度堆叠 )和强大的特征提取能力,成为深度学习领域的经典模型。本文基于 PyTorch 框架,结合迁移学习多项优化策略,使用 VGG16 模型对 CIFAR-10 数据集进行分类,在保证训练效率的同时,实现了较高的分类准确率。

一、VGG16 理论基础

1. 核心设计理念

VGGNet 的核心设计思想是:使用多个 3×3 小卷积核替代大卷积核,通过增加网络深度来提升性能。这种设计有以下优势:

  • 参数效率更高:3 个 3×3 卷积核的感受野与 1 个 7×7 卷积核相同,但参数数量更少(3×(3×3×C²) < 7×7×C²)
  • 更强的特征表达能力:多个非线性激活层(ReLU)增加了网络的非线性表达能力
  • 更灵活的感受野:深度堆叠的小卷积核能够学习更复杂的特征层次

2. VGG16 网络结构

VGG16 包含16 层可训练层(13 个卷积层 + 3 个全连接层)。

3. 迁移学习策略

针对 CIFAR-10 数据集,我们采用部分层解冻的迁移学习策略

  • 冻结 VGG16 的前 24 层(大部分卷积层),保留预训练的特征提取能力
  • 解冻后 6 层卷积层(24-29 层),允许模型微调适应 CIFAR-10 的特征
  • 替换最后一层全连接层,输出类别数从 1000 改为 10

二、实验配置

3. 核心配置参数

python 复制代码
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 10  # 平衡内存与速度
EPOCHS = 8  # 快速收敛
LEARNING_RATE = 4e-4  # 适合迁移学习的学习率
NUM_CLASSES = 10  # CIFAR-10类别数

三、代码实现与优化

1. 数据预处理优化

针对 CIFAR-10 数据集的特点,我们设计了高效的数据预处理流程

python 复制代码
# CIFAR-10专用归一化参数
cifar_mean = [0.4914, 0.4822, 0.4465]
cifar_std = [0.2023, 0.1994, 0.2010]

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 保留核心增强,提升泛化
    transforms.Resize((144, 144)),  # 优化点1:144×144输入(速度+准确率平衡点)
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std)
])

优化点解析

  • 使用144×144 输入尺寸:相比 AlexNet 的 224×224,减少了计算量,同时保持了较高的特征提取能力
  • 保留核心数据增强:随机裁剪和水平翻转,有效减少过拟合
  • 使用CIFAR-10 专用归一化参数:相比 ImageNet 的归一化参数,更适合 CIFAR-10 数据集

2. 模型构建与冻结策略

python 复制代码
model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

# 优化点2:部分层解冻策略
for param in model.features[:24].parameters():
    param.requires_grad = False  # 冻结前24层
for param in model.features[24:].parameters():
    param.requires_grad = True  # 解冻后6层卷积

# 优化点3:保留Dropout,防止过拟合
in_features = model.classifier[6].in_features
model.classifier[6] = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(in_features, NUM_CLASSES)
)
model = model.to(DEVICE)

优化点解析

  • 部分层解冻:仅解冻后 6 层卷积,兼顾特征微调与训练速度
  • 保留 Dropout:在分类层前添加 Dropout (0.5),有效防止多轮训练过拟合
  • 使用预训练权重:加载 ImageNet 预训练权重,加速模型收敛

3. 优化器与学习率调度

python 复制代码
# 优化点4:使用AdamW优化器+权重衰减
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE, weight_decay=8e-5  # 权重衰减抑制过拟合
)

# 优化点5:动态学习率调度
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=2, factor=0.5
)

优化点解析

  • AdamW 优化器:相比传统 Adam,结合了权重衰减,更适合深度学习训练
  • 动态学习率:当验证准确率不再提升时,自动将学习率减半,加速收敛
  • 仅优化可训练参数 :使用filter函数只优化解冻的层,减少计算量

4. 训练函数优化

python 复制代码
def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs):
    train_losses = []
    test_accuracies = []
    best_acc = 0.0
    batch_print_interval = 600  # 优化点6:降低打印频率,减少CPU IO开销

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        running_loss = 0.0

        # 训练阶段
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(DEVICE), targets.to(DEVICE)

            # 前向+反向传播
            outputs = model(data)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * data.size(0)

            # 批量打印(每600批次一次)
            if batch_idx % batch_print_interval == 0 and batch_idx != 0:
                print(f'  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

        # 验证阶段(精简代码,减少冗余计算)
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(DEVICE), targets.to(DEVICE)
                outputs = model(data)
                _, predicted = torch.max(outputs, 1)
                test_correct += (predicted == targets).sum().item()

        test_acc = 100 * test_correct / len(test_loader.dataset)
        test_accuracies.append(test_acc)

        # 学习率调度
        scheduler.step(test_acc)

        # 保存最佳模型(仅最佳时保存,减少IO)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_vgg16_cifar10_final.pth")

        print(f'Epoch [{epoch+1}/{epochs}] | Loss: {epoch_loss:.4f} | Test Acc: {test_acc:.2f}% | Time: {epoch_time:.1f}s')

    return train_losses, test_accuracies

优化点解析

  • 降低打印频率:每 600 批次打印一次,减少 CPU IO 开销
  • 精简验证代码:去除冗余计算,提高验证速度
  • 仅保存最佳模型:避免频繁写入磁盘,减少 IO 操作
  • 记录核心指标:仅记录训练损失和测试准确率,简化日志

5. 可视化优化

python 复制代码
def plot_training_curves(train_losses, test_accuracies):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))  # 优化点7:缩小尺寸,加快渲染

    # 损失曲线
    ax1.plot(range(1, EPOCHS+1), train_losses, 'b-', linewidth=2, label='Train Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('VGG16 Training Loss')
    ax1.grid(True, alpha=0.3)
    ax1.legend()

    # 准确率曲线
    ax2.plot(range(1, EPOCHS+1), test_accuracies, 'r-', linewidth=2, label='Test Acc')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('VGG16 Test Accuracy')
    ax2.grid(True, alpha=0.3)
    ax2.legend()

    plt.suptitle('VGG16 Training History (8 Epochs)', fontsize=14)
    plt.tight_layout()
    plt.savefig('vgg_training_curves_final.png', dpi=100)  # 优化点8:降低dpi,加快保存
    plt.show()

优化点解析

  • 缩小图像尺寸:从 15×5 改为 10×4,加快渲染速度
  • 降低保存 dpi:从 300 改为 100,减少图像文件大小,加快保存速度
  • 简化绘图样式:使用简洁的线条和标题,提高可读性

四、实验结果与分析

1. 数据集样本展示

2. 训练曲线

4. 分类结果展示

完整代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torchvision.models import VGG16_Weights
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# ---------------------- 1. 核心配置 ----------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 10  # 平衡内存与速度,CPU无压力
EPOCHS = 8
LEARNING_RATE = 4e-4
NUM_CLASSES = 10
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

# ---------------------- 2. 数据预处理 ----------------------
cifar_mean = [0.4914, 0.4822, 0.4465]
cifar_std = [0.2023, 0.1994, 0.2010]

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 保留核心增强(提升泛化,避免识别错误)
    transforms.Resize((144, 144)),  # 核心优化:144×144(速度+准确率平衡点)
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std)
])

transform_test = transforms.Compose([
    transforms.Resize((144, 144)),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar_mean, std=cifar_std)
])

# ---------------------- 3. 数据加载 ----------------------
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0, pin_memory=False  # CPU禁用多线程+锁存
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=0, pin_memory=False
)


# ---------------------- 4. 数据集展示 ----------------------
def show_dataset_samples():
    data_iter = iter(train_loader)
    images, labels = next(data_iter)

    fig, axes = plt.subplots(2, 5, figsize=(10, 4))  # 精简布局,减少绘图耗时
    fig.suptitle('CIFAR-10 Dataset Samples (144x144)', fontsize=14, fontweight='bold')

    for i in range(10):
        row, col = i // 5, i % 5
        img = images[i].numpy().transpose((1, 2, 0))
        img = img * cifar_std + cifar_mean
        img = np.clip(img, 0, 1)

        axes[row, col].imshow(img)
        axes[row, col].set_title(classes[labels[i]], fontsize=8)
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.savefig('vgg_dataset_samples_fast.png', dpi=100)  # 降低dpi,加快保存
    plt.show()


print("展示数据集样本...")
show_dataset_samples()

# ---------------------- 5. VGG16模型 ----------------------
model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

# 优化冻结策略:解冻6层卷积(24-29层),兼顾特征微调与训练速度
for param in model.features[:24].parameters():
    param.requires_grad = False
for param in model.features[24:].parameters():
    param.requires_grad = True

# 保留dropout+适配10分类(防止多轮过拟合)
in_features = model.classifier[6].in_features
model.classifier[6] = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(in_features, NUM_CLASSES)
)
model = model.to(DEVICE)

# ---------------------- 6. 优化器 ----------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE, weight_decay=8e-5  # 抑制过拟合
)

# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=2, factor=0.5
)


# ---------------------- 7. 训练函数 ----------------------
def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs):
    train_losses = []
    test_accuracies = []
    best_acc = 0.0
    prev_lr = optimizer.param_groups[0]['lr']
    batch_print_interval = 600  # 降低打印频率,减少CPU IO开销

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        running_loss = 0.0

        # 训练阶段(精简循环,减少冗余计算)
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(DEVICE), targets.to(DEVICE)

            # 前向+反向传播
            outputs = model(data)
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * data.size(0)

            # 批量打印(每600批次一次)
            if batch_idx % batch_print_interval == 0 and batch_idx != 0:
                print(f'  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

        # 计算训练损失
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)

        # 验证阶段
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(DEVICE), targets.to(DEVICE)
                outputs = model(data)
                _, predicted = torch.max(outputs, 1)
                test_correct += (predicted == targets).sum().item()

        test_acc = 100 * test_correct / len(test_loader.dataset)
        test_accuracies.append(test_acc)
        epoch_time = time.time() - start_time

        # 学习率调度
        scheduler.step(test_acc)
        current_lr = optimizer.param_groups[0]['lr']
        if current_lr != prev_lr:
            print(f'  学习率调整:{prev_lr:.6f} → {current_lr:.6f}')
            prev_lr = current_lr

        # 打印核心信息
        print(
            f'Epoch [{epoch + 1}/{epochs}] | Loss: {epoch_loss:.4f} | Test Acc: {test_acc:.2f}% | Time: {epoch_time:.1f}s')

        # 保存最佳模型(仅最佳时保存,减少IO)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_vgg16_cifar10_final.pth")
            print(f'  Best model saved! Acc: {best_acc:.2f}%')

    return train_losses, test_accuracies


# ---------------------- 8. 训练曲线可视化 ----------------------
def plot_training_curves(train_losses, test_accuracies):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))  # 缩小尺寸,加快渲染

    # 损失曲线
    ax1.plot(range(1, EPOCHS + 1), train_losses, 'b-', linewidth=2, label='Train Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('VGG16 Training Loss')
    ax1.grid(True, alpha=0.3)
    ax1.legend()

    # 准确率曲线
    ax2.plot(range(1, EPOCHS + 1), test_accuracies, 'r-', linewidth=2, label='Test Acc')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('VGG16 Test Accuracy')
    ax2.grid(True, alpha=0.3)
    ax2.legend()

    plt.suptitle('VGG16 Training History (8 Epochs)', fontsize=14)
    plt.tight_layout()
    plt.savefig('vgg_training_curves_final.png', dpi=100)
    plt.show()


# ---------------------- 9. 分类结果展示(验证识别正确性) ----------------------
def show_classification_results(model, test_loader):
    model.eval()
    images, labels = next(iter(test_loader))
    images = images.to(DEVICE)

    with torch.no_grad():
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)

    fig, axes = plt.subplots(2, 5, figsize=(12, 5))  # 展示10张图,全面验证
    fig.suptitle('VGG16 Classification Results (8 Epochs)', fontsize=14, fontweight='bold')

    for i in range(10):
        row, col = i // 5, i % 5
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        img = img * cifar_std + cifar_mean
        img = np.clip(img, 0, 1)

        true_label = classes[labels[i]]
        pred_label = classes[predictions[i]]
        color = 'green' if true_label == pred_label else 'red'
        axes[row, col].set_title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=9)
        axes[row, col].imshow(img)
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.savefig('vgg_classification_results_final.png', dpi=100)
    plt.show()

    # 计算示例批次准确率(验证整体识别效果)
    batch_acc = 100 * (predictions.cpu() == labels).sum().item() / len(labels)
    print(f'示例批次准确率: {batch_acc:.2f}%')


# ---------------------- 10. 主训练流程 ----------------------
if __name__ == "__main__":
    print("=" * 60)
    print(f"VGG16最终训练启动 | Epochs={EPOCHS}, Batch={BATCH_SIZE}, LR={LEARNING_RATE}")
    print(f"训练设备: {DEVICE} | 输入尺寸: 144x144 | 解冻6层卷积")
    print("=" * 60)

    train_losses, test_accuracies = train_model(
        model, train_loader, test_loader, criterion, optimizer, scheduler, EPOCHS
    )

    # 输出核心结果
    final_acc = test_accuracies[-1]
    best_acc = max(test_accuracies)
    print(f"\n训练完成! 最终准确率: {final_acc:.2f}% | 最佳准确率: {best_acc:.2f}%")

    # 核心可视化(验证识别正确性)
    plot_training_curves(train_losses, test_accuracies)
    show_classification_results(model, test_loader)


# ---------------------- 11. 单张图像预测 ----------------------
def predict_single_image(image_path="test_image.jpg"):
    # 构建模型并加载最佳权重
    model = models.vgg16(weights=None)
    in_features = model.classifier[6].in_features
    model.classifier[6] = nn.Sequential(nn.Dropout(0.5), nn.Linear(in_features, NUM_CLASSES))
    model.load_state_dict(torch.load("best_vgg16_cifar10_final.pth", map_location=DEVICE))
    model.to(DEVICE).eval()

    # 预处理匹配144x144输入
    transform = transforms.Compose([
        transforms.Resize((144, 144)),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar_mean, std=cifar_std)
    ])

    try:
        image = Image.open(image_path).convert("RGB")
        plt.figure(figsize=(5, 5)), plt.imshow(image), plt.axis('off'), plt.title("Input Image"), plt.show()

        with torch.no_grad():
            img_tensor = transform(image).unsqueeze(0).to(DEVICE)
            outputs = model(img_tensor)
            probs = torch.softmax(outputs, dim=1)
            conf, pred = torch.max(probs, 1)

        # 输出详细信息,验证识别正确性
        print(f"\nVGG16预测结果: {classes[pred.item()]}")
        print(f"置信度: {conf.item():.4f}")
        print("\nTop3预测概率(验证可靠性):")
        top3_idx = torch.topk(probs, 3)[1].cpu().numpy()[0]
        for idx in top3_idx:
            print(f"  {classes[idx]:12s}: {probs[0][idx].item():.4f}")
    except FileNotFoundError:
        print(f"错误:未找到图像文件 {image_path}")


# 示例调用(训练完成后执行)
predict_single_image("test_image.jpg")

该模型训练时间较长,可尝试减小输入尺寸、减少解冻卷积层数等提升训练速度。

相关推荐
小糖豆巴拉巴拉10 小时前
AI应用(1)-基础概念的理解
人工智能
CES_Asia11 小时前
亚洲科技话语权之争:CES Asia 2026核心展区席位进入收官阶段
大数据·人工智能·科技·物联网·机器人
一个会的不多的人11 小时前
人工智能基础篇:概念性名词浅谈(第十四讲)
人工智能·制造·数字化转型
Brduino脑机接口技术答疑11 小时前
TDCA 算法在 SSVEP 场景中:Padding 的应用对象与工程实践指南
人工智能·python·算法·数据分析·脑机接口·eeg
玄同76511 小时前
Python 装饰器:LLM API 的安全与可观测性增强
开发语言·人工智能·python·安全·自然语言处理·numpy·装饰器
房产中介行业研习社11 小时前
市面上比较主流的房产中介管理系统有哪些推荐?
大数据·人工智能·房产直播技巧·房产直播培训
学习3人组11 小时前
目标检测模型选型+训练调参极简步骤清单
人工智能·目标检测·决策树
Yeats_Liao11 小时前
MindSpore开发之路(十七):静态图 vs. 动态图:掌握MindSpore的两种执行模式
人工智能·深度学习·机器学习
keep_learning11111 小时前
Z-Image模型架构全解析
人工智能·算法·计算机视觉·大模型·多模态
雅欣鱼子酱11 小时前
Type-C接口小家电 PD诱骗电压方案
人工智能·芯片·电子元器件