基于MATLAB的CNN图像分类算法实现

基于MATLAB深度学习工具箱(Deep Learning Toolbox)的卷积神经网络(CNN)图像分类完整实现流程


一、数据准备与预处理
1. 数据集加载

使用MATLAB内置的imageDatastore函数加载图像数据集,支持自定义路径或直接调用公开数据集(如CIFAR-10、MNIST):

matlab 复制代码
% 加载CIFAR-10数据集(示例)
[XTrain, YTrain] = digitTrain4DArrayData; % 内置手写数字数据集
[XTest, YTest] = digitTest4DArrayData;

% 或自定义数据集(需按文件夹分类)
dataFolder = 'path/to/custom_dataset';
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');

关键点

  • 图像需按类别存放在子文件夹中,文件夹名即为标签。
  • 对于灰度图像,需转换为RGB格式(如cat(3, X, X, X))。
2. 数据预处理
  • 归一化:将像素值缩放到[0,1]或[-1,1]范围。
  • 数据增强:通过旋转、平移等操作增加数据多样性,防止过拟合:
matlab 复制代码
augmenter = imageDataAugmenter(...
    'RandRotation', [-20, 20], ...  % 随机旋转±20度
    'RandXReflection', true, ...    % 随机水平翻转
    'RandYReflection', true);       % 随机垂直翻转

augimdsTrain = augmentedImageDatastore([32, 32], imdsTrain, 'DataAugmentation', augmenter);

关键点:输入尺寸需与网络输入层匹配(如AlexNet需227×227)。


二、CNN模型构建
1. 自定义网络架构

通过layerGraph定义网络结构,典型CNN包含卷积层、池化层、全连接层:

matlab 复制代码
layers = [
    imageInputLayer([32 32 3]) % 输入层(高度×宽度×通道)
    
    % 卷积块1
    convolution2dLayer(3, 16, 'Padding', 'same') % 3×3卷积核,输出16通道
    batchNormalizationLayer % 批归一化
    reluLayer % ReLU激活
    maxPooling2dLayer(2, 'Stride', 2) % 2×2最大池化
    
    % 卷积块2
    convolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    
    % 全连接层
    fullyConnectedLayer(10) % 输出层(类别数)
    softmaxLayer % Softmax概率转换
    classificationLayer]; % 分类输出

优化建议

  • 添加Dropout层(如dropoutLayer(0.5))进一步防止过拟合。
  • 使用残差连接(ResNet结构)提升深层网络性能。
2. 迁移学习(推荐)

对于小数据集,直接使用预训练模型(如ResNet-50)进行微调:

matlab 复制代码
net = resnet50; % 加载预训练模型
lgraph = layerGraph(net);

% 替换最后的全连接层和输出层
newFCLayer = fullyConnectedLayer(numClasses, 'WeightLearnRateFactor', 10);
newOutputLayer = classificationLayer('Name', 'new_classoutput');
lgraph = replaceLayer(lgraph, 'fc1000', newFCLayer);
lgraph = replaceLayer(lgraph, 'ClassificationLayer_fc1000', newOutputLayer);

优势:迁移学习可显著减少训练时间并提升准确率。


三、模型训练与优化
1. 训练选项配置

通过trainingOptions设置优化器、学习率、早停等参数:

matlab 复制代码
options = trainingOptions('adam', ...  % 优化器(sgdm, rmsprop)
    'MaxEpochs', 20, ...
    'MiniBatchSize', 64, ...
    'InitialLearnRate', 0.001, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', imdsValidation, ...  % 验证集
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress', ...  % 实时显示训练曲线
    'ExecutionEnvironment', 'auto');  % 自动选择CPU/GPU加速

关键参数

  • LearnRateSchedule: 设置学习率衰减策略(如'piecewise')。
  • L2Regularization: 添加L2正则化(默认0.0005)。
2. 模型训练

调用trainNetwork启动训练:

matlab 复制代码
net = trainNetwork(augimdsTrain, lgraph, options);

训练监控

  • 实时显示损失(Loss)和准确率(Accuracy)曲线。
  • 若启用验证集,可观察验证损失防止过拟合。

