CNN手写数字识别1——模型搭建与数据准备

模型搭建

我们这次使用LeNet模型,LeNet是一个经典的卷积神经网络(Convolutional Neural Network, CNN)架构,最初由Yann LeCun等人在1998年提出,用于手写数字识别任务

创建一个文件model.py。实现以下代码。

源码

python 复制代码
# 导入PyTorch库
import torch
# 从PyTorch库中导入神经网络模块
from torch import nn
# 从torchsummary库中导入summary函数,用于打印模型的结构和参数数量
from torchsummary import summary

# 定义LeNet类,它继承自nn.Module,是一个神经网络模型
class LeNet(nn.Module):
    # 初始化函数,定义模型的层次结构
    def __init__(self):
        # 调用父类的初始化函数
        super().__init__()
        # 第一个卷积层,输入通道为1,输出通道为6,卷积核大小为5x5,padding为2
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
        # Sigmoid激活函数
        self.sig = nn.Sigmoid()
        # 第一个平均池化层,池化窗口为2x2,步长为2
        self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 第二个卷积层,输入通道为6,输出通道为16,卷积核大小为5x5
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        # 第二个平均池化层,池化窗口为2x2,步长为2
        self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
        # Flatten层,用于将多维的输入一维化,以便输入到全连接层
        self.flatten = nn.Flatten()
        # 第一个全连接层,输入特征数为400,输出特征数为120
        self.f5 = nn.Linear(400, 120)
        # 第二个全连接层,输入特征数为120,输出特征数为84
        self.f6 = nn.Linear(120, 84)
        # 第三个全连接层,输入特征数为84,输出特征数为10(通常对应分类任务中的类别数)
        self.f7 = nn.Linear(84, 10)

    # 前向传播函数,定义数据通过网络的方式
    def forward(self, x):
        x = self.sig(self.c1(x))  # 通过第一个卷积层和Sigmoid激活函数
        x = self.s2(x)            # 通过第一个平均池化层
        x = self.sig(self.c3(x))  # 通过第二个卷积层和Sigmoid激活函数
        x = self.s4(x)            # 通过第二个平均池化层
        x = self.flatten(x)       # 通过Flatten层
        x = self.sig(self.f5(x))  # 通过第一个全连接层和Sigmoid激活函数
        x = self.sig(self.f6(x))  # 通过第二个全连接层和Sigmoid激活函数
        x = self.sig(self.f7(x))  # 通过第三个全连接层和Sigmoid激活函数
        return x

# 主函数
if __name__ == "__main__":
    # 自动检测是否有可用的GPU,如果有则使用GPU,否则使用CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 实例化LeNet模型,并将其移动到指定的设备上(GPU或CPU)
    model = LeNet().to(device)
    # 使用torchsummary的summary函数打印模型的结构和参数数量,输入形状为(1, 28, 28)
    print(summary(model, (1, 28, 28)))

源码解析

神经网络构建

LeNet网络主要由卷积层、池化层、激活函数和全连接层组成。有2个卷积层,2个池化层,3个全连接层。

  • 卷积层 ‌(nn.Conv2d):用于提取图像中的特征。这里有两个卷积层,第一个卷积层有6个输出通道,第二个卷积层有16个输出通道,卷积核大小都是5x5。
  • 激活函数 ‌(nn.Sigmoid):用于引入非线性,使得网络能够学习更复杂的模式。这里使用了Sigmoid激活函数。
  • 池化层 ‌(nn.AvgPool2d):用于降低特征图的尺寸,减少计算量,同时保留重要特征。这里使用了平均池化层,池化窗口大小为2*2,步长为2。
  • Flatten层 ‌(nn.Flatten):用于将多维的特征图展平成一维向量,以便输入到全连接层。
  • 全连接层 ‌(nn.Linear):用于分类任务,将特征向量映射到类别空间。这里有三个全连接层,分别将特征维度从400降到120,再从120降到84,最后从84降到10(对应10个类别)。

参数计算

