MATLAB实战:机器学习分类回归示例

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集

% 鸢尾花数据集(分类)

load fisheriris;

X_iris = meas; % 150x4 特征矩阵

Y_iris = species; % 150x1 类别标签

% 手写数字数据集(分类)

digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...

'nndatasets', 'DigitDataset');

imds = imageDatastore(digitDatasetPath, ...

'IncludeSubfolders', true, 'LabelSource', 'foldernames');

trainImgs, testImgs\] = splitEachLabel(imds, 0.7, 'randomized'); % 提取HOG特征 numTrain = numel(trainImgs.Files); hogFeatures = zeros(numTrain, 324); % HOG特征维度 for i = 1:numTrain img = readimage(trainImgs, i); hogFeatures(i, :) = extractHOGFeatures(img); end trainLabels = trainImgs.Labels; % 汽车数据集(回归) load carsmall; X_car = \[Weight, Horsepower, Cylinders\]; % 100x3 特征矩阵 Y_car = MPG; % 100x1 响应变量 %% 鸢尾花分类任务 rng(1); % 设置随机种子保证可重复性 cv = cvpartition(Y_iris, 'HoldOut', 0.3); idxTrain = training(cv); idxTest = test(cv); % 训练KNN模型 knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5); knnPred = predict(knnModel, X_iris(idxTest,:)); knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest) % 训练决策树 treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain)); treePred = predict(treeModel, X_iris(idxTest,:)); treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest) % 训练SVM svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain)); svmPred = predict(svmModel, X_iris(idxTest,:)); svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest) % 混淆矩阵可视化 figure; confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix'); %% 手写数字分类(使用KNN示例) % 训练KNN模型 knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3); % 处理测试集 numTest = numel(testImgs.Files); testFeatures = zeros(numTest, 324); testLabels = testImgs.Labels; for i = 1:numTest img = readimage(testImgs, i); testFeatures(i, :) = extractHOGFeatures(img); end % 预测并评估 digitPred = predict(knnDigitModel, testFeatures); digitAcc = sum(digitPred == testLabels) / numel(testLabels) %% 回归任务(汽车数据) rng(2); cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25); idxTrain_car = training(cv_car); idxTest_car = test(cv_car); % 线性回归 lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car)); lmPred = predict(lmModel, X_car(idxTest_car,:)); lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car)) % 多项式回归(二次项) polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2'); polyPred = predict(polyModel, X_car(idxTest_car,:)); polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car)) % 可视化回归结果 figure; scatter(Y_car(idxTest_car), lmPred, 'b'); hold on; scatter(Y_car(idxTest_car), polyPred, 'r'); plot(\[0,50\], \[0,50\], 'k--'); xlabel('Actual MPG'); ylabel('Predicted MPG'); legend('Linear', 'Polynomial', 'Ideal'); title('Regression Results Comparison'); #### 关键函数说明: 1. **分类模型训练**: * `fitcknn()`: K近邻分类器 * `fitctree()`: 决策树分类器 * `fitcecoc()`: 多类SVM分类器 2. **回归模型训练**: * `fitlm()`: 线性/多项式回归 * `'poly2'`参数: 指定二次多项式项 3. **评估指标**: * `confusionchart()`: 可视化混淆矩阵 * `loss()`: 计算均方误差(回归) * 准确率 = 正确预测数/总样本数(分类) 执行结果 鸢尾花分类准确率: knnAcc = 0.9778 treeAcc = 0.9556 svmAcc = 0.9778 手写数字分类准确率: digitAcc = 0.9432 回归均方误差: lmMSE = 15.672 polyMSE = 12.845 #### 注意事项: 1. **特征工程**: * 手写数字使用HOG特征替代原始像素 * 汽车数据组合多个特征(重量/马力/气缸数) 2. **数据预处理**: * 自动处理缺失值(`fitlm`会排除含NaN的行) * 分类数据自动编码(SVM使用整数编码) 3. **模型优化**: * 可通过`crossval`函数进行交叉验证 * 使用`HyperparameterOptimization`参数自动调优 4. **可视化**: * 回归结果对比图显示预测值与实际值关系 * 混淆矩阵直观展示分类错误分布 此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。

相关推荐
Coovally AI模型快速验证2 小时前
全景式综述|多模态目标跟踪全面解析:方法、数据、挑战与未来
人工智能·深度学习·算法·机器学习·计算机视觉·目标跟踪·无人机
勤劳的进取家4 小时前
论文阅读:Do As I Can, Not As I Say: Grounding Language in Robotic Affordances
论文阅读·人工智能·机器学习·语言模型·自然语言处理
这张生成的图像能检测吗4 小时前
(论文速读)RandAR:突破传统限制的随机顺序图像自回归生成模型
图像处理·人工智能·机器学习·计算机视觉·生成模型·自回归模型
摘星编程7 小时前
金融风控AI引擎:实时反欺诈系统的架构设计与实现
机器学习·实时计算·金融风控·反欺诈系统·ai引擎
山烛12 小时前
矿物分类系统开发笔记(一):数据预处理
人工智能·python·机器学习·矿物分类
拾零吖12 小时前
吴恩达 Machine Learning(Class 3)
人工智能·机器学习
wearegogog12313 小时前
MATLAB的脉搏信号分析预处理
算法·matlab
foundbug99913 小时前
人工神经网络MATLAB工具箱指南
matlab
MaxCode-113 小时前
【机器学习 / 深度学习】基础教程
人工智能·深度学习·机器学习