基于MATLAB的CNN图像分类算法实现

基于MATLAB深度学习工具箱(Deep Learning Toolbox)的卷积神经网络(CNN)图像分类完整实现流程


一、数据准备与预处理
1. 数据集加载

使用MATLAB内置的imageDatastore函数加载图像数据集,支持自定义路径或直接调用公开数据集(如CIFAR-10、MNIST):

matlab 复制代码
% 加载CIFAR-10数据集(示例)
[XTrain, YTrain] = digitTrain4DArrayData; % 内置手写数字数据集
[XTest, YTest] = digitTest4DArrayData;

% 或自定义数据集(需按文件夹分类)
dataFolder = 'path/to/custom_dataset';
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');

关键点

  • 图像需按类别存放在子文件夹中,文件夹名即为标签。
  • 对于灰度图像,需转换为RGB格式(如cat(3, X, X, X))。
2. 数据预处理
  • 归一化:将像素值缩放到[0,1]或[-1,1]范围。
  • 数据增强:通过旋转、平移等操作增加数据多样性,防止过拟合:
matlab 复制代码
augmenter = imageDataAugmenter(...
    'RandRotation', [-20, 20], ...  % 随机旋转±20度
    'RandXReflection', true, ...    % 随机水平翻转
    'RandYReflection', true);       % 随机垂直翻转

augimdsTrain = augmentedImageDatastore([32, 32], imdsTrain, 'DataAugmentation', augmenter);

关键点:输入尺寸需与网络输入层匹配(如AlexNet需227×227)。


二、CNN模型构建
1. 自定义网络架构

通过layerGraph定义网络结构,典型CNN包含卷积层、池化层、全连接层:

matlab 复制代码
layers = [
    imageInputLayer([32 32 3]) % 输入层(高度×宽度×通道)
    
    % 卷积块1
    convolution2dLayer(3, 16, 'Padding', 'same') % 3×3卷积核,输出16通道
    batchNormalizationLayer % 批归一化
    reluLayer % ReLU激活
    maxPooling2dLayer(2, 'Stride', 2) % 2×2最大池化
    
    % 卷积块2
    convolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    
    % 全连接层
    fullyConnectedLayer(10) % 输出层(类别数)
    softmaxLayer % Softmax概率转换
    classificationLayer]; % 分类输出

优化建议

  • 添加Dropout层(如dropoutLayer(0.5))进一步防止过拟合。
  • 使用残差连接(ResNet结构)提升深层网络性能。
2. 迁移学习(推荐)

对于小数据集,直接使用预训练模型(如ResNet-50)进行微调:

matlab 复制代码
net = resnet50; % 加载预训练模型
lgraph = layerGraph(net);

% 替换最后的全连接层和输出层
newFCLayer = fullyConnectedLayer(numClasses, 'WeightLearnRateFactor', 10);
newOutputLayer = classificationLayer('Name', 'new_classoutput');
lgraph = replaceLayer(lgraph, 'fc1000', newFCLayer);
lgraph = replaceLayer(lgraph, 'ClassificationLayer_fc1000', newOutputLayer);

优势:迁移学习可显著减少训练时间并提升准确率。


三、模型训练与优化
1. 训练选项配置

通过trainingOptions设置优化器、学习率、早停等参数:

matlab 复制代码
options = trainingOptions('adam', ...  % 优化器(sgdm, rmsprop)
    'MaxEpochs', 20, ...
    'MiniBatchSize', 64, ...
    'InitialLearnRate', 0.001, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', imdsValidation, ...  % 验证集
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress', ...  % 实时显示训练曲线
    'ExecutionEnvironment', 'auto');  % 自动选择CPU/GPU加速

关键参数

  • LearnRateSchedule: 设置学习率衰减策略(如'piecewise')。
  • L2Regularization: 添加L2正则化(默认0.0005)。
2. 模型训练

调用trainNetwork启动训练:

matlab 复制代码
net = trainNetwork(augimdsTrain, lgraph, options);

训练监控

  • 实时显示损失(Loss)和准确率(Accuracy)曲线。
  • 若启用验证集,可观察验证损失防止过拟合。