从代码中可以看到每一层神经网络都有自己的参数,这里面通道数,卷积核大小,步长和感受野,一定程度上可以当做超参数人为自由设定,其他的参数都需要事先根据输入数据进行计算。

首先,假设输入的图像数据大小都是28*28*1,即宽28个像素,高28个像素,由于是灰度图所以色彩通道只有1。

在第一个卷积层c1,卷积核大小是5*5,卷积核个数是6个,步长默认是1,填充是2,这些是我们人为设定的。可可以得出输出通道数=卷积核个数=6,经过这一层的输出数据通道数为6,尺寸通过公式计算:

公式里面O是输出的宽/高,IN是输入的宽/高,P是填充,F是卷积核的宽/高或者感受野的宽/高,S是步长。

即输出图像的宽高是(28+2*2-5)/1+1=28。输出数据量是28*28*6。可以看到卷积层可以有效提升数据的通道数。

在第一个池化层s2,感受野是2*2,步长是2,填充默认是0,可以计算出输出图像的宽高是(28+2*0-2)/2+1=14。可以看到经过池化层之后数据的特征量明显减少,此时输出的数据量是14*14*6(池化不会改变通道数)。

在第二个卷积层c3,卷积核大小是5*5,步长默认是1,填充默认是0,有16个卷积核,也就是通道数增加到了16。根据公式可以计算出输出图像的宽高是(14+2*0-5)/1+1=10。输出数据量是10*10*16。

在第二个池化层s4,感受野是2*2,步长为2,那么输出数据宽高就是(10+2*0-2)/2+1=5。输出数据量是5*5*16。

到全连接层的第一层就比较关键了,因为这里的参数有一个"输入特征数",也就是刚才算的5*5*16=400。如果前面的计算不对,到了这一步模型是会报错的,因此每一层的输出特征量都需要计算出来。

后面的全连接层就比较好写了,输入的特征量是上一层全连接层的神经元个数。

前向传播

前向传播定义了数据通过网络的方式。对于输入x,它首先通过第一个卷积层和Sigmoid激活函数,然后通过第一个平均池化层;接着通过第二个卷积层和Sigmoid激活函数,再通过第二个平均池化层;最后通过Flatten层将多维特征展平成一维向量,并依次通过三个全连接层和Sigmoid激活函数得到最终输出。

验证

在主函数中,我们首先检测是否有可用的GPU,并将模型移动到合适的计算设备上(GPU或CPU)。然后,我们使用torchsummary的summary函数打印模型的结构和参数数量,以便了解模型的复杂度和计算需求。

从图中可以看出,池化层是不包含参数的,整个模型的大部分参数都在全连接层(48120+10164+850 = 59134,将近6万个参数在伺候全连接层)。

数据准备

FashionMNIST是一个流行的数据集,包含了10种类别的70,000个灰度图像,通常用于计算机视觉和机器学习的教学与研究。我们这次通过远程下载的方式来获取数据。

另外创建一个plot.python。用来下载数据和预览数据。这部分代码可写可不写,模型训练的时候还会重新加载数据。

源码

python 复制代码
# 导入必要的库和模块
from torchvision import transforms  # 用于图像预处理的变换
from torchvision.datasets import FashionMNIST  # 导入FashionMNIST数据集
import torch.utils.data as Data  # 用于数据加载的实用工具
import numpy as np  # 导入NumPy库,用于数值计算
import matplotlib.pyplot as plt

# 准备训练数据
train_data = FashionMNIST(
    root='./data',  # 数据集存储的根目录
    train=True,  # 指定为训练数据集
    transform=transforms.Compose([  # 图像预处理步骤
        transforms.Resize(size=224),  # 将图像大小调整为224x224
        transforms.ToTensor()  # 将图像转换为PyTorch张量
    ]),
    download=True  # 如果数据集不存在,则下载
)

# 创建数据加载器
train_loader = Data.DataLoader(
    dataset=train_data,  # 指定数据集
    batch_size=64,  # 每个批次的大小
    shuffle=True,  # 在每个epoch开始时打乱数据
    num_workers=0  # 使用0个工作线程(对于Windows系统,有时需要设置为0以避免多进程问题)
)


