MATLAB初学者入门(28)—— 有监督学习神经网络

有监督学习神经网络是用于执行分类和回归任务的强大工具,其中网络通过输入和目标输出对的训练集来学习数据的映射。MATLAB 提供了一个易于使用的框架,用于设计、训练和验证深度学习模型,包括多层感知器(MLP)、卷积神经网络(CNN)和循环神经网络(RNN)。

案例分析:使用 MATLAB 实现和训练一个多层感知器(MLP)进行数字识别

假设我们需要分类手写数字,这是一个典型的有监督学习问题,可以使用多层感知器(MLP)解决。

步骤 1: 准备数据

我们将使用 MATLAB 中预加载的手写数字数据集(MNIST)。

Matlab 复制代码
% 加载预置的 MNIST 数据集
[XTrain, YTrain, XTest, YTest] = digitTrain4DArrayData;
步骤 2: 定义神经网络架构

设计一个简单的 MLP,包括输入层、隐藏层和输出层。

Matlab 复制代码
layers = [
    imageInputLayer([28 28 1], 'Name', 'input', 'Normalization', 'none')

    % 第一个全连接层和ReLU激活函数
    fullyConnectedLayer(100, 'Name', 'fc1')
    reluLayer('Name', 'relu1')
    
    % 第二个全连接层和ReLU激活函数
    fullyConnectedLayer(50, 'Name', 'fc2')
    reluLayer('Name', 'relu2')

    % 输出层
    fullyConnectedLayer(10, 'Name', 'fc3')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'output')
];

% 查看网络架构
analyzeNetwork(layers);
步骤 3: 配置训练选项

设置训练算法(例如使用 SGD、Adam 等),指定迭代次数、学习率等。

Matlab 复制代码
options = trainingOptions('adam', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 10, ...
    'MiniBatchSize', 128, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress');
步骤 4: 训练神经网络

使用准备好的数据和配置训练神经网络。

Matlab 复制代码
net = trainNetwork(XTrain, YTrain, layers, options);
步骤 5: 评估网络性能

在测试集上评估训练好的网络性能。

Matlab 复制代码
YPred = classify(net, XTest);
accuracy = sum(YPred == YTest) / numel(YTest);
disp(['Test Accuracy: ', num2str(accuracy)]);

案例分析:使用MATLAB实现卷积神经网络(CNN)进行图像分类

假设我们的任务是分类来自一个更复杂的图像数据集,例如CIFAR-10,这是一个常用的包含60000张32x32彩色图像的数据集,涵盖10个类别。

步骤 1: 准备数据

加载CIFAR-10数据集,并进行适当的预处理。

Matlab 复制代码
[XTrain, YTrain, XTest, YTest] = cifar10Data;

% 数据预处理
XTrain = rescale(XTrain);  % 归一化
XTest = rescale(XTest);
步骤 2: 定义卷积神经网络架构

为CIFAR-10数据集设计一个适当的CNN结构。

Matlab 复制代码
layers = [
    imageInputLayer([32 32 3], 'Name', 'input')
    
    convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1')
    batchNormalizationLayer('Name', 'bn1')
    reluLayer('Name', 'relu1')
    
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')
    
    convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
    batchNormalizationLayer('Name', 'bn2')
    reluLayer('Name', 'relu2')
    
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
    
    convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv3')
    reluLayer('Name', 'relu3')
    
    fullyConnectedLayer(64, 'Name', 'fc1')
    dropoutLayer(0.5, 'Name', 'dropout1')
    fullyConnectedLayer(10, 'Name', 'fc2')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'output')
];

% 查看网络架构
analyzeNetwork(layers);
步骤 3: 配置训练选项

设置训练参数,如优化器、学习率、批次大小等。

Matlab 复制代码
options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 30, ...
    'MiniBatchSize', 64, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 10, ...
    'Verbose', true, ...
    'Plots', 'training-progress');
步骤 4: 训练网络

训练卷积神经网络。

Matlab 复制代码
net = trainNetwork(XTrain, YTrain, layers, options);
步骤 5: 评估网络性能

在测试集上评估训练好的网络性能,计算准确率。

Matlab 复制代码
YPred = classify(net, XTest);
accuracy = mean(YPred == YTest);
disp(['Test Accuracy: ', num2str(accuracy)]);

案例分析:使用MATLAB实现LSTM网络进行时间序列预测

假设我们要预测金融市场的未来趋势,这是一个典型的时间序列预测问题,可以通过使用LSTM网络来解决。

步骤 1: 准备数据

对于时间序列预测任务,首先需要准备和预处理数据,包括标准化和创建适合于LSTM训练的数据结构。

Matlab 复制代码
% 假设已有加载数据
load exampleFinancialSeries.mat
data = DataTable.Price;

% 数据标准化
data = (data - mean(data)) / std(data);

% 创建时间序列训练数据
numTimeStepsTrain = floor(0.9 * numel(data));
dataTrain = data(1:numTimeStepsTrain+1);
dataTest = data(numTimeStepsTrain+1:end);

