使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类

本文将展示如何使用 PyTorch 实现一个经典的 VGGNet 网络,并在 MNIST 数据集上进行训练和测试。我们将从模型构建开始,涵盖数据预处理、模型训练、评估、保存与加载模型,以及可视化预测结果等全过程。


1. VGGNet 模型的实现

首先,我们实现一个标准的 VGGNet 网络。VGGNet 是一个深度卷积神经网络,它由多个卷积层和全连接层组成,广泛应用于图像分类任务。

VGGNet 模型结构:
  • 卷积层 :VGGNet 采用了简单的结构,使用多个卷积层,每层卷积后跟一个 ReLU 激活函数和一个 最大池化 层。
  • 全连接层:经过卷积层提取特征后,VGGNet 会将特征图展平,并通过全连接层进行分类。
python 复制代码
import torch.nn as nn

class VGG(nn.Module):
    def __init__(self, num_classes=10, input_channels=1):
        """
        VGG 网络的初始化方法,包含卷积层和全连接层。

        参数:
        - num_classes (int): 分类的类别数量,默认 10 (适用于 MNIST)
        - input_channels (int): 输入图片的通道数,默认 1 (适用于灰度图像)
        """
        super(VGG, self).__init__()

        # 构建卷积层部分
        self.features = self._make_layers(input_channels)

        # 构建分类器部分
        self.classifier = self._make_classifier(num_classes)

    def _make_layers(self, input_channels):
        """
        构建卷积层部分,通过堆叠卷积层、ReLU 激活和池化层来构建特征提取部分

        参数:
        - input_channels (int): 输入图像的通道数,默认为 1(灰度图)

        返回:
        - features (nn.Sequential): 包含卷积层和池化层的神经网络模块
        """
        layers = []
        # 卷积块 1
        layers += self._conv_block(input_channels, 64)
        # 卷积块 2
        layers += self._conv_block(64, 128)
        # 卷积块 3
        layers += self._conv_block(128, 256)
        # 卷积块 4
        layers += self._conv_block(256, 512)

        # 将所有卷积块和池化层堆叠在一起
        return nn.Sequential(*layers)

    def _conv_block(self, in_channels, out_channels):
        """
        创建一个卷积块,包含两个卷积层和一个最大池化层

        参数:
        - in_channels (int): 输入通道数
        - out_channels (int): 输出通道数

        返回:
        - block (list): 卷积块 [卷积层 + ReLU + 卷积层 + ReLU + 最大池化层]
        """
        block = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ]
        return block

    def _make_classifier(self, num_classes):
        """
        构建全连接层部分,最后的输出层为分类层。

        参数:
        - num_classes (int): 分类类别数

        返回:
        - classifier (nn.Sequential): 包含全连接层和 Dropout 层的网络模块
        """
        return nn.Sequential(
            nn.Linear(512 * 1 * 1, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        """
        前向传播方法,输入图像通过卷积层提取特征后再通过全连接层进行分类。

        参数:
        - x (Tensor): 输入的图像数据

        返回:
        - x (Tensor): 分类结果
        """
        # 通过卷积层提取特征
        x = self.features(x)

        # 将特征图展平为一维向量
        x = x.view(x.size(0), -1)  # 这里将 4D 张量转换为 2D,保留 batch_size

        # 通过分类器进行最终分类
        x = self.classifier(x)

        return x

2. 训练模型

使用 PyTorch 实现的 VGGNet 网络后,我们需要对模型进行训练。在这个过程中,我们会使用 AdamW 优化器、交叉熵损失 以及 混合精度训练 来提升训练效率。

python 复制代码
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast


def get_data_loader(batch_size=64, num_workers=2):
    """ 获取 MNIST 数据加载器 """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root='D:/workspace/data', train=True, download=True, transform=transform)
    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


def initialize_model(device, num_classes=10):
    """ 初始化模型、优化器和损失函数 """
    model = VGG(num_classes=num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    criterion = torch.nn.CrossEntropyLoss()
    return model, optimizer, criterion


def train_epoch(model, train_loader, device, criterion, optimizer, scaler):
    """ 训练一个 epoch,并返回该 epoch 的平均损失和准确率 """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(train_loader, desc="Training", unit="batch", ncols=100) as pbar:
        for data, target in pbar:
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)

            optimizer.zero_grad()

            # 混合精度训练
            with autocast():
                output = model(data)
                loss = criterion(output, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            # 更新进度条
            pbar.set_postfix(loss=running_loss / (total // len(data)), accuracy=100 * correct / total)

    return running_loss / len(train_loader), 100 * correct / total

3. 保存与加载模型

在训练完成后,我们将保存模型,并在后续的测试过程中加载模型以进行评估。

python 复制代码
def save_model(model, filepath='vggnet_mnist.pth'):
    """ 保存训练的模型到指定文件(覆盖之前的文件) """
    torch.save(model.state_dict(), filepath)
    print(f"Model saved to {filepath}")


def load_model(model_path='vggnet_mnist.pth', num_classes=10):
    """ 加载预训练模型 """
    model = VGG(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))
    return model

4. 评估模型与可视化结果

我们可以加载训练好的模型并对其在测试集上的表现进行评估。我们还可以通过 matplotlib 可视化前六张测试图像的预测结果。

python 复制代码
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def get_test_loader(batch_size=64, data_dir='D:/workspace/data'):
    """ 获取 MNIST 测试数据加载器 """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


def evaluate_model(model, test_loader, device):
    """ 评估模型并返回准确率和前六张图片的预测与标签 """
    model.eval()
    correct = 0
    total = 0
    images, labels, preds = [], [], []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            # 记录前六张图片及其标签和预测
            if len(images) < 6:
                batch_size = data.size(0)
                for i in range(min(6 - len(images), batch_size)):
                    images.append(data[i].cpu())
                    labels.append(target[i].cpu())
                    preds.append(predicted[i].cpu())

    accuracy = 100 * correct / total
    return accuracy, images, labels, preds


def display_images(images, labels, preds):
    """ 可视化前六张图片及其真实标签和预测标签 """
    fig, axes = plt.subplots(2, 3, figsize=(10, 6))
    axes = axes.ravel()

    for i in range(6):
        axes[i].imshow(images[i][0].squeeze(), cmap='gray')  # MNIST 是单通道灰度图像
        axes[i].set_title(f"True: {labels[i].item()}, Pred: {preds[i].item()}")
        axes[i].axis('off')  # 不显示坐标轴

    plt.show()

5. 总结

通过以上步骤,我们成功实现并训练了一个 VGGNet 网络,并在 MNIST 数据集上进行了测试与评估。我们使用了混合精度训练来加速训练过程,并通过可视化展示了模型的预测效果。

这种方法可以推广到其他数据集和任务中,例如 CIFAR-10、CIFAR-100 或其他图像分类问题。

完整项目:

qxd-ljy/VGGNet-PyTorch: 使用PyTorch实现VGGNet进行MINST图像分类https://github.com/qxd-ljy/VGGNet-PyTorchVGGNet-PyTorch: 使用PyTorch实现VGGNet进行MINST图像分类https://gitee.com/qxdlll/vggnet-py-torch

相关推荐
AIGC大时代1 小时前
方法建议ChatGPT提示词分享
人工智能·深度学习·chatgpt·aigc·ai写作
数据馅2 小时前
window系统annaconda中同时安装paddle和pytorch环境
人工智能·pytorch·paddle
程序员一诺3 小时前
【深度学习】嘿马深度学习笔记第11篇:卷积神经网络,学习目标【附代码文档】
人工智能·python·深度学习·算法
姓学名生3 小时前
李沐vscode配置+github管理+FFmpeg视频搬运+百度API添加翻译字幕
vscode·python·深度学习·ffmpeg·github·视频
AI科技大本营3 小时前
Anthropic四大专家“会诊”:实现深度思考不一定需要多智能体,AI完美对齐比失控更可怕!...
人工智能·深度学习
Damon小智3 小时前
合合信息DocFlow产品解析与体验:人人可搭建的AI自动化单据处理工作流
图像处理·人工智能·深度学习·机器学习·ai·自动化·docflow
孤独且没人爱的纸鹤3 小时前
【机器学习】深入无监督学习分裂型层次聚类的原理、算法结构与数学基础全方位解读,深度揭示其如何在数据空间中构建层次化聚类结构
人工智能·python·深度学习·机器学习·支持向量机·ai·聚类
Galerkin码农选手3 小时前
寒武纪使用cnnl库函数实现卷积算子
pytorch
小深ai硬件分享5 小时前
Keras、TensorFlow、PyTorch框架对比及服务器配置揭秘
服务器·人工智能·深度学习
盼小辉丶12 小时前
TensorFlow深度学习实战——情感分析模型
深度学习·神经网络·tensorflow