LeNet-5 是一个经典的卷积神经网络(CNN)架构,最初由 Yann LeCun 提出,用于手写数字识别任务,特别是针对 MNIST 数据集。在 MATLAB 中实现 LeNet-5 并用于识别 MNIST 数据集,可以通过 MATLAB 的深度学习工具箱(Deep Learning Toolbox)来完成。
实现 LeNet-5 的步骤和代码:
1. 准备 MNIST 数据集
MATLAB 提供了内置的 MNIST 数据集加载功能,可以直接使用。
matlab
% 加载 MNIST 数据集
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsTest] = splitEachLabel(imds, 0.9, 'randomized');
2. 定义 LeNet-5 架构
LeNet-5 的架构包括多个卷积层、池化层和全连接层。以下是 MATLAB 中的实现:
matlab
layers = [
imageInputLayer([28 28 1], 'Name', 'input') % 输入层,图像大小为 28x28,灰度图像
convolution2dLayer(5, 6, 'Padding', 'same', 'Name', 'conv_1') % 第一个卷积层
batchNormalizationLayer('Name', 'BN_1') % 批量归一化层
reluLayer('Name', 'relu_1') % ReLU 激活层
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool_1') % 最大池化层
convolution2dLayer(5, 16, 'Padding', 'same', 'Name', 'conv_2') % 第二个卷积层
batchNormalizationLayer('Name', 'BN_2') % 批量归一化层
reluLayer('Name', 'relu_2') % ReLU 激活层
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool_2') % 最大池化层
fullyConnectedLayer(120, 'Name', 'fc_1') % 全连接层
reluLayer('Name', 'relu_fc_1') % ReLU 激活层
fullyConnectedLayer(84, 'Name', 'fc_2') % 全连接层
reluLayer('Name', 'relu_fc_2') % ReLU 激活层
fullyConnectedLayer(10, 'Name', 'fc_3') % 输出层,10 个类别
softmaxLayer('Name', 'softmax') % Softmax 层
classificationLayer('Name', 'output') % 分类输出层
];
3. 配置训练选项
使用 trainingOptions 函数配置训练参数,例如学习率、迭代次数等。
matlab
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',20, ...
'MiniBatchSize', 128, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsTest, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
4. 训练模型
使用 trainNetwork 函数训练网络。
matlab
net = trainNetwork(imdsTrain, layers, options);
5. 测试模型
使用测试集评估模型性能。
matlab
YTest = classify(net, imdsTest);
accuracy = sum(YTest == imdsTest.Labels) / numel(imdsTest.Labels);
disp(['测试集准确率: ', num2str(accuracy * 100), '%']);
6. 可视化结果
可以随机选择一些测试图像,显示预测结果。
matlab
figure;
for i = 1:10
idx = randi(numel(imdsTest.Files));
img = readimage(imdsTest, idx);
label = classify(net, img);
subplot(2, 5, i);
imshow(img);
title(char(label));
end
参考代码 对LeNet-5的matlab实现,识别MINST手写数字集 www.youwenfan.com/contentcsp/59637.html
完整代码
将以上代码片段组合起来,即可完成 LeNet-5 的实现和训练。
matlab
% 加载数据集
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsTest] = splitEachLabel(imds, 0.9, 'randomized');
% 定义网络架构
layers = [
imageInputLayer([28 28 1], 'Name', 'input')
convolution2dLayer(5, 6, 'Padding', 'same', 'Name', 'conv_1')
batchNormalizationLayer('Name', 'BN_1')
reluLayer('Name', 'relu_1')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool_1')
convolution2dLayer(5, 16, 'Padding', 'same', 'Name', 'conv_2')
batchNormalizationLayer('Name', 'BN_2')
reluLayer('Name', 'relu_2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool_2')
fullyConnectedLayer(120, 'Name', 'fc_1')
reluLayer('Name', 'relu_fc_1')
fullyConnectedLayer(84, 'Name', 'fc_2')
reluLayer('Name', 'relu_fc_2')
fullyConnectedLayer(10, 'Name', 'fc_3')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')
];
% 配置训练选项
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',20, ...
'MiniBatchSize', 128, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsTest, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
% 训练网络
net = trainNetwork(imdsTrain, layers, options);
% 测试网络
YTest = classify(net, imdsTest);
accuracy = sum(YTest == imdsTest.Labels) / numel(imdsTest.Labels);
disp(['测试集准确率: ', num2str(accuracy * 100), '%']);
% 可视化结果
figure;
for i = 1:10
idx = randi(numel(imdsTest.Files));
img = readimage(imdsTest, idx);
label = classify(net, img);
subplot(2, 5, i);
imshow(img);
title(char(label));
end