MATLAB实战:机器学习分类回归示例

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集

% 鸢尾花数据集(分类)

load fisheriris;

X_iris = meas; % 150x4 特征矩阵

Y_iris = species; % 150x1 类别标签

% 手写数字数据集(分类)

digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...

'nndatasets', 'DigitDataset');

imds = imageDatastore(digitDatasetPath, ...

'IncludeSubfolders', true, 'LabelSource', 'foldernames');

trainImgs, testImgs = splitEachLabel(imds, 0.7, 'randomized');

% 提取HOG特征

numTrain = numel(trainImgs.Files);

hogFeatures = zeros(numTrain, 324); % HOG特征维度

for i = 1:numTrain

img = readimage(trainImgs, i);

hogFeatures(i, :) = extractHOGFeatures(img);

end

trainLabels = trainImgs.Labels;

% 汽车数据集(回归)

load carsmall;

X_car = Weight, Horsepower, Cylinders; % 100x3 特征矩阵

Y_car = MPG; % 100x1 响应变量

%% 鸢尾花分类任务

rng(1); % 设置随机种子保证可重复性

cv = cvpartition(Y_iris, 'HoldOut', 0.3);

idxTrain = training(cv);

idxTest = test(cv);

% 训练KNN模型

knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5);

knnPred = predict(knnModel, X_iris(idxTest,:));

knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest)

% 训练决策树

treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain));

treePred = predict(treeModel, X_iris(idxTest,:));

treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest)

% 训练SVM

svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain));

svmPred = predict(svmModel, X_iris(idxTest,:));

svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest)

% 混淆矩阵可视化

figure;

confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix');

%% 手写数字分类(使用KNN示例)

% 训练KNN模型

knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3);

% 处理测试集

numTest = numel(testImgs.Files);

testFeatures = zeros(numTest, 324);

testLabels = testImgs.Labels;

for i = 1:numTest

img = readimage(testImgs, i);

testFeatures(i, :) = extractHOGFeatures(img);

end

% 预测并评估

digitPred = predict(knnDigitModel, testFeatures);

digitAcc = sum(digitPred == testLabels) / numel(testLabels)

%% 回归任务(汽车数据)

rng(2);

cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25);

idxTrain_car = training(cv_car);

idxTest_car = test(cv_car);

% 线性回归

lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car));

lmPred = predict(lmModel, X_car(idxTest_car,:));

lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 多项式回归(二次项)

polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2');

polyPred = predict(polyModel, X_car(idxTest_car,:));

polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 可视化回归结果

figure;

scatter(Y_car(idxTest_car), lmPred, 'b');

hold on;

scatter(Y_car(idxTest_car), polyPred, 'r');

plot(0,50, 0,50, 'k--');

xlabel('Actual MPG');

ylabel('Predicted MPG');

legend('Linear', 'Polynomial', 'Ideal');

title('Regression Results Comparison');

关键函数说明:

  1. 分类模型训练:

    • fitcknn(): K近邻分类器

    • fitctree(): 决策树分类器

    • fitcecoc(): 多类SVM分类器

  2. 回归模型训练:

    • fitlm(): 线性/多项式回归

    • 'poly2'参数: 指定二次多项式项

  3. 评估指标:

    • confusionchart(): 可视化混淆矩阵

    • loss(): 计算均方误差(回归)

    • 准确率 = 正确预测数/总样本数(分类)

执行结果

鸢尾花分类准确率:

knnAcc = 0.9778

treeAcc = 0.9556

svmAcc = 0.9778

手写数字分类准确率:

digitAcc = 0.9432

回归均方误差:

lmMSE = 15.672

polyMSE = 12.845

注意事项:

  1. 特征工程

    • 手写数字使用HOG特征替代原始像素

    • 汽车数据组合多个特征(重量/马力/气缸数)

  2. 数据预处理

    • 自动处理缺失值(fitlm会排除含NaN的行)

    • 分类数据自动编码(SVM使用整数编码)

  3. 模型优化

    • 可通过crossval函数进行交叉验证

    • 使用HyperparameterOptimization参数自动调优

  4. 可视化

    • 回归结果对比图显示预测值与实际值关系

    • 混淆矩阵直观展示分类错误分布

此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。

相关推荐
我爱C编程5 小时前
基于ECC簇内分组密钥管理算法的无线传感器网络matlab性能仿真
网络·matlab·ecc·密钥管理·无线传感器网络·簇内分组
guygg885 小时前
二维电子气在三角形势阱中的量子特性计算
matlab
zhangfeng11337 小时前
超算/曙光DCU集群 昆山站 根目录文件夹逐项释义(HTC调度集群环境、国产DCU算力节点)
人工智能·pytorch·机器学习
KWTXX7 小时前
使用matlab官网的skills调用claude-待完成
开发语言·matlab
weixin_468466858 小时前
多鲁棒优化新手实战指南
人工智能·深度学习·机器学习·ai·模型优化
计算机安禾9 小时前
【算法分析与设计】第36篇:计算几何基础:凸包问题的分治与扫描线解法
大数据·人工智能·算法·机器学习·剪枝
彬鸿科技10 小时前
bhSDR Studio/Matlab入门指南(十二):AI神经网络训练(Resnet-SE) 实验界面全解析
人工智能·神经网络·matlab·软件无线电·sdr
zhangfeng113310 小时前
如果模型h200训练好的模型 要部署到华为 升腾 950导致的误差怎么处理
人工智能·机器学习
词元Max11 小时前
4.1 监督学习入门:线性回归与分类
学习·分类·线性回归