四、模型评估与测试
1. 测试集预测
matlab 复制代码
YPred = classify(net, imdsTest);
YTest = imdsTest.Labels;
accuracy = sum(YPred == YTest) / numel(YTest);
fprintf('测试集准确率: %.2f%%
', accuracy*100);
2. 混淆矩阵分析
matlab 复制代码
figure;
cm = confusionchart(YTest, YPred);
cm.Title = '混淆矩阵';
cm.ColumnSummary = 'column-normalized'; % 列归一化
cm.RowSummary = 'row-normalized';       % 行归一化

解读:对角线元素为正确分类比例,非对角线反映误分类情况。

3. ROC曲线与AUC
matlab 复制代码
[X, Y, T, AUC] = perfcurve(YTest, scores(:,2), 'cat'); % 假设二分类
figure; plot(X, Y); title(sprintf('ROC曲线 (AUC=%.2f)', AUC));

五、模型部署与应用
1. 新图像分类
matlab 复制代码
newImg = imread('test_image.jpg');
newImgResized = imresize(newImg, [32, 32]); % 调整尺寸
label = classify(net, newImgResized);
imshow(newImg); title(sprintf('预测类别: %s', char(label)));
2. 导出为ONNX格式
matlab 复制代码
net = exportNetworkToONNX(net, 'cnn_model.onnx');

应用场景:部署到移动端或嵌入式设备(需转换框架如TensorRT)。

参考代码 利用MATLAB实现一个基于CNN的图像分类算法 www.youwenfan.com/contentcsq/50981.html

六、完整代码示例
matlab 复制代码
%% 数据准备
dataFolder = 'path/to/dataset';
imds = imageDatastore(dataFolder, 'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsTest] = splitEachLabel(imds, 0.8, 'randomized');

% 数据增强
augmenter = imageDataAugmenter('RandRotation', [-15,15], 'RandXReflection', true);
augimdsTrain = augmentedImageDatastore([224,224], imdsTrain, 'DataAugmentation', augmenter);

%% 模型构建(迁移学习)
net = alexnet;
lgraph = layerGraph(net);
newFCLayer = fullyConnectedLayer(10, 'WeightLearnRateFactor', 10);
newOutputLayer = classificationLayer('Name', 'output');
lgraph = replaceLayer(lgraph, 'fc7', newFCLayer);
lgraph = replaceLayer(lgraph, 'ClassificationLayer_fc7', newOutputLayer);

%% 训练配置
options = trainingOptions('adam', ...
    'MaxEpochs', 15, ...
    'MiniBatchSize', 64, ...
    'InitialLearnRate', 0.001, ...
    'ValidationData', imdsTest, ...
    'ValidationFrequency', 30, ...
    'Plots', 'training-progress', ...
    'ExecutionEnvironment', 'gpu');

%% 训练模型
net = trainNetwork(augimdsTrain, lgraph, options);

%% 评估模型
YPred = classify(net, imdsTest);
accuracy = sum(YPred == imdsTest.Labels)/numel(imdsTest.Labels);
fprintf('准确率: %.2f%%
', accuracy*100);

%% 可视化混淆矩阵
figure;
cm = confusionchart(imdsTest.Labels, YPred);
cm.Title = '混淆矩阵';
cm.ColumnSummary = 'column-normalized';

七、扩展应用
  1. 目标检测:结合YOLO或Faster R-CNN实现实时检测。
  2. 风格迁移:使用预训练网络提取特征进行图像风格化。
  3. 多模态融合:结合CNN与RNN处理视频分类任务。

八、常见问题解决
  • 内存不足 :减小MiniBatchSize或使用'ExecutionEnvironment', 'cpu'
  • 过拟合:增加Dropout层、数据增强或使用正则化。
  • 低准确率:检查数据预处理、网络深度或尝试迁移学习。
相关推荐
foundbug9991 天前
无人机离散系统模型预测控制(MPC)MATLAB实现
开发语言·matlab·无人机
SomeB1oody1 天前
【Python深度学习】1.2. 多层感知器MLP(人工神经网络)实现非线性分类理论
开发语言·人工智能·python·深度学习·机器学习·分类
AI浩1 天前
PaveSync:用于路面病害分析与分类的统一综合数据集
人工智能·机器学习·分类·数据挖掘
代码改善世界2 天前
【matlab初阶】matlab入门知识
android·java·matlab
工业机器视觉设计和实现2 天前
自己的初心,在bpnet基础上自研cnn
人工智能·神经网络·cnn
youcans_2 天前
【FOC-MBD】(20)矢量空间脉宽调制 (SVPWM)输出
stm32·单片机·嵌入式硬件·matlab·代码生成
Three~stone2 天前
MATLAB vs Python 两者区别和安装教程
开发语言·python·matlab
Dev7z2 天前
基于MATLAB与SVM实现河道水面漂浮物的自动检测与识别
人工智能·支持向量机·matlab
MoRanzhi12032 天前
scikit-learn 决策树分类详解:从原理、可视化到剪枝实战掌握 DecisionTreeClassifier
python·决策树·机器学习·数学建模·分类·scikit-learn·剪枝
再一次等风来2 天前
近场声全息(NAH)仿真实现:从阵列实值信号到波数域重建
算法·matlab·信号处理·近场声全息·nah