% 准备 LSTM 输入
XTrain = dataTrain(1:end-1);
YTrain = dataTrain(2:end);
步骤 2: 定义LSTM网络架构

创建一个包含LSTM层的网络架构,适用于时间序列数据的特征。

Matlab 复制代码
layers = [
    sequenceInputLayer(1, 'Name', 'input')
    lstmLayer(50, 'OutputMode', 'sequence', 'Name', 'lstm')
    fullyConnectedLayer(1, 'Name', 'fc')
    regressionLayer('Name', 'output')
];

% 查看网络架构
analyzeNetwork(layers);
步骤 3: 配置训练选项

设置训练参数,确保模型在训练时的效率和效果。

Matlab 复制代码
options = trainingOptions('adam', ...
    'MaxEpochs', 100, ...
    'MiniBatchSize', 20, ...
    'GradientThreshold', 1, ...
    'InitialLearnRate', 0.005, ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropPeriod', 125, ...
    'LearnRateDropFactor', 0.2, ...
    'Verbose', 0, ...
    'Plots', 'training-progress');
步骤 4: 训练LSTM网络

使用配置的参数和数据训练网络。

Matlab 复制代码
net = trainNetwork(XTrain', YTrain', layers, options);
步骤 5: 评估网络性能

使用训练好的网络在测试集上进行预测,并评估其预测性能。

Matlab 复制代码
net = predictAndUpdateState(net, XTrain');
[net, YPred] = predictAndUpdateState(net, YTrain(end));

% 预测未来步骤
numFutureSteps = 20;
for i = 2:numFutureSteps
    [net, YPred(:, i)] = predictAndUpdateState(net, YPred(:, i-1), 'ExecutionEnvironment', 'cpu');
end

% 可视化预测结果
figure;
subplot(2,1,1);
plot(dataTrain(end-100:end));
hold on;
idx = numTimeStepsTrain:(numTimeStepsTrain+numFutureSteps);
plot(idx, [data(numTimeStepsTrain) YPred], '.-');
hold off;
legend(["Observed" "Forecast"]);
title("Forecast");
ylabel("Cases");
xlabel("Month");

结论

(1)设计并训练了一个基本的多层感知器(MLP)来识别手写数字。这个过程展示了使用 MATLAB 进行神经网络训练的完整流程,包括数据预处理、网络架构设计、训练配置设置以及性能评估。在实际应用中,网络的性能大量依赖于所选的架构、训练算法和超参数的调整。更深的网络或更复杂的结构(如卷积神经网络)可能会在处理图像或序列数据时表现更好。MATLAB 的深度学习工具箱提供了强大的工具和函数,帮助研究人员和工程师优化这些参数,以实现更高效和精准的模型。

(2)卷积神经网络(CNN)是图像分类任务中的黄金标准,能够有效地从图像数据中学习高级特征。通过MATLAB的深度学习工具箱,我们可以轻松设计、训练并验证CNN模型。在设计CNN时,层数、过滤器大小、批归一化和Dropout等都是重要的因素,需要根据具体任务进行调整。此外,实际应用中可能还需要处理过拟合、调整学习率和使用数据增强等问题来进一步提高模型的泛化能力和性能。针对特定的应用,如视频分析或自然语言处理,我们还可以探索使用循环神经网络(RNN)或其变体,如LSTM和GRU,这些网络结构特别适用于处理序列数据。

(3)LSTM网络是解决复杂时间序列预测问题的有效工具,能够学习和记住长期依赖关系。通过MATLAB的深度学习工具箱,我们可以轻松设计、训练并评估这样的网络。在实际应用中,LSTM的参数调整对模型的性能至关重要,可能需要多次实验以找到最优的网络结构和训练配置。此外,对于更复杂的序列预测任务,可以考虑使用更高级的LSTM变体或其他类型的循环网络。

相关推荐
菜狗woc1 分钟前
opencv-python的简单练习
人工智能·python·opencv
十年一梦实验室5 分钟前
【C++】sophus : sim_details.hpp 实现了矩阵函数 W、其导数,以及其逆 (十七)
开发语言·c++·线性代数·矩阵
15年网络推广青哥5 分钟前
国际抖音TikTok矩阵运营的关键要素有哪些?
大数据·人工智能·矩阵
最爱番茄味14 分钟前
Python实例之函数基础打卡篇
开发语言·python
weixin_3875456424 分钟前
探索 AnythingLLM:借助开源 AI 打造私有化智能知识库
人工智能
Oneforlove_twoforjob1 小时前
【Java基础面试题033】Java泛型的作用是什么?
java·开发语言
engchina1 小时前
如何在 Python 中忽略烦人的警告?
开发语言·人工智能·python
向宇it1 小时前
【从零开始入门unity游戏开发之——C#篇24】C#面向对象继承——万物之父(object)、装箱和拆箱、sealed 密封类
java·开发语言·unity·c#·游戏引擎
诚丞成2 小时前
计算世界之安生:C++继承的文水和智慧(上)
开发语言·c++
Smile灬凉城6662 小时前
反序列化为啥可以利用加号绕过php正则匹配
开发语言·php