MATLAB实现,包含数据预处理、CNN模型构建、训练、评估和实时推理等功能,专门用于道路特征提取任务。
一、程序结构
MATLAB_CNN_Road_Feature_Extraction/
├── main.m # 主程序入口
├── data_preprocessing.m # 数据预处理
├── cnn_model_construction.m # CNN模型构建
├── training_pipeline.m # 训练流程
├── evaluation_metrics.m # 评估指标
├── realtime_inference.m # 实时推理
├── visualization.m # 可视化工具
├── dataset/ # 数据集
│ ├── images/ # 道路图像
│ ├── labels/ # 分割标签
│ └── annotations/ # 标注文件
└── models/ # 训练好的模型
二、核心代码实现
2.1 主程序 (main.m)
matlab
%% CNN道路特征提取主程序
% 作者:AI Assistant
% 日期:2024年
% 功能:使用CNN进行道路特征提取(车道线、边界、标志等)
clear; clc; close all;
% 设置随机种子确保可重复性
rng(42);
%% 1. 配置参数
config = struct();
config.dataDir = './dataset';
config.imageSize = [256, 256, 3]; % 输入图像尺寸
config.batchSize = 16;
config.learningRate = 0.001;
config.maxEpochs = 50;
config.validationSplit = 0.2;
config.modelType = 'unet'; % 'unet', 'segnet', 'deeplabv3'
config.device = 'gpu'; % 'gpu', 'cpu'
fprintf('=== CNN道路特征提取系统 ===\n');
fprintf('配置参数:\n');
fprintf(' 图像尺寸: %d x %d x %d\n', config.imageSize(1), config.imageSize(2), config.imageSize(3));
fprintf(' 批次大小: %d\n', config.batchSize);
fprintf(' 训练轮数: %d\n', config.maxEpochs);
fprintf(' 模型类型: %s\n', config.modelType);
%% 2. 数据预处理
fprintf('\n1. 数据预处理...\n');
[dataStore, labelStore] = data_preprocessing(config);
%% 3. 构建CNN模型
fprintf('2. 构建CNN模型...\n');
model = cnn_model_construction(config);
%% 4. 训练模型
fprintf('3. 训练模型...\n');
[trainedModel, trainingInfo] = training_pipeline(model, dataStore, labelStore, config);
%% 5. 评估模型
fprintf('4. 评估模型性能...\n');
metrics = evaluation_metrics(trainedModel, dataStore, labelStore, config);
%% 6. 可视化结果
fprintf('5. 可视化结果...\n');
visualization(trainedModel, dataStore, metrics, config);
%% 7. 保存模型
savePath = './models/road_feature_extractor.mat';
save(savePath, 'trainedModel', 'config', 'metrics');
fprintf('模型已保存到: %s\n', savePath);
%% 8. 实时推理演示(可选)
demo_realtime_inference = true;
if demo_realtime_inference
fprintf('\n6. 实时推理演示...\n');
realtime_inference(trainedModel, config);
end
fprintf('\n=== 程序执行完成 ===\n');
2.2 数据预处理 (data_preprocessing.m)
matlab
function [dataStore, labelStore] = data_preprocessing(config)
% 数据预处理函数
% 输入:
% config - 配置参数
% 输出:
% dataStore - 图像数据存储
% labelStore - 标签数据存储
fprintf(' 加载数据集...\n');
% 检查数据集是否存在
if ~exist(config.dataDir, 'dir')
error('数据集目录不存在: %s', config.dataDir);
end
% 创建图像数据存储
imageDir = fullfile(config.dataDir, 'images');
if ~exist(imageDir, 'dir')
error('图像目录不存在: %s', imageDir);
end
% 创建像素标签数据存储
labelDir = fullfile(config.dataDir, 'labels');
if ~exist(labelDir, 'dir')
error('标签目录不存在: %s', labelDir);
end
% 获取文件列表
imageFiles = dir(fullfile(imageDir, '*.jpg'));
if isempty(imageFiles)
imageFiles = dir(fullfile(imageDir, '*.png'));
end
if isempty(imageFiles)
error('未找到图像文件');
end
fprintf(' 找到 %d 张图像\n', length(imageFiles));
% 创建数据存储
dataStore = imageDatastore(imageDir, ...
'FileExtensions', {'.jpg', '.png'}, ...
'ReadFcn', @(filename) imread(filename));
labelStore = pixelLabelDatastore(labelDir, ...
'FileExtensions', {'.png'}, ...
'ReadFcn', @(filename) imread(filename));
% 数据增强
augmenter = imageDataAugmenter( ...
'RandRotation', [-10, 10], ...
'RandXReflection', true, ...
'RandYReflection', false, ...
'RandXScale', [0.9, 1.1], ...
'RandYScale', [0.9, 1.1]);
% 创建增强数据存储
augmentedDataStore = augmentedImageDatastore(config.imageSize(1:2), ...
dataStore, labelStore, ...
'DataAugmentation', augmenter, ...
'ColorPreprocessing', 'gray2rgb');
% 分割训练集和验证集
numFiles = length(imageFiles);
idx = randperm(numFiles);
valIdx = idx(1:round(config.validationSplit * numFiles));
trainIdx = idx(round(config.validationSplit * numFiles)+1:end);
% 创建训练集和验证集
trainingData = subset(augmentedDataStore, trainIdx);
validationData = subset(augmentedDataStore, valIdx);
fprintf(' 训练集: %d 张图像\n', length(trainIdx));
fprintf(' 验证集: %d 张图像\n', length(valIdx));
% 更新数据存储
dataStore.train = trainingData;
dataStore.val = validationData;
dataStore.test = validationData; % 使用验证集作为测试集
end
2.3 CNN模型构建 (cnn_model_construction.m)
matlab
function model = cnn_model_construction(config)
% 构建CNN模型
% 输入:
% config - 配置参数
% 输出:
% model - CNN模型
fprintf(' 构建 %s 模型...\n', config.modelType);
% 根据选择的模型类型构建
switch config.modelType
case 'unet'
model = unet_model(config);
case 'segnet'
model = segnet_model(config);
case 'deeplabv3'
model = deeplabv3_model(config);
otherwise
error('不支持的模型类型: %s', config.modelType);
end
% 显示模型结构
analyzeNetwork(model);
end
%% U-Net模型
function model = unet_model(config)
% 构建U-Net模型
% 输入层
layers = [
imageInputLayer(config.imageSize, 'Name', 'input', 'Normalization', 'zerocenter')
% 编码器部分
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv1_1')
batchNormalizationLayer('Name', 'bn1_1')
reluLayer('Name', 'relu1_1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv1_2')
batchNormalizationLayer('Name', 'bn1_2')
reluLayer('Name', 'relu1_2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')
convolution2dLayer(3, 128, 'Padding', 'same', 'Name', 'conv2_1')
batchNormalizationLayer('Name', 'bn2_1')
reluLayer('Name', 'relu2_1')
convolution2dLayer(3, 128, 'Padding', 'same', 'Name', 'conv2_2')
batchNormalizationLayer('Name', 'bn2_2')
reluLayer('Name', 'relu2_2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')
convolution2dLayer(3, 256, 'Padding', 'same', 'Name', 'conv3_1')
batchNormalizationLayer('Name', 'bn3_1')
reluLayer('Name', 'relu3_1')
convolution2dLayer(3, 256, 'Padding', 'same', 'Name', 'conv3_2')
batchNormalizationLayer('Name', 'bn3_2')
reluLayer('Name', 'relu3_2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool3')
% 桥接层
convolution2dLayer(3, 512, 'Padding', 'same', 'Name', 'conv4_1')
batchNormalizationLayer('Name', 'bn4_1')
reluLayer('Name', 'relu4_1')
convolution2dLayer(3, 512, 'Padding', 'same', 'Name', 'conv4_2')
batchNormalizationLayer('Name', 'bn4_2')
reluLayer('Name', 'relu4_2')
% 解码器部分
transposedConv2dLayer(2, 256, 'Stride', 2, 'Name', 'deconv1')
concatenationLayer(3, 2, 'Name', 'concat1')
convolution2dLayer(3, 256, 'Padding', 'same', 'Name', 'conv5_1')
batchNormalizationLayer('Name', 'bn5_1')
reluLayer('Name', 'relu5_1')
convolution2dLayer(3, 256, 'Padding', 'same', 'Name', 'conv5_2')
batchNormalizationLayer('Name', 'bn5_2')
reluLayer('Name', 'relu5_2')
transposedConv2dLayer(2, 128, 'Stride', 2, 'Name', 'deconv2')
concatenationLayer(3, 2, 'Name', 'concat2')
convolution2dLayer(3, 128, 'Padding', 'same', 'Name', 'conv6_1')
batchNormalizationLayer('Name', 'bn6_1')
reluLayer('Name', 'relu6_1')
convolution2dLayer(3, 128, 'Padding', 'same', 'Name', 'conv6_2')
batchNormalizationLayer('Name', 'bn6_2')
reluLayer('Name', 'relu6_2')
transposedConv2dLayer(2, 64, 'Stride', 2, 'Name', 'deconv3')
concatenationLayer(3, 2, 'Name', 'concat3')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv7_1')
batchNormalizationLayer('Name', 'bn7_1')
reluLayer('Name', 'relu7_1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv7_2')
batchNormalizationLayer('Name', 'bn7_2')
reluLayer('Name', 'relu7_2')
% 输出层
convolution2dLayer(1, 4, 'Padding', 'same', 'Name', 'final_conv') % 4类:背景、车道线、道路边界、交通标志
softmaxLayer('Name', 'softmax')
pixelClassificationLayer('Name', 'output')
];
% 创建层图
lgraph = layerGraph(layers);
% 连接拼接层
lgraph = connectLayers(lgraph, 'pool3', 'concat1/in2');
lgraph = connectLayers(lgraph, 'pool2', 'concat2/in2');
lgraph = connectLayers(lgraph, 'pool1', 'concat3/in2');
% 转换为dlnetwork对象
model = dlnetwork(lgraph);
end
%% SegNet模型
function model = segnet_model(config)
% 构建SegNet模型
layers = [
imageInputLayer(config.imageSize, 'Name', 'input')
% 编码器
convolution2dLayer(7, 64, 'Padding', 'same', 'Name', 'conv1')
batchNormalizationLayer('Name', 'bn1')
reluLayer('Name', 'relu1')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')
convolution2dLayer(7, 128, 'Padding', 'same', 'Name', 'conv2')
batchNormalizationLayer('Name', 'bn2')
reluLayer('Name', 'relu2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')
convolution2dLayer(7, 256, 'Padding', 'same', 'Name', 'conv3')
batchNormalizationLayer('Name', 'bn3')
reluLayer('Name', 'relu3')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool3')
% 解码器
maxUnpooling2dLayer('Name', 'unpool1')
convolution2dLayer(7, 128, 'Padding', 'same', 'Name', 'deconv1')
batchNormalizationLayer('Name', 'bn4')
reluLayer('Name', 'relu4')
maxUnpooling2dLayer('Name', 'unpool2')
convolution2dLayer(7, 64, 'Padding', 'same', 'Name', 'deconv2')
batchNormalizationLayer('Name', 'bn5')
reluLayer('Name', 'relu5')
maxUnpooling2dLayer('Name', 'unpool3')
convolution2dLayer(7, 4, 'Padding', 'same', 'Name', 'deconv3')
softmaxLayer('Name', 'softmax')
pixelClassificationLayer('Name', 'output')
];
model = dlnetwork(layerGraph(layers));
end
%% DeepLabV3+模型
function model = deeplabv3_model(config)
% 构建DeepLabV3+模型
% 使用预训练的ResNet-50作为骨干网络
imageSize = config.imageSize(1:2);
numClasses = 4;
% 创建DeepLabV3+网络
lgraph = deeplabv3plusLayers(imageSize, numClasses, 'resnet50');
% 转换为dlnetwork对象
model = dlnetwork(lgraph);
end
2.4 训练流程 (training_pipeline.m)
matlab
function [trainedModel, trainingInfo] = training_pipeline(model, dataStore, labelStore, config)
% 训练CNN模型
% 输入:
% model - CNN模型
% dataStore - 数据存储
% labelStore - 标签存储
% config - 配置参数
% 输出:
% trainedModel - 训练好的模型
% trainingInfo - 训练信息
fprintf(' 开始训练...\n');
% 设置训练选项
options = trainingOptions('adam', ...
'InitialLearnRate', config.learningRate, ...
'MaxEpochs', config.maxEpochs, ...
'MiniBatchSize', config.batchSize, ...
'Shuffle', 'every-epoch', ...
'ValidationData', dataStore.val, ...
'ValidationFrequency', 10, ...
'Plots', 'training-progress', ...
'Verbose', true, ...
'ExecutionEnvironment', config.device, ...
'GradientThreshold', 1, ...
'L2Regularization', 0.0001, ...
'OutputFcn', @(info) displayTrainingProgress(info));
% 训练模型
[trainedModel, trainingInfo] = trainNetwork(dataStore.train, model, options);
fprintf(' 训练完成!\n');
fprintf(' 最终训练准确率: %.2f%%\n', trainingInfo.FinalValidationAccuracy);
end
function stop = displayTrainingProgress(info)
% 显示训练进度
persistent startTime
if info.State == "start"
startTime = tic;
fprintf(' 训练开始...\n');
stop = false;
elseif info.State == "iteration"
elapsedTime = toc(startTime);
fprintf(' 迭代 %d: 损失 = %.4f, 准确率 = %.2f%%, 耗时 = %.1f秒\n', ...
info.Iteration, info.Loss, info.Accuracy, elapsedTime);
stop = false;
elseif info.State == "done"
fprintf(' 训练完成!总耗时: %.1f秒\n', toc(startTime));
stop = false;
end
end
2.5 评估指标 (evaluation_metrics.m)
matlab
function metrics = evaluation_metrics(model, dataStore, labelStore, config)
% 评估模型性能
% 输入:
% model - 训练好的模型
% dataStore - 数据存储
% labelStore - 标签存储
% config - 配置参数
% 输出:
% metrics - 评估指标
fprintf(' 评估模型性能...\n');
% 获取测试数据
testData = dataStore.test;
% 进行预测
predictions = semanticseg(testData, model, 'MiniBatchSize', config.batchSize);
% 计算混淆矩阵
confusionMatrix = confusionmat(labelStore, predictions);
% 计算各类别指标
numClasses = size(confusionMatrix, 1);
metrics.classNames = {'Background', 'Lane', 'Boundary', 'Sign'};
metrics.confusionMatrix = confusionMatrix;
% 计算每个类别的精度、召回率和F1分数
metrics.precision = zeros(numClasses, 1);
metrics.recall = zeros(numClasses, 1);
metrics.f1Score = zeros(numClasses, 1);
metrics.iou = zeros(numClasses, 1);
for i = 1:numClasses
TP = confusionMatrix(i, i);
FP = sum(confusionMatrix(:, i)) - TP;
FN = sum(confusionMatrix(i, :)) - TP;
TN = sum(confusionMatrix(:)) - TP - FP - FN;
% 精度 (Precision)
if TP + FP > 0
metrics.precision(i) = TP / (TP + FP);
else
metrics.precision(i) = 0;
end
% 召回率 (Recall)
if TP + FN > 0
metrics.recall(i) = TP / (TP + FN);
else
metrics.recall(i) = 0;
end
% F1分数
if metrics.precision(i) + metrics.recall(i) > 0
metrics.f1Score(i) = 2 * (metrics.precision(i) * metrics.recall(i)) / ...
(metrics.precision(i) + metrics.recall(i));
else
metrics.f1Score(i) = 0;
end
% IoU (Intersection over Union)
if TP + FP + FN > 0
metrics.iou(i) = TP / (TP + FP + FN);
else
metrics.iou(i) = 0;
end
end
% 计算平均指标
metrics.meanPrecision = mean(metrics.precision);
metrics.meanRecall = mean(metrics.recall);
metrics.meanF1Score = mean(metrics.f1Score);
metrics.meanIoU = mean(metrics.iou);
% 计算像素准确率
totalPixels = sum(confusionMatrix(:));
correctPixels = sum(diag(confusionMatrix));
metrics.pixelAccuracy = correctPixels / totalPixels;
% 显示结果
fprintf(' 评估结果:\n');
fprintf(' 像素准确率: %.2f%%\n', metrics.pixelAccuracy * 100);
fprintf(' 平均精度: %.2f%%\n', metrics.meanPrecision * 100);
fprintf(' 平均召回率: %.2f%%\n', metrics.meanRecall * 100);
fprintf(' 平均F1分数: %.2f%%\n', metrics.meanF1Score * 100);
fprintf(' 平均IoU: %.2f%%\n', metrics.meanIoU * 100);
% 显示各类别指标
fprintf(' 各类别指标:\n');
for i = 1:numClasses
fprintf(' %s: Precision=%.2f%%, Recall=%.2f%%, F1=%.2f%%, IoU=%.2f%%\n', ...
metrics.classNames{i}, ...
metrics.precision(i)*100, ...
metrics.recall(i)*100, ...
metrics.f1Score(i)*100, ...
metrics.iou(i)*100);
end
end
2.6 可视化工具 (visualization.m)
matlab
function visualization(model, dataStore, metrics, config)
% 可视化结果
% 输入:
% model - 训练好的模型
% dataStore - 数据存储
% metrics - 评估指标
% config - 配置参数
fprintf(' 生成可视化结果...\n');
% 创建可视化目录
visDir = './visualization';
if ~exist(visDir, 'dir')
mkdir(visDir);
end
% 1. 显示混淆矩阵
figure('Name', '混淆矩阵', 'NumberTitle', 'off');
confusionchart(metrics.confusionMatrix, metrics.classNames);
title('道路特征分类混淆矩阵');
saveas(gcf, fullfile(visDir, 'confusion_matrix.png'));
% 2. 显示样本预测结果
figure('Name', '样本预测结果', 'NumberTitle', 'off', 'Position', [100, 100, 1200, 800]);
% 获取测试数据
testData = dataStore.test;
testImages = readall(testData);
% 选择5个样本进行可视化
numSamples = min(5, length(testImages));
sampleIndices = randperm(length(testImages), numSamples);
for i = 1:numSamples
subplot(3, numSamples, i);
imshow(testImages{sampleIndices(i)});
title(['原图 ' num2str(i)]);
subplot(3, numSamples, i + numSamples);
[prediction, scores] = semanticseg(testImages{sampleIndices(i)}, model);
imshow(label2rgb(prediction));
title(['预测 ' num2str(i)]);
subplot(3, numSamples, i + 2*numSamples);
% 显示置信度热力图
scoreMap = max(scores, [], 3);
imshow(scoreMap, []);
colormap('jet');
colorbar;
title(['置信度 ' num2str(i)]);
end
saveas(gcf, fullfile(visDir, 'sample_predictions.png'));
% 3. 显示各类别分割结果
figure('Name', '类别分割结果', 'NumberTitle', 'off', 'Position', [100, 100, 1000, 600]);
% 选择一个样本
sampleIdx = sampleIndices(1);
originalImage = testImages{sampleIdx};
prediction = semanticseg(originalImage, model);
% 分离各个类别
for classIdx = 1:4
subplot(2, 2, classIdx);
% 创建二值掩码
mask = prediction == classIdx;
% 显示掩码
imshow(mask);
title(metrics.classNames{classIdx});
% 在原图上叠加掩码
if any(mask(:))
subplot(2, 2, 4);
hold on;
h = imshow(originalImage);
set(h, 'AlphaData', 0.7);
overlay = imoverlay(originalImage, mask, 'red');
imshow(overlay);
title('叠加显示');
end
end
saveas(gcf, fullfile(visDir, 'class_segmentation.png'));
% 4. 显示性能指标柱状图
figure('Name', '性能指标', 'NumberTitle', 'off');
metrics_names = {'Pixel Accuracy', 'Mean Precision', 'Mean Recall', 'Mean F1', 'Mean IoU'};
metrics_values = [metrics.pixelAccuracy, metrics.meanPrecision, metrics.meanRecall, ...
metrics.meanF1Score, metrics.meanIoU];
bar(metrics_values * 100);
set(gca, 'XTickLabel', metrics_names, 'XTick', 1:5);
ylabel('百分比 (%)');
title('道路特征提取性能指标');
grid on;
% 在柱子上显示数值
for i = 1:length(metrics_values)
text(i, metrics_values(i)*100 + 1, sprintf('%.1f%%', metrics_values(i)*100), ...
'HorizontalAlignment', 'center', 'VerticalAlignment', 'bottom');
end
saveas(gcf, fullfile(visDir, 'performance_metrics.png'));
fprintf(' 可视化结果已保存到: %s\n', visDir);
end
2.7 实时推理 (realtime_inference.m)
matlab
function realtime_inference(model, config)
% 实时推理演示
% 输入:
% model - 训练好的模型
% config - 配置参数
fprintf(' 启动实时推理演示...\n');
% 创建视频捕获对象
vidObj = VideoReader('test_video.mp4'); % 替换为您的视频文件
if ~hasFrame(vidObj)
fprintf(' 警告:未找到视频文件,使用摄像头输入\n');
vidObj = webcam; % 使用摄像头
end
% 创建显示窗口
figure('Name', '实时道路特征提取', 'NumberTitle', 'off', 'Position', [100, 100, 1200, 600]);
while hasFrame(vidObj)
% 读取帧
if isa(vidObj, 'VideoReader')
frame = readFrame(vidObj);
else
frame = snapshot(vidObj);
end
% 调整图像大小
frame_resized = imresize(frame, config.imageSize(1:2));
% 进行预测
start_time = tic;
prediction = semanticseg(frame_resized, model);
inference_time = toc(start_time);
% 创建可视化
subplot(1, 3, 1);
imshow(frame_resized);
title('原始图像');
subplot(1, 3, 2);
imshow(label2rgb(prediction));
title('分割结果');
subplot(1, 3, 3);
% 在原图上叠加分割结果
overlay = imoverlay(frame_resized, prediction > 0, 'green');
imshow(overlay);
title(sprintf('实时检测 (%.1f FPS)', 1/inference_time));
% 显示检测到的特征数量
uniqueLabels = unique(prediction(:));
featureText = '';
for i = 1:length(uniqueLabels)
if uniqueLabels(i) > 0
className = getClassName(uniqueLabels(i));
area = sum(prediction(:) == uniqueLabels(i));
featureText = [featureText sprintf('%s: %d pixels\n', className, area)];
end
end
annotation('textbox', [0.7, 0.1, 0.25, 0.2], 'String', featureText, ...
'FitBoxToText', 'on', 'BackgroundColor', 'white', 'EdgeColor', 'black');
drawnow;
% 按ESC退出
if waitforbuttonpress
break;
end
end
% 清理
if isa(vidObj, 'webcam')
clear vidObj;
end
fprintf(' 实时推理演示结束\n');
end
function className = getClassName(label)
% 根据标签获取类别名称
switch label
case 1
className = 'Lane';
case 2
className = 'Boundary';
case 3
className = 'Sign';
otherwise
className = 'Unknown';
end
end
三、数据集准备
3.1 数据集结构
dataset/
├── images/
│ ├── 000001.jpg
│ ├── 000002.jpg
│ └── ...
├── labels/
│ ├── 000001.png
│ ├── 000002.png
│ └── ...
└── annotations/
├── 000001.txt
├── 000002.txt
└── ...
3.2 标签格式
matlab
% 创建示例标签
% 标签图像应为单通道灰度图,像素值表示类别:
% 0: 背景
% 1: 车道线
% 2: 道路边界
% 3: 交通标志
% 创建示例标签图像
labelImage = zeros(256, 256, 'uint8');
% 绘制车道线
labelImage(100:120, :) = 1; % 水平车道线
% 绘制道路边界
labelImage(:, 50:60) = 2; % 左侧边界
labelImage(:, 200:210) = 2; % 右侧边界
% 绘制交通标志
labelImage(50:70, 100:120) = 3; % 交通标志区域
% 保存标签
imwrite(labelImage, 'dataset/labels/example.png');
参考代码 CNN实现道路特征提取 www.youwenfan.com/contentcsv/79562.html
四、运行说明
4.1 快速开始
- 下载或创建道路图像数据集
- 运行
main.m开始训练 - 查看
visualization/目录中的结果
4.2 自定义配置
matlab
% 修改配置参数
config.modelType = 'deeplabv3'; % 更换模型
config.learningRate = 0.0001; % 调整学习率
config.maxEpochs = 100; % 增加训练轮数
config.imageSize = [512, 512, 3]; % 使用更大图像
4.3 使用预训练模型
matlab
% 加载预训练模型
load('./models/road_feature_extractor.mat', 'trainedModel', 'config');
% 对新图像进行预测
newImage = imread('test_road.jpg');
prediction = semanticseg(newImage, trainedModel);
% 显示结果
figure;
subplot(1,2,1), imshow(newImage), title('原始图像');
subplot(1,2,2), imshow(label2rgb(prediction)), title('分割结果');
五、性能优化
- 使用GPU加速 :确保
config.device = 'gpu' - 数据增强:增加更多数据增强策略
- 模型剪枝:减少模型参数量以提高推理速度
- 量化:使用int8量化减少模型大小
- 多尺度训练:使用不同尺寸的图像训练提高鲁棒性