四、模型评估与测试
1. 测试集预测
matlab 复制代码
YPred = classify(net, imdsTest);
YTest = imdsTest.Labels;
accuracy = sum(YPred == YTest) / numel(YTest);
fprintf('测试集准确率: %.2f%%
', accuracy*100);
2. 混淆矩阵分析
matlab 复制代码
figure;
cm = confusionchart(YTest, YPred);
cm.Title = '混淆矩阵';
cm.ColumnSummary = 'column-normalized'; % 列归一化
cm.RowSummary = 'row-normalized';       % 行归一化

解读:对角线元素为正确分类比例,非对角线反映误分类情况。

3. ROC曲线与AUC
matlab 复制代码
[X, Y, T, AUC] = perfcurve(YTest, scores(:,2), 'cat'); % 假设二分类
figure; plot(X, Y); title(sprintf('ROC曲线 (AUC=%.2f)', AUC));

五、模型部署与应用
1. 新图像分类
matlab 复制代码
newImg = imread('test_image.jpg');
newImgResized = imresize(newImg, [32, 32]); % 调整尺寸
label = classify(net, newImgResized);
imshow(newImg); title(sprintf('预测类别: %s', char(label)));
2. 导出为ONNX格式
matlab 复制代码
net = exportNetworkToONNX(net, 'cnn_model.onnx');

应用场景:部署到移动端或嵌入式设备(需转换框架如TensorRT)。

参考代码 利用MATLAB实现一个基于CNN的图像分类算法 www.youwenfan.com/contentcsq/50981.html

六、完整代码示例
matlab 复制代码
%% 数据准备
dataFolder = 'path/to/dataset';
imds = imageDatastore(dataFolder, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsTest] = splitEachLabel(imds, 0.8, 'randomized');

% 数据增强
augmenter = imageDataAugmenter('RandRotation', [-15,15], 'RandXReflection', true);
augimdsTrain = augmentedImageDatastore([224,224], imdsTrain, 'DataAugmentation', augmenter);

%% 模型构建(迁移学习)
net = alexnet;
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(10, 'WeightLearnRateFactor', 10);
newOutputLayer = classificationLayer('Name', 'output');
lgraph = replaceLayer(lgraph, 'fc7', newFCLayer);
lgraph = replaceLayer(lgraph, 'ClassificationLayer_fc7', newOutputLayer);

%% 训练配置
options = trainingOptions('adam', ...
    'MaxEpochs', 15, ...
    'MiniBatchSize', 64, ...
    'InitialLearnRate', 0.001, ...
    'ValidationData', imdsTest, ...
    'ValidationFrequency', 30, ...
    'Plots', 'training-progress', ...
    'ExecutionEnvironment', 'gpu');

%% 训练模型
net = trainNetwork(augimdsTrain, lgraph, options);

%% 评估模型
YPred = classify(net, imdsTest);
accuracy = sum(YPred == imdsTest.Labels)/numel(imdsTest.Labels);
fprintf('准确率: %.2f%%
', accuracy*100);

%% 可视化混淆矩阵
figure;
cm = confusionchart(imdsTest.Labels, YPred);
cm.Title = '混淆矩阵';
cm.ColumnSummary = 'column-normalized';

七、扩展应用
  1. 目标检测:结合YOLO或Faster R-CNN实现实时检测。
  2. 风格迁移:使用预训练网络提取特征进行图像风格化。
  3. 多模态融合:结合CNN与RNN处理视频分类任务。

八、常见问题解决
  • 内存不足 :减小MiniBatchSize或使用'ExecutionEnvironment', 'cpu'
  • 过拟合:增加Dropout层、数据增强或使用正则化。
  • 低准确率:检查数据预处理、网络深度或尝试迁移学习。
相关推荐
爱吃泡芙的小白白3 小时前
CNN参数量计算全解析:从基础公式到前沿优化
人工智能·神经网络·cnn·参数量
Faker66363aaa5 小时前
指纹过滤器缺陷检测与分类 —— 基于MS-RCNN_X101-64x4d_FPN_1x_COCO模型的实现与分析_1
人工智能·目标跟踪·分类
Evand J6 小时前
【MATLAB例程】TOA和TDOA混合定位,适用于二维平面的高精度定位。锚点数量、位置、测量噪声可自行调节
开发语言·matlab·定位·tdoa
Loacnasfhia97 小时前
面部表情识别与分类_YOLOv10n与MobileNetV4融合方案详解
yolo·分类·数据挖掘
t198751289 小时前
基于MATLAB的HOG+GLCM特征提取与SVM分类实现
支持向量机·matlab·分类
Loacnasfhia99 小时前
贝类海产品物种识别与分类_---_基于YOLOv10n与特征金字塔共享卷积的改进方法
yolo·分类·数据挖掘
机器学习之心10 小时前
Bayes-TCN+SHAP分析贝叶斯优化深度学习多变量分类预测可解释性分析!Matlab完整代码
深度学习·matlab·分类·贝叶斯优化深度学习
想进部的张同学10 小时前
week1-day5-CNN卷积补充感受野-CUDA 一、CUDA 编程模型基础 1.1 CPU vs GPU 架构线程索引与向量乘法
人工智能·神经网络·cnn
机器学习之心10 小时前
TCN+SHAP分析深度学习多变量分类预测可解释性分析!Matlab完整代码
深度学习·matlab·分类·多变量分类预测可解释性分析