一、前言
手写数字识别可替代人工录入,提升效率、减少人为误差;广泛应用于政务归档、智能阅卷、金融票据、物流分拣及人工智能教学领域。
采用卷积神经网络 实现图像分类,核心逻辑模拟人类视觉识别过程:自动提取图像特征 + 数字分类。无需人工设计特征,通过网络自主学习手写数字的边缘、笔画、轮廓等视觉特征;利用最大池化压缩数据、保留核心特征,最终通过全连接层整合特征,完成 0-9 十个数字的精准分类,实现端到端的手写数字识别。
二、网络结构
程序构建了轻量化浅层 CNN ,专为 28×28 像素灰度手写数字设计,结构简单易理解模型分为 输入层 → 特征提取层 → 全连接层 → 输出层 四部分,通过两层卷积完成特征提取,两层全连接层协同实现特征融合与数字分类:
- 输入层:接收 28×28×1 单通道灰度图像;
- 特征提取层 :2 组卷积 + ReLU 激活 + 最大池化(3×3 卷积核,16/32 个特征图,提取简单 / 复杂特征;2×2 最大池化,降维降噪、提升训练效率)。第一次卷积 + 池化:
28×28×1 → 28×28×16 → 14×14×16,第二次卷积 + 池化:14×14×16 → 14×14×32 → 7×7×32 - 全连接层 :全连接层:
7×7×32 = 1568个特征 → 压缩为128→ 最终10分类 - 输出层:10 神经元全连接层 + Softmax 归一化 + 分类层,输出 0-9 的分类概率,判定最终数字。
三、程序核心流程
- 加载 MATLAB 自带 MNIST 数据集,按 7:3 随机划分训练集(学习特征)和测试集(验证效果);
- 数据预处理:统一图像尺寸、像素归一化,适配网络输入;
- 采用 Adam 优化器训练模型,设置 10 轮训练、128 批次大小,实时监控训练进度;
- 测试集预测并计算准确率,可视化混淆矩阵、随机样本识别结果;
- 支持单张图片实时预测,直观展示识别效果。
四、程序代码
%% 1. 初始化环境
clear; clc; close all; % 清空工作区、命令行、关闭窗口
rng(1); % 固定随机种子,保证训练结果可复现
%% 2. 加载MNIST手写数字数据集(MATLAB自带)
% 数据集路径:0-9十个文件夹,每个文件夹内存放对应数字的手写图片
datasetPath = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
% 创建图像数据存储(自动按文件夹分类标签)
imds = imageDatastore(datasetPath, ...
'IncludeSubfolders',true, ... % 包含子文件夹
'LabelSource','foldernames'); % 标签来自文件夹名称
%% 3. 划分训练集和测试集(70%训练,30%测试)
imdsTrain, imdsTest\] = splitEachLabel(imds,0.7,'randomize'); %% 4. 数据预处理(统一图像尺寸+归一化) % MNIST原始图像为28x28灰度图,无需调整大小,直接适配输入层 inputSize = \[28 28 1\]; % 输入:28x28 单通道灰度图 numClasses = 10; % 分类数:0-9 共10个数字 %% 5. 构建卷积神经网络(CNN)核心结构 layers = \[ % 输入层 imageInputLayer(inputSize) % 第一组卷积+激活+池化 convolution2dLayer(3,16,'Padding','same') % 3x3卷积核,16个特征图 reluLayer % ReLU激活函数 maxPooling2dLayer(2,'Stride',2) % 2x2最大池化,步长2 % 第二组卷积+激活+池化 convolution2dLayer(3,32,'Padding','same') % 3x3卷积核,32个特征图 reluLayer maxPooling2dLayer(2,'Stride',2) % 全连接层+分类输出 fullyConnectedLayer(128) % 全连接层,128个神经元 reluLayer fullyConnectedLayer(numClasses) % 输出层:10个神经元(对应0-9) softmaxLayer % Softmax归一化概率 classificationLayer % 分类层 \]; %% 6. 设置训练参数 options = trainingOptions('adam', ... % 优化器:Adam 'MaxEpochs',10, ... % 最大训练轮数 'MiniBatchSize',128, ... % 批次大小 'ValidationData',imdsTest, ... % 验证集 'ValidationFrequency',30, ... % 验证频率 'Verbose',true, ... % 打印训练日志 'Plots','training-progress'); % 显示训练进度图 %% 7. 训练CNN模型 fprintf('开始训练卷积神经网络...\\n'); net = trainNetwork(imdsTrain,layers,options); %% 8. 模型测试:预测测试集并计算准确率 fprintf('模型训练完成,开始测试...\\n'); predictedLabels = classify(net,imdsTest); % 对测试集预测 testLabels = imdsTest.Labels; % 测试集真实标签 % 计算识别准确率 accuracy = mean(predictedLabels == testLabels); fprintf('测试集识别准确率:%.2f%%\\n',accuracy\*100); %% 9. 结果可视化 % 9.1 绘制混淆矩阵(查看各类数字识别效果) figure; confusionchart(testLabels,predictedLabels); title('手写数字识别混淆矩阵'); % 9.2 随机显示10张测试图+预测结果 figure; idx = randperm(length(imdsTest.Files),10); % 随机选10张图 for i = 1:10 subplot(2,5,i); I = readimage(imdsTest,idx(i)); imshow(I); title(\['真实:' char(testLabels(idx(i))) ' 预测:' char(predictedLabels(idx(i)))\]); end sgtitle('手写数字识别结果可视化'); %% 10. 单张图片预测示例 fprintf('\\n单张图片预测演示:\\n'); testImg = readimage(imdsTest,2001); % 读取第2001张测试图 figure,imshow(testImg),title('图像测试数据集中第2001张图像'); predNum = classify(net,testImg); % 预测 figure,imshow(testImg); title(\['预测数字:' char(predNum)\]); ### 五、主要运行结果 1.随机10张图像的预测结果  2.指定图像的预测结果(图像测试数据集中第2001张图像)   从程序运行结果可以看出,基于CNN的手写数字测试集识别准确率可达98以上,结构轻量高效、训练速度快,该方法无复杂结构,容易理解和上手,完整覆盖数据处理、模型构建、训练、测试、可视化全流程,是理解 CNN 原理与图像分类的简单案例。 ### **撰写博客不易,如果你觉得本文对你的研究和学习有过帮助,请点赞、关注,欢迎转发!谢谢大家!**