MATLAB中进行深度学习网络训练的模型评估步骤

文章目录


前言

在 MATLAB 中进行深度学习网络训练后的模型评估是确保模型性能和可靠性的关键环节。以下是详细的评估步骤和方法。


环境配置

MATLAB下载安装教程:https://blog.csdn.net/tyatyatya/article/details/147879353
MATLAB下载地址链接:https://pan.quark.cn/s/364584a880f7

一、基础性能评估

  1. 分类准确率(Classification Accuracy)
c 复制代码
% 在测试集上进行预测
YPred = classify(net, imdsTest);  % 对图像数据
YPred = predict(net, XTest);      % 对数值数据

% 计算准确率
YTest = imdsTest.Labels;
accuracy = mean(YPred == YTest);
fprintf('测试集准确率: %.2f%%\n', accuracy*100);
  1. 混淆矩阵(Confusion Matrix)
c 复制代码
% 生成混淆矩阵
cm = confusionmat(YTest, YPred);

% 可视化混淆矩阵
figure
cmplot = confusionchart(cm, categories(YTest));
cmplot.Title = '混淆矩阵';
cmplot.RowSummary = 'row-normalized';  % 显示行归一化(召回率)
cmplot.ColumnSummary = 'column-normalized';  % 显示列归一化(精确率)
  1. 精确率、召回率与 F1 分数
c 复制代码
% 计算每个类别的精确率、召回率和F1分数
categories = unique(YTest);
metrics = table(categories, zeros(length(categories), 3), 'VariableNames', {'Category', 'Precision', 'Recall', 'F1Score'});

for i = 1:length(categories)
    truePositives = cm(i,i);
    falsePositives = sum(cm(:,i)) - truePositives;
    falseNegatives = sum(cm(i,:)) - truePositives;
    
    metrics.Precision(i) = truePositives / (truePositives + falsePositives);
    metrics.Recall(i) = truePositives / (truePositives + falseNegatives);
    metrics.F1Score(i) = 2 * (metrics.Precision(i) * metrics.Recall(i)) / (metrics.Precision(i) + metrics.Recall(i));
end

% 计算宏平均和微平均
macroPrecision = mean(metrics.Precision);
macroRecall = mean(metrics.Recall);
macroF1 = mean(metrics.F1Score);

microPrecision = sum(diag(cm)) / sum(sum(cm));
microRecall = microPrecision;  % 微平均精确率和召回率相等
microF1 = 2 * (microPrecision * microRecall) / (microPrecision + microRecall);

fprintf('宏平均 F1 分数: %.4f\n', macroF1);
fprintf('微平均 F1 分数: %.4f\n', microF1);

二、高级评估指标

  1. ROC 曲线与 AUC 值(二分类问题)
c 复制代码
% 获取预测概率
[YPred, scores] = classify(net, imdsTest, 'OutputAs', 'probabilities');

% 计算ROC曲线
figure
for i = 1:numel(categories)
    [x, y, t, auc] = perfcurve(YTest, scores(:,i), categories(i));
    plot(x, y, 'DisplayName', [categories(i), ': AUC = ', num2str(auc, '%.3f')])
end
title('ROC曲线')
xlabel('假阳性率 (FPR)')
ylabel('真阳性率 (TPR)')
legend
grid on
  1. 损失函数曲线分析
c 复制代码
% 绘制训练过程中的损失函数曲线
figure
plot(tr.TrainingLoss, 'b-', 'LineWidth', 2)
hold on
plot(tr.ValidationLoss, 'r-', 'LineWidth', 2)
title('训练与验证损失')
xlabel('训练轮次 (Epoch)')
ylabel('损失值')
legend('训练损失', '验证损失')
grid on
  1. 学习率调整分析
c 复制代码
% 绘制学习率随训练轮次的变化
figure
plot(tr.LearnRate, 'LineWidth', 2)
title('学习率调整')
xlabel('训练轮次 (Epoch)')
ylabel('学习率')
grid on

三、模型解释与可视化

  1. 类激活映射(Class Activation Mapping, CAM)
c 复制代码
% 计算并可视化类激活映射
I = imread('test_image.jpg');
[YPred, scores] = classify(net, I, 'OutputAs', 'probabilities');
cam = activation(net, I, 'last_conv_layer', 'OutputAs', 'image');  % 替换为实际最后卷积层名称

figure
subplot(1,2,1)
imshow(I)
title('原始图像')

subplot(1,2,2)
imshow(I)
hold on
h = imagesc(cam, 'AlphaData', cam);
colormap jet
axis off
title(['预测: ', string(YPred), ', 置信度: ', num2str(max(scores), '%.2f')])
colorbar
  1. 特征可视化
c 复制代码
% 可视化中间层特征
I = imread('test_image.jpg');
features = activation(net, I, 'conv2_1');  % 替换为实际层名称

% 可视化前16个特征图
figure
for i = 1:min(16, size(features, 3))
    subplot(4, 4, i)
    imshow(features(:,:,i), 'DisplayRange', [])
    title(['特征图 ', num2str(i)])
