一、前言
手写数字识别可替代人工录入,提升效率、减少人为误差;广泛应用于政务归档、智能阅卷、金融票据、物流分拣及人工智能教学领域。
采用卷积神经网络 实现图像分类,核心逻辑模拟人类视觉识别过程:自动提取图像特征 + 数字分类。无需人工设计特征,通过网络自主学习手写数字的边缘、笔画、轮廓等视觉特征;利用最大池化压缩数据、保留核心特征,最终通过全连接层整合特征,完成 0-9 十个数字的精准分类,实现端到端的手写数字识别。
二、网络结构
程序构建了轻量化浅层 CNN ,专为 28×28 像素灰度手写数字设计,结构简单易理解模型分为 输入层 → 特征提取层 → 全连接层 → 输出层 四部分,通过两层卷积完成特征提取,两层全连接层协同实现特征融合与数字分类:
- 输入层:接收 28×28×1 单通道灰度图像;
- 特征提取层 :2 组卷积 + ReLU 激活 + 最大池化(3×3 卷积核,16/32 个特征图,提取简单 / 复杂特征;2×2 最大池化,降维降噪、提升训练效率)。第一次卷积 + 池化:
28×28×1 → 28×28×16 → 14×14×16,第二次卷积 + 池化:14×14×16 → 14×14×32 → 7×7×32 - 全连接层 :全连接层:
7×7×32 = 1568个特征 → 压缩为128→ 最终10分类 - 输出层:10 神经元全连接层 + Softmax 归一化 + 分类层,输出 0-9 的分类概率,判定最终数字。
三、程序核心流程
- 加载 MATLAB 自带 MNIST 数据集,按 7:3 随机划分训练集(学习特征)和测试集(验证效果);
- 数据预处理:统一图像尺寸、像素归一化,适配网络输入;
- 采用 Adam 优化器训练模型,设置 10 轮训练、128 批次大小,实时监控训练进度;
- 测试集预测并计算准确率,可视化混淆矩阵、随机样本识别结果;
- 支持单张图片实时预测,直观展示识别效果。
四、程序代码
%% 1. 初始化环境
clear; clc; close all; % 清空工作区、命令行、关闭窗口
rng(1); % 固定随机种子,保证训练结果可复现
%% 2. 加载MNIST手写数字数据集(MATLAB自带)
% 数据集路径:0-9十个文件夹,每个文件夹内存放对应数字的手写图片
datasetPath = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
% 创建图像数据存储(自动按文件夹分类标签)
imds = imageDatastore(datasetPath, ...
'IncludeSubfolders',true, ... % 包含子文件夹
'LabelSource','foldernames'); % 标签来自文件夹名称
%% 3. 划分训练集和测试集(70%训练,30%测试)
imdsTrain, imdsTest = splitEachLabel(imds,0.7,'randomize');
%% 4. 数据预处理(统一图像尺寸+归一化)
% MNIST原始图像为28x28灰度图,无需调整大小,直接适配输入层
inputSize = 28 28 1; % 输入:28x28 单通道灰度图
numClasses = 10; % 分类数:0-9 共10个数字
%% 5. 构建卷积神经网络(CNN)核心结构
layers = [
% 输入层
imageInputLayer(inputSize)
% 第一组卷积+激活+池化
convolution2dLayer(3,16,'Padding','same') % 3x3卷积核,16个特征图
reluLayer % ReLU激活函数
maxPooling2dLayer(2,'Stride',2) % 2x2最大池化,步长2
% 第二组卷积+激活+池化
convolution2dLayer(3,32,'Padding','same') % 3x3卷积核,32个特征图
reluLayer
maxPooling2dLayer(2,'Stride',2)
% 全连接层+分类输出
fullyConnectedLayer(128) % 全连接层,128个神经元
reluLayer
fullyConnectedLayer(numClasses) % 输出层:10个神经元(对应0-9)
softmaxLayer % Softmax归一化概率
classificationLayer % 分类层
];
%% 6. 设置训练参数
options = trainingOptions('adam', ... % 优化器:Adam
'MaxEpochs',10, ... % 最大训练轮数
'MiniBatchSize',128, ... % 批次大小
'ValidationData',imdsTest, ... % 验证集
'ValidationFrequency',30, ... % 验证频率
'Verbose',true, ... % 打印训练日志
'Plots','training-progress'); % 显示训练进度图
%% 7. 训练CNN模型
fprintf('开始训练卷积神经网络...\n');
net = trainNetwork(imdsTrain,layers,options);
%% 8. 模型测试:预测测试集并计算准确率
fprintf('模型训练完成,开始测试...\n');
predictedLabels = classify(net,imdsTest); % 对测试集预测
testLabels = imdsTest.Labels; % 测试集真实标签
% 计算识别准确率
accuracy = mean(predictedLabels == testLabels);
fprintf('测试集识别准确率:%.2f%%\n',accuracy*100);
%% 9. 结果可视化
% 9.1 绘制混淆矩阵(查看各类数字识别效果)
figure;
confusionchart(testLabels,predictedLabels);
title('手写数字识别混淆矩阵');
% 9.2 随机显示10张测试图+预测结果
figure;
idx = randperm(length(imdsTest.Files),10); % 随机选10张图
for i = 1:10
subplot(2,5,i);
I = readimage(imdsTest,idx(i));
imshow(I);
title('真实:' char(testLabels(idx(i))) ' 预测:' char(predictedLabels(idx(i))));
end
sgtitle('手写数字识别结果可视化');
%% 10. 单张图片预测示例
fprintf('\n单张图片预测演示:\n');
testImg = readimage(imdsTest,2001); % 读取第2001张测试图
figure,imshow(testImg),title('图像测试数据集中第2001张图像');
predNum = classify(net,testImg); % 预测
figure,imshow(testImg);
title('预测数字:' char(predNum));
五、主要运行结果
1.随机10张图像的预测结果

2.指定图像的预测结果(图像测试数据集中第2001张图像)


从程序运行结果可以看出,基于CNN的手写数字测试集识别准确率可达98以上,结构轻量高效、训练速度快,该方法无复杂结构,容易理解和上手,完整覆盖数据处理、模型构建、训练、测试、可视化全流程,是理解 CNN 原理与图像分类的简单案例。