AlexNet 迁移学习实战:CIFAR-10 图像分类实验

引言

2012 年,AlexNet 在 ImageNet 图像分类比赛中以显著优势夺冠,开启了深度学习在计算机视觉领域的黄金时代。AlexNet 通过更深的网络结构、ReLU 激活函数、Dropout 正则化等创新,大幅提升了图像分类精度,成为卷积神经网络(CNN)发展史上的里程碑。

一、AlexNet 理论基础

1. 网络结构

AlexNet 包含 8 个可训练层:5 个卷积层(Conv)和 3 个全连接层(FC),AlexNet结构图:

2. 核心创新点

  • ReLU 激活函数:解决了 Sigmoid 的梯度消失问题,加速训练
  • Dropout:随机失活神经元,减少过拟合
  • LRN(局部响应归一化):增强模型泛化能力
  • 数据增强:通过随机裁剪、翻转等方式扩充数据集
  • 多 GPU 训练:将网络分为两部分,分别在两个 GPU 上训练

3. 迁移学习原理

迁移学习是将预训练模型在大规模数据集(如 ImageNet)上学到的特征提取能力迁移到新任务中。对于 CIFAR-10 分类任务,可以:

  • 冻结 AlexNet 的前 7 层(特征提取层)
  • 替换最后一层全连接层,输出类别数从 1000 改为 10
  • 仅训练新的全连接层,快速适应新任务

二、实验环境

  • 框架:PyTorch 2.0+
  • 数据集:CIFAR-10(60000 张 32×32 彩色图像,10 个类别)
  • 设备:CPU
  • 主要库:torchvision、matplotlib、numpy、PIL

三、代码实现

1. 配置参数

python 复制代码
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16  # CPU优化:减小batch size
EPOCHS = 10       # CPU优化:减少训练轮数
LEARNING_RATE = 0.0008
NUM_CLASSES = 10
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

2. 数据预处理

python 复制代码
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),  # AlexNet输入要求224×224
    transforms.RandomHorizontalFlip(p=0.5),  # 数据增强
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet均值
                         std=[0.229, 0.224, 0.225])   # ImageNet方差
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

3. 加载数据集

python 复制代码
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0 if DEVICE.type == 'cpu' else 4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0 if DEVICE.type == 'cpu' else 4)

4. 数据集展示