# 遍历数据加载器
for step, (b_x, b_y) in enumerate(train_loader):
    if step > 0:  # 只处理第一个批次的数据
        break
# 将PyTorch张量转换为NumPy数组
batch_x = b_x.squeeze().numpy()  # 移除批次维度(如果可能),并转换为NumPy数组
batch_y = b_y.numpy()  # 将标签转换为NumPy数组

# 获取数据集中的类别标签
class_label = train_data.classes  # 这是一个包含所有类别名称的列表
# 打印类别标签
print(class_label)  # 输出类别标签列表

# 设置图形的大小
plt.figure(figsize=(12, 5))


# 遍历batch_y中的每一个元素,即每一个样本的标签
for ii in np.arange(len(batch_y)):
    # 创建子图,4行16列,第ii+1个子图
    # 这里假设一个批次有64个样本,因此用4x16的布局来显示它们
    plt.subplot(4, 16, ii + 1)

    # 显示图像
    # batch_x[ii, :, :]表示第ii个样本的图像数据
    # cmap=plt.cm.gray指定使用灰度色彩映射
    plt.imshow(batch_x[ii, :, :], cmap=plt.cm.gray)

    # 设置标题为对应的类别标签
    # class_label[batch_y[ii]]根据标签索引获取类别名称
    # size=10设置标题字体大小
    plt.title(class_label[batch_y[ii]], size=10)

    # 关闭坐标轴显示
    plt.axis("off")

# 调整子图之间的间距
# wspace=0.05设置子图之间的宽度间距
plt.subplots_adjust(wspace=0.05)

# 显示图形
plt.show()

源码解析

下载和加载数据

首先准备训练数据。FashionMNIST数据集将被下载到指定的根目录,并进行图像预处理

为了高效地加载数据,我们使用PyTorch的DataLoader来创建数据加载器。

dataloader的参数解释如下:

  • dataset:指定要加载的数据集。
  • batch_size:每个批次加载的样本数。
  • shuffle:是否在每个epoch开始时打乱数据。
  • num_workers:加载数据时使用的工作线程数。在Windows系统上,有时需要设置为0以避免多进程问题。

展示数据

我们遍历数据加载器,但只处理第一个批次的数据(为了简化示例)。使用squeeze()方法移除批次维度(如果可能),并将PyTorch张量转换为NumPy数组,以便使用matplotlib进行可视化。随后使用matplotlib的subplot()方法创建子图,并在每个子图中显示一个图像样本。我们使用灰度色彩映射(cmap=plt.cm.gray)来显示图像。最后,使用plt.show()方法显示图形。

相关推荐
带娃的IT创业者15 分钟前
机器学习实战(8):降维技术——主成分分析(PCA)
人工智能·机器学习·分类·聚类
调皮的芋头39 分钟前
iOS各个证书生成细节
人工智能·ios·app·aigc
flying robot3 小时前
人工智能基础之数学基础:01高等数学基础
人工智能·机器学习
Moutai码农3 小时前
机器学习-生命周期
人工智能·python·机器学习·数据挖掘
188_djh3 小时前
# 10分钟了解DeepSeek,保姆级部署DeepSeek到WPS,实现AI赋能
人工智能·大语言模型·wps·ai技术·ai应用·deepseek·ai知识
Jackilina_Stone3 小时前
【DL】浅谈深度学习中的知识蒸馏 | 输出层知识蒸馏
人工智能·深度学习·机器学习·蒸馏
bug404_4 小时前
分布式大语言模型服务引擎vLLM论文解读
人工智能·分布式·语言模型
Logout:4 小时前
[AI]docker封装包含cuda cudnn的paddlepaddle PaddleOCR
人工智能·docker·paddlepaddle
OJAC近屿智能4 小时前
苹果新品今日发布,AI手机市场竞争加剧,近屿智能专注AI人才培养
大数据·人工智能·ai·智能手机·aigc·近屿智能
代码猪猪傻瓜coding5 小时前
关于 形状信息提取的说明
人工智能·python·深度学习