PyTorch中的LeNet-5入门

PyTorch中的LeNet-5入门

LeNet-5是一个经典的卷积神经网络(CNN)模型,由Yann LeCun等人在1998年提出。它在手写数字识别任务上取得了很好的性能,并被广泛应用于图像分类问题。本文将介绍如何使用PyTorch实现LeNet-5模型,并在MNIST手写数字数据集上进行训练和测试。

数据集介绍

MNIST是一个常用的手写数字识别数据集,包括60000个训练样本和10000个测试样本。每个样本是一个28x28的灰度图像,标签为0到9之间的数字。

网络结构

LeNet-5由7个层组成:两个卷积层、两个池化层和三个全连接层。具体结构如下:

markdown 复制代码
plaintextCopy code1. 卷积层:输入1通道,输出6通道,卷积核大小为5x5
2. 池化层:最大池化,池化窗口大小为2x2
3. 卷积层:输入6通道,输出16通道,卷积核大小为5x5
4. 池化层:最大池化,池化窗口大小为2x2
5. 全连接层:输入展平的向量,输出120维
6. 全连接层:输入120维,输出84维
7. 全连接层:输入84维,输出10维(输出类别的数量)

实现步骤

以下是实现LeNet-5模型的代码示例:

ini 复制代码
pythonCopy codeimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 定义LeNet-5模型
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 加载数据集并进行预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 初始化模型和优化器
model = LeNet5()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
# 测试模型
model.eval()
total_correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_correct += (predicted == labels).sum().item()
accuracy = total_correct / len(test_dataset)
print(f"Test Accuracy: {accuracy}")

总结

本文介绍了如何使用PyTorch实现LeNet-5模型,并在MNIST手写数字数据集上进行训练和测试。通过迭代训练和测试过程,我们可以获得模型在手写数字识别任务上的准确率。为了进一步提高模型性能,还可以尝试调整超参数和网络结构。希望本文能帮助初学者理解LeNet-5模型的基本原理和实现方式。

实际应用场景 - 图像分类

LeNet-5在图像分类中有广泛的应用,比如人脸识别、物体检测、手势识别等。下面以人脸识别为例,给出示例代码。

ini 复制代码
pythonCopy codeimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义LeNet-5模型
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 数据预处理及加载
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.ImageFolder('train', transform=transform)
test_dataset = datasets.ImageFolder('test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 初始化模型和优化器
model = LeNet5(num_classes=len(train_dataset.classes))
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            test_loss += criterion(outputs, labels).item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = correct / len(test_loader.dataset)
    print('Test Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, 100. * accuracy))
# 模型训练和测试
epochs = 10
for epoch in range(epochs):
    train(model, device, train_loader, optimizer, criterion)
    test(model, device, test_loader)

上述示例代码可用于人脸识别任务,其中的​​train​​文件夹和​​test​​文件夹分别存放了训练集和测试集的图像数据。模型训练过程中的准确率和测试损失会输出到控制台。可以根据实际情况调整超参数、网络结构和训练集。

LeNet-5是深度学习领域中的经典模型,它在图像分类领域有不错的表现。然而,LeNet-5也存在一些缺点:

  1. 处理复杂图像的能力有限:由于LeNet-5的网络结构相对简单,模型的复杂度较低,因此它对于处理复杂图像的能力相对有限。在面对较为复杂的图像分类任务时,LeNet-5可能无法达到较高的准确性。
  2. 仅适用于小尺寸图像:LeNet-5最初设计用于处理手写数字识别任务,因此网络结构较小,主要处理28x28像素的小尺寸图像。对于更高分辨率、更大尺寸的图像,LeNet-5的网络结构可能过于简单,难以提取有效的特征,导致性能下降。
  3. 缺乏激活函数的变化:LeNet-5中使用的激活函数主要为sigmoid函数,在深度学习发展的后期被证明不是最优选择。相对而言,更现代的激活函数(如ReLU)在深度神经网络中具有更好的性能。 类似的模型通常是基于LeNet-5进行改进和扩展的,其中最著名的是AlexNet、VGGNet、GoogLeNet和ResNet等。这些模型在LeNet-5的基础上增加了网络的深度和复杂度,引入了更多的卷积层和全连接层,并使用更先进的激活函数和正则化方法。这些模型在大规模图像分类任务上表现出色,具有更高的准确性和更强的泛化能力。然而,这些模型的复杂度也带来了更大的计算和存储需求,增加了训练时间和模型容量。因此,在实际应用中需要根据具体任务的需求和资源预算进行选择。
相关推荐
万亿少女的梦1685 分钟前
基于Spring Boot的网络购物商城的设计与实现
java·spring boot·后端
开心工作室_kaic2 小时前
springboot485基于springboot的宠物健康顾问系统(论文+源码)_kaic
spring boot·后端·宠物
0zxm2 小时前
08 Django - Django媒体文件&静态文件&文件上传
数据库·后端·python·django·sqlite
刘大辉在路上9 小时前
突发!!!GitLab停止为中国大陆、港澳地区提供服务,60天内需迁移账号否则将被删除
git·后端·gitlab·版本管理·源代码管理
追逐时光者11 小时前
免费、简单、直观的数据库设计工具和 SQL 生成器
后端·mysql
初晴~11 小时前
【Redis分布式锁】高并发场景下秒杀业务的实现思路(集群模式)
java·数据库·redis·分布式·后端·spring·
盖世英雄酱5813611 小时前
InnoDB 的页分裂和页合并
数据库·后端
小_太_阳12 小时前
Scala_【2】变量和数据类型
开发语言·后端·scala·intellij-idea
直裾12 小时前
scala借阅图书保存记录(三)
开发语言·后端·scala