python 复制代码
def show_dataset_samples(dataset, num_samples=16, title="CIFAR-10 Samples"):
    fig, axes = plt.subplots(4, 4, figsize=(12, 10))
    for i in range(num_samples):
        idx = np.random.randint(0, len(dataset))
        img, label = dataset[idx]
        img = img.permute(1, 2, 0).cpu().numpy()  # 转换为HWC格式
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # 反归一化
        img = np.clip(img, 0, 1)
        axes[i//4, i%4].imshow(img)
        axes[i//4, i%4].set_title(f"Class: {CLASSES[label]}")
        axes[i//4, i%4].axis('off')
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

5. 构建 AlexNet 模型

python 复制代码
def build_alexnet(num_classes=NUM_CLASSES):
    model = models.alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)  # 加载预训练权重
    for param in model.parameters():
        param.requires_grad = False  # 冻结所有层
    in_features = model.classifier[6].in_features  # 获取最后一层输入特征数
    model.classifier[6] = nn.Linear(in_features, num_classes)  # 替换最后一层
    return model.to(DEVICE)

6. 训练函数

python 复制代码
def train_model(model, train_loader, test_loader, criterion, optimizer, epochs):
    best_acc = 0.0
    train_loss_history = []
    test_loss_history = []
    test_acc_history = []

    for epoch in range(epochs):
        start_time = time.time()
        running_train_loss = 0.0

        # 训练阶段
        model.train()
        for data, targets in train_loader:
            data, targets = data.to(DEVICE), targets.to(DEVICE)
            outputs = model(data)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() * data.size(0)

        # 验证阶段
        model.eval()
        running_test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(DEVICE), targets.to(DEVICE)
                outputs = model(data)
                loss = criterion(outputs, targets)

                running_test_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs.data, 1)
                test_total += targets.size(0)
                test_correct += (predicted == targets).sum().item()

        # 计算指标
        epoch_train_loss = running_train_loss / len(train_loader.dataset)
        epoch_test_loss = running_test_loss / len(test_loader.dataset)
        epoch_test_acc = 100 * test_correct / test_total

        # 保存最佳模型
        if epoch_test_acc > best_acc:
            best_acc = epoch_test_acc
            torch.save(model.state_dict(), "best_alexnet_cifar10.pth")

        # 记录历史
        train_loss_history.append(epoch_train_loss)
        test_loss_history.append(epoch_test_loss)
        test_acc_history.append(epoch_test_acc)

        print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {epoch_train_loss:.4f} | Test Loss: {epoch_test_loss:.4f} | Test Acc: {epoch_test_acc:.2f}% | Time: {time.time()-start_time:.1f}s")

    return train_loss_history, test_loss_history, test_acc_history

7. 训练过程可视化

python 复制代码
def plot_training_history(train_loss, test_loss, test_acc):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(range(1, len(train_loss)+1), train_loss, label='Train Loss', color='blue')
    ax1.plot(range(1, len(test_loss)+1), test_loss, label='Test Loss', color='red')
    ax1.set_title('Loss Curve')
    ax1.legend()

    ax2.plot(range(1, len(test_acc)+1), test_acc, label='Test Accuracy', color='green')
    ax2.set_title('Accuracy Curve')
    ax2.legend()

    plt.suptitle('AlexNet Training History', fontsize=16)
    plt.tight_layout()
    plt.show()

8. 分类结果展示

python 复制代码
def show_classification_results(model, test_loader, num_samples=16):
    model.eval()
    fig, axes = plt.subplots(4, 4, figsize=(14, 12))
    data, targets = next(iter(test_loader))
    data, targets = data.to(DEVICE), targets.to(DEVICE)
    outputs = model(data)
    _, predicted = torch.max(outputs, 1)

    for i in range(num_samples):
        img = data[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        true_label = CLASSES[targets[i].item()]
        pred_label = CLASSES[predicted[i].item()]
        color = 'green' if true_label == pred_label else 'red'

        axes[i//4, i%4].imshow(img)
        axes[i//4, i%4].set_title(f"True: {true_label}\nPred: {pred_label}", color=color)
        axes[i//4, i%4].axis('off')

    plt.suptitle('Classification Results', fontsize=16)
    plt.tight_layout()
    plt.show()

9. 单张图像预测

python 复制代码
def predict_single_image(image_path, model):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(DEVICE)
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        _, pred = torch.max(output, 1)
    plt.figure(figsize=(6, 4))
    plt.imshow(image)
    plt.title(f"Predicted Class: {CLASSES[pred.item()]}", fontsize=14)
    plt.axis('off')
    plt.show()

10. 主函数

python 复制代码
if __name__ == "__main__":
    # 展示数据集
    show_dataset_samples(train_dataset)
    
    # 构建模型
    model = build_alexnet()
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.classifier[6].parameters(), lr=LEARNING_RATE)
    
    # 训练模型
    train_loss, test_loss, test_acc = train_model(model, train_loader, test_loader, criterion, optimizer, EPOCHS)
    
    # 可视化训练过程
    plot_training_history(train_loss, test_loss, test_acc)
    
    # 展示分类结果
    show_classification_results(model, test_loader)
    
    # 单张图像预测
    predict_single_image("test_image.jpg", model)

四、实验结果

1. 数据集样本

2. 训练历史曲线

3. 分类结果

五、分析与讨论

1. 迁移学习效果

使用预训练 AlexNet 仅训练分类层,在 10 轮训练后达到 81.65% 的测试准确率,具有很强的通用性,能够快速适应新任务。

2. CPU 优化效果

通过减小 batch size(16)、减少训练轮数(10)、关闭多进程加载(num_workers=0),代码在 CPU 上能够稳定运行,适合没有 GPU 资源的开发者进行实验。

3. 进一步优化方向

  • 微调卷积层:解冻最后 1-2 个卷积层,联合训练分类层,可能进一步提升准确率
  • 数据增强:添加随机旋转、裁剪、颜色抖动等增强手段,减少过拟合
  • 学习率调度:使用 StepLR 或 ReduceLROnPlateau 动态调整学习率
  • 模型集成:结合多个预训练模型(如 VGG、ResNet)的预测结果

六、完整代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torchvision.models import AlexNet_Weights
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import time
import warnings

warnings.filterwarnings('ignore')

# ---------------------- 1. 配置参数 ----------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16  # CPU优化:减小batch size提升速度
EPOCHS = 10  # CPU优化:减少epochs快速收敛
LEARNING_RATE = 0.0008
NUM_CLASSES = 10
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

# ---------------------- 2. 数据预处理 ----------------------
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# ---------------------- 3. 加载数据集 ----------------------
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0 if DEVICE.type == 'cpu' else 4
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0 if DEVICE.type == 'cpu' else 4
)


# ---------------------- 4. 数据集展示函数----------------------
def show_dataset_samples(dataset, num_samples=16, title="CIFAR-10 Training Dataset Samples"):
    fig, axes = plt.subplots(4, 4, figsize=(12, 10))
    axes = axes.flatten()

    for i in range(num_samples):
        idx = np.random.randint(0, len(dataset))
        img, label = dataset[idx]

        # 修复警告:先转NumPy数组再运算
        img = img.permute(1, 2, 0).cpu().numpy()  # Tensor→NumPy(HWC格式)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = img * std + mean
        img = np.clip(img, 0, 1)

        axes[i].imshow(img)
        axes[i].set_title(f"Class: {CLASSES[label]}")
        axes[i].axis('off')

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()


# ---------------------- 5. 构建AlexNet模型----------------------
def build_alexnet(num_classes=NUM_CLASSES):
    model = models.alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
    for param in model.parameters():
        param.requires_grad = False
    in_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(in_features, num_classes)
    return model.to(DEVICE)


# ---------------------- 6. 训练函数 ----------------------
def train_model(model, train_loader, test_loader, criterion, optimizer, epochs):
    model.train()
    best_acc = 0.0
    train_loss_history = []
    test_loss_history = []
    test_acc_history = []

    for epoch in range(epochs):
        start_time = time.time()
        running_train_loss = 0.0

        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(DEVICE), targets.to(DEVICE)
            outputs = model(data)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item() * data.size(0)

        epoch_train_loss = running_train_loss / len(train_loader.dataset)
        train_loss_history.append(epoch_train_loss)

        model.eval()
        running_test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(DEVICE), targets.to(DEVICE)
                outputs = model(data)
                loss = criterion(outputs, targets)

                running_test_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs.data, 1)
                test_total += targets.size(0)
                test_correct += (predicted == targets).sum().item()

        epoch_test_loss = running_test_loss / len(test_loader.dataset)
        epoch_test_acc = 100 * test_correct / test_total
        test_loss_history.append(epoch_test_loss)
        test_acc_history.append(epoch_test_acc)

        epoch_time = time.time() - start_time

        print(f"Epoch [{epoch + 1}/{epochs}] | "
              f"Train Loss: {epoch_train_loss:.4f} | "
              f"Test Loss: {epoch_test_loss:.4f} | "
              f"Test Acc: {epoch_test_acc:.2f}% | "
              f"Time: {epoch_time:.1f}s")

        if epoch_test_acc > best_acc:
            best_acc = epoch_test_acc
            torch.save(model.state_dict(), "best_alexnet_cifar10.pth")
            print(f"Best model saved! Current Best Acc: {best_acc:.2f}%")

        model.train()

    return train_loss_history, test_loss_history, test_acc_history


# ---------------------- 7. 训练过程可视化 ----------------------
def plot_training_history(train_loss, test_loss, test_acc):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(range(1, len(train_loss) + 1), train_loss, label='Train Loss', linewidth=2, color='blue')
    ax1.plot(range(1, len(test_loss) + 1), test_loss, label='Test Loss', linewidth=2, color='red')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('AlexNet Training & Test Loss Curve')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    ax2.plot(range(1, len(test_acc) + 1), test_acc, label='Test Accuracy', linewidth=2, color='green')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('AlexNet Test Accuracy Curve')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.suptitle('AlexNet Training History', fontsize=16)
    plt.tight_layout()
    plt.show()


# ---------------------- 8. 分类结果展示函数 ----------------------
def show_classification_results(model, test_loader, num_samples=16):
    model.eval()
    fig, axes = plt.subplots(4, 4, figsize=(14, 12))
    axes = axes.flatten()

    with torch.no_grad():
        data, targets = next(iter(test_loader))
        data, targets = data.to(DEVICE), targets.to(DEVICE)
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)

        for i in range(num_samples):
            # 修复警告:Tensor→NumPy
            img = data[i].cpu().permute(1, 2, 0).numpy()
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = img * std + mean
            img = np.clip(img, 0, 1)

            true_label = CLASSES[targets[i].item()]
            pred_label = CLASSES[predicted[i].item()]
            color = 'green' if true_label == pred_label else 'red'

            axes[i].imshow(img)
            axes[i].set_title(f"True: {true_label}\nPred: {pred_label}", color=color)
            axes[i].axis('off')

    plt.suptitle('AlexNet Classification Results on Test Set', fontsize=16)
    plt.tight_layout()
    plt.show()


# ---------------------- 9. 单张图像预测函数 ----------------------
def predict_single_image(image_path, model):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    try:
        image = Image.open(image_path).convert("RGB")
        img_tensor = transform(image).unsqueeze(0).to(DEVICE)

        model.eval()
        with torch.no_grad():
            output = model(img_tensor)
            _, pred = torch.max(output, 1)

        plt.figure(figsize=(6, 4))
        plt.imshow(image)
        plt.title(f"Predicted Class: {CLASSES[pred.item()]}", fontsize=14)
        plt.axis('off')
        plt.show()
    except FileNotFoundError:
        print(f"Warning: File '{image_path}' not found. Skip single image prediction.")


# ---------------------- 10. 主函数 ----------------------
if __name__ == "__main__":
    # 1. 展示数据集样本
    print("=== Showing CIFAR-10 Dataset Samples ===")
    show_dataset_samples(train_dataset)

    # 2. 构建模型
    print(f"\n=== Building AlexNet Model (Training on {DEVICE}) ===")
    model = build_alexnet()
    print("Model structure shown above (feature layers frozen, only last FC layer trained)")

    # 3. 定义损失函数与优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.classifier[6].parameters(), lr=LEARNING_RATE)

    # 4. 启动训练(CPU优化版)
    print("\n=== Starting Training (CPU Optimized) ===")
    print(f"Batch Size: {BATCH_SIZE}, Epochs: {EPOCHS}, LR: {LEARNING_RATE}")
    print("CPU训练速度较慢,请耐心等待...")

    train_loss, test_loss, test_acc = train_model(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        epochs=EPOCHS
    )

    # 5. 绘制训练历史曲线
    print("\n=== Plotting Training History ===")
    plot_training_history(train_loss, test_loss, test_acc)

    # 6. 展示分类结果
    print("\n=== Showing Classification Results ===")
    show_classification_results(model, test_loader)

    # 7. 单张图像预测
    print("\n=== Predicting Single Image ===")
    predict_single_image("test_image.jpg", model)

    # 8. 最终结果汇总
    best_acc = max(test_acc)
    print("\n=== Experiment Summary ===")
    print(f"Training Device: {DEVICE}")
    print(f"Best Test Accuracy: {best_acc:.2f}%")
    print(f"Trained Epochs: {EPOCHS}")
    print(f"Best model saved as 'best_alexnet_cifar10.pth'")
    print("Experiment Completed Successfully!")

参考文献

  1. PyTorch 官方文档:https://pytorch.org/docs/stable/index.html
  2. CIFAR-10 数据集:https://www.cs.toronto.edu/~kriz/cifar.html
相关推荐
AKAMAI8 小时前
无服务器计算架构的优势
人工智能·云计算
阿星AI工作室8 小时前
gemini3手势互动圣诞树保姆级教程来了!附提示词
前端·人工智能
刘一说8 小时前
时空大数据与AI融合:重塑物理世界的智能中枢
大数据·人工智能·gis
月亮月亮要去太阳8 小时前
基于机器学习的糖尿病预测
人工智能·机器学习
Oflycomm8 小时前
LitePoint 2025:以 Wi-Fi 8 与光通信测试推动下一代无线创新
人工智能·wifi模块·wifi7模块
机器之心8 小时前
「豆包手机」为何能靠超级Agent火遍全网,我们听听AI学者们怎么说
人工智能·openai
monster000w8 小时前
大模型微调过程
人工智能·深度学习·算法·计算机视觉·信息与通信
机器之心8 小时前
一手实测 | 智谱AutoGLM重磅开源: AI手机的「安卓时刻」正式到来
人工智能·openai
算家计算8 小时前
解禁H200却留有后手!美国这波“卖芯片”,是让步还是埋坑?
人工智能·资讯
GIS数据转换器8 小时前
综合安防数智管理平台
大数据·网络·人工智能·安全·无人机