end
  1. 决策边界分析(二维数据)
c 复制代码
% 生成网格点
[x1Grid, x2Grid] = meshgrid(linspace(min(XTest(:,1)), max(XTest(:,1)), 100), ...
                            linspace(min(XTest(:,2)), max(XTest(:,2)), 100));
gridPoints = [x1Grid(:), x2Grid(:)];

% 预测网格点
YPredGrid = classify(net, gridPoints);

% 可视化决策边界
figure
gscatter(XTest(:,1), XTest(:,2), YTest)
hold on
contourf(x1Grid, x2Grid, reshape(YPredGrid, size(x1Grid)), 'Alpha', 0.3)
title('决策边界可视化')
legend('类别1', '类别2', '决策边界')

四、交叉验证与模型选择

  1. K 折交叉验证
c 复制代码
% 设置K折交叉验证
k = 5;
cv = cvpartition(height(tbl), 'KFold', k);

% 存储每折的准确率
accuracies = zeros(k, 1);

% 执行交叉验证
for i = 1:k
    idxTrain = training(cv, i);
    idxTest = test(cv, i);
    
    % 训练模型
    net = trainNetwork(imds(idxTrain), layers, options);
    
    % 评估模型
    YPred = classify(net, imds(idxTest));
    accuracies(i) = mean(YPred == imds.Labels(idxTest));
end

% 计算平均准确率和标准差
meanAccuracy = mean(accuracies);
stdAccuracy = std(accuracies);
fprintf('交叉验证准确率: %.2f%% ± %.2f%%\n', meanAccuracy*100, stdAccuracy*100);
  1. 模型比较
c 复制代码
% 比较不同模型架构
models = {'resnet18', 'resnet50', 'alexnet'};
results = table(models, zeros(length(models), 1), 'VariableNames', {'Model', 'Accuracy'});

for i = 1:length(models)
    % 加载预训练模型
    net = eval(models{i});
    
    % 修改网络结构
    % ... [省略网络修改代码] ...
    
    % 训练模型
    trainedNet = trainNetwork(imdsTrain, lgraph, options);
    
    % 评估模型
    YPred = classify(trainedNet, imdsTest);
    results.Accuracy(i) = mean(YPred == YTest);
end

% 显示比较结果
results = sortrows(results, 'Accuracy', 'descend');
disp(results);

五、部署前的优化

  1. 模型量化
c 复制代码
% 量化模型以减小尺寸和加速推理
quantizedNet = quantizeNetwork(net, 'WeightPrecision', 8, 'ActivationPrecision', 8);

% 评估量化模型
YPredQuantized = classify(quantizedNet, imdsTest);
accuracyQuantized = mean(YPredQuantized == YTest);
fprintf('量化模型准确率: %.2f%%\n', accuracyQuantized*100);
  1. 剪枝(Pruning)
c 复制代码
% 对模型进行剪枝
prunedNet = pruneNetwork(net, 'Percentage', 50);  % 剪枝50%的连接

% 微调剪枝后的模型
optionsFineTune = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.0001, ...
    'MaxEpochs', 3);
prunedNet = trainNetwork(imdsTrain, prunedNet, optionsFineTune);

% 评估剪枝模型
YPredPruned = classify(prunedNet, imdsTest);
accuracyPruned = mean(YPredPruned == YTest);
fprintf('剪枝模型准确率: %.2f%%\n', accuracyPruned*100);
相关推荐
Evand J14 分钟前
【MATLAB例程】线性卡尔曼滤波的程序,三维状态量和观测量,较为简单,可用于理解多维KF,附代码下载链接
开发语言·matlab
有Li33 分钟前
联合建模组织学和分子标记用于癌症分类|文献速递-深度学习医疗AI最新文献
人工智能·深度学习·分类
乌旭44 分钟前
开源GPU架构RISC-V VCIX的深度学习潜力测试:从RTL仿真到MNIST实战
人工智能·深度学习·stable diffusion·架构·aigc·midjourney·risc-v
立秋67892 小时前
从零开始:使用 PyTorch 构建深度学习网络
人工智能·pytorch·深度学习
21级的乐未央2 小时前
论文阅读(四):Agglomerative Transformer for Human-Object Interaction Detection
论文阅读·深度学习·计算机视觉·transformer
xiaobin889992 小时前
matlab官方免费下载安装超详细教程2025最新matlab安装教程(MATLAB R2024b)
java·开发语言·其他·matlab
Blossom.1183 小时前
基于区块链技术的供应链溯源系统:重塑信任与透明度
服务器·网络·人工智能·目标检测·机器学习·计算机视觉·区块链
冷崖3 小时前
网络编程-select(二)
网络·学习
埃菲尔铁塔_CV算法3 小时前
深度学习驱动下的目标检测技术:原理、算法与应用创新(二)
深度学习·算法·目标检测
KangkangLoveNLP4 小时前
Llama:开源的急先锋
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理·llama