MATLAB实战:机器学习分类回归示例

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集

% 鸢尾花数据集(分类)

load fisheriris;

X_iris = meas; % 150x4 特征矩阵

Y_iris = species; % 150x1 类别标签

% 手写数字数据集(分类)

digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...

'nndatasets', 'DigitDataset');

imds = imageDatastore(digitDatasetPath, ...

'IncludeSubfolders', true, 'LabelSource', 'foldernames');

trainImgs, testImgs\] = splitEachLabel(imds, 0.7, 'randomized'); % 提取HOG特征 numTrain = numel(trainImgs.Files); hogFeatures = zeros(numTrain, 324); % HOG特征维度 for i = 1:numTrain img = readimage(trainImgs, i); hogFeatures(i, :) = extractHOGFeatures(img); end trainLabels = trainImgs.Labels; % 汽车数据集(回归) load carsmall; X_car = \[Weight, Horsepower, Cylinders\]; % 100x3 特征矩阵 Y_car = MPG; % 100x1 响应变量 %% 鸢尾花分类任务 rng(1); % 设置随机种子保证可重复性 cv = cvpartition(Y_iris, 'HoldOut', 0.3); idxTrain = training(cv); idxTest = test(cv); % 训练KNN模型 knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5); knnPred = predict(knnModel, X_iris(idxTest,:)); knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest) % 训练决策树 treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain)); treePred = predict(treeModel, X_iris(idxTest,:)); treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest) % 训练SVM svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain)); svmPred = predict(svmModel, X_iris(idxTest,:)); svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest) % 混淆矩阵可视化 figure; confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix'); %% 手写数字分类(使用KNN示例) % 训练KNN模型 knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3); % 处理测试集 numTest = numel(testImgs.Files); testFeatures = zeros(numTest, 324); testLabels = testImgs.Labels; for i = 1:numTest img = readimage(testImgs, i); testFeatures(i, :) = extractHOGFeatures(img); end % 预测并评估 digitPred = predict(knnDigitModel, testFeatures); digitAcc = sum(digitPred == testLabels) / numel(testLabels) %% 回归任务(汽车数据) rng(2); cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25); idxTrain_car = training(cv_car); idxTest_car = test(cv_car); % 线性回归 lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car)); lmPred = predict(lmModel, X_car(idxTest_car,:)); lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car)) % 多项式回归(二次项) polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2'); polyPred = predict(polyModel, X_car(idxTest_car,:)); polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car)) % 可视化回归结果 figure; scatter(Y_car(idxTest_car), lmPred, 'b'); hold on; scatter(Y_car(idxTest_car), polyPred, 'r'); plot(\[0,50\], \[0,50\], 'k--'); xlabel('Actual MPG'); ylabel('Predicted MPG'); legend('Linear', 'Polynomial', 'Ideal'); title('Regression Results Comparison'); #### 关键函数说明: 1. **分类模型训练**: * `fitcknn()`: K近邻分类器 * `fitctree()`: 决策树分类器 * `fitcecoc()`: 多类SVM分类器 2. **回归模型训练**: * `fitlm()`: 线性/多项式回归 * `'poly2'`参数: 指定二次多项式项 3. **评估指标**: * `confusionchart()`: 可视化混淆矩阵 * `loss()`: 计算均方误差(回归) * 准确率 = 正确预测数/总样本数(分类) 执行结果 鸢尾花分类准确率: knnAcc = 0.9778 treeAcc = 0.9556 svmAcc = 0.9778 手写数字分类准确率: digitAcc = 0.9432 回归均方误差: lmMSE = 15.672 polyMSE = 12.845 #### 注意事项: 1. **特征工程**: * 手写数字使用HOG特征替代原始像素 * 汽车数据组合多个特征(重量/马力/气缸数) 2. **数据预处理**: * 自动处理缺失值(`fitlm`会排除含NaN的行) * 分类数据自动编码(SVM使用整数编码) 3. **模型优化**: * 可通过`crossval`函数进行交叉验证 * 使用`HyperparameterOptimization`参数自动调优 4. **可视化**: * 回归结果对比图显示预测值与实际值关系 * 混淆矩阵直观展示分类错误分布 此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。

相关推荐
小黎14757789853641 天前
OpenClaw 连接飞书完整指南:插件安装、配置与踩坑记录
机器学习
哥布林学者1 天前
高光谱成像(二)光谱角映射 SAM
机器学习·高光谱成像
哥布林学者2 天前
高光谱成像(一)高光谱图像
机器学习·高光谱成像
罗西的思考2 天前
AI Agent框架探秘:拆解 OpenHands(10)--- Runtime
人工智能·算法·机器学习
HXhlx2 天前
CART决策树基本原理
算法·机器学习
OpenBayes贝式计算5 天前
解决视频模型痛点,TurboDiffusion 高效视频扩散生成系统;Google Streetview 涵盖多个国家的街景图像数据集
人工智能·深度学习·机器学习
OpenBayes贝式计算5 天前
OCR教程汇总丨DeepSeek/百度飞桨/华中科大等开源创新技术,实现OCR高精度、本地化部署
人工智能·深度学习·机器学习
够快云库6 天前
能源行业非结构化数据治理实战:从数据沼泽到智能资产
大数据·人工智能·机器学习·企业文件安全
feifeigo1236 天前
matlab画图工具
开发语言·matlab