MATLAB初学者入门(28)—— 有监督学习神经网络

有监督学习神经网络是用于执行分类和回归任务的强大工具,其中网络通过输入和目标输出对的训练集来学习数据的映射。MATLAB 提供了一个易于使用的框架,用于设计、训练和验证深度学习模型,包括多层感知器(MLP)、卷积神经网络(CNN)和循环神经网络(RNN)。

案例分析:使用 MATLAB 实现和训练一个多层感知器(MLP)进行数字识别

假设我们需要分类手写数字,这是一个典型的有监督学习问题,可以使用多层感知器(MLP)解决。

步骤 1: 准备数据

我们将使用 MATLAB 中预加载的手写数字数据集(MNIST)。

Matlab 复制代码
% 加载预置的 MNIST 数据集
[XTrain, YTrain, XTest, YTest] = digitTrain4DArrayData;
步骤 2: 定义神经网络架构

设计一个简单的 MLP,包括输入层、隐藏层和输出层。

Matlab 复制代码
layers = [
    imageInputLayer([28 28 1], 'Name', 'input', 'Normalization', 'none')

    % 第一个全连接层和ReLU激活函数
    fullyConnectedLayer(100, 'Name', 'fc1')
    reluLayer('Name', 'relu1')
    
    % 第二个全连接层和ReLU激活函数
    fullyConnectedLayer(50, 'Name', 'fc2')
    reluLayer('Name', 'relu2')

    % 输出层
    fullyConnectedLayer(10, 'Name', 'fc3')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'output')
];

% 查看网络架构
analyzeNetwork(layers);
步骤 3: 配置训练选项

设置训练算法(例如使用 SGD、Adam 等),指定迭代次数、学习率等。

Matlab 复制代码
options = trainingOptions('adam', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 10, ...
    'MiniBatchSize', 128, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress');
步骤 4: 训练神经网络

使用准备好的数据和配置训练神经网络。

Matlab 复制代码
net = trainNetwork(XTrain, YTrain, layers, options);
步骤 5: 评估网络性能

在测试集上评估训练好的网络性能。

Matlab 复制代码
YPred = classify(net, XTest);
accuracy = sum(YPred == YTest) / numel(YTest);
disp(['Test Accuracy: ', num2str(accuracy)]);

案例分析:使用MATLAB实现卷积神经网络(CNN)进行图像分类

假设我们的任务是分类来自一个更复杂的图像数据集,例如CIFAR-10,这是一个常用的包含60000张32x32彩色图像的数据集,涵盖10个类别。

步骤 1: 准备数据

加载CIFAR-10数据集,并进行适当的预处理。

Matlab 复制代码
[XTrain, YTrain, XTest, YTest] = cifar10Data;

% 数据预处理
XTrain = rescale(XTrain);  % 归一化
XTest = rescale(XTest);
步骤 2: 定义卷积神经网络架构

为CIFAR-10数据集设计一个适当的CNN结构。

Matlab 复制代码
layers = [
    imageInputLayer([32 32 3], 'Name', 'input')
    
    convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1')
    batchNormalizationLayer('Name', 'bn1')
    reluLayer('Name', 'relu1')
    
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')
    
    convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
    batchNormalizationLayer('Name', 'bn2')
    reluLayer('Name', 'relu2')
    
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
    
    convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv3')
    reluLayer('Name', 'relu3')
    
    fullyConnectedLayer(64, 'Name', 'fc1')
    dropoutLayer(0.5, 'Name', 'dropout1')
    fullyConnectedLayer(10, 'Name', 'fc2')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'output')
];

% 查看网络架构
analyzeNetwork(layers);
步骤 3: 配置训练选项

设置训练参数,如优化器、学习率、批次大小等。

Matlab 复制代码
options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 30, ...
    'MiniBatchSize', 64, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 10, ...
    'Verbose', true, ...
    'Plots', 'training-progress');
步骤 4: 训练网络

训练卷积神经网络。

Matlab 复制代码
net = trainNetwork(XTrain, YTrain, layers, options);
步骤 5: 评估网络性能

在测试集上评估训练好的网络性能,计算准确率。

Matlab 复制代码
YPred = classify(net, XTest);
accuracy = mean(YPred == YTest);
disp(['Test Accuracy: ', num2str(accuracy)]);

案例分析:使用MATLAB实现LSTM网络进行时间序列预测

假设我们要预测金融市场的未来趋势,这是一个典型的时间序列预测问题,可以通过使用LSTM网络来解决。

步骤 1: 准备数据

对于时间序列预测任务,首先需要准备和预处理数据,包括标准化和创建适合于LSTM训练的数据结构。

Matlab 复制代码
% 假设已有加载数据
load exampleFinancialSeries.mat
data = DataTable.Price;

% 数据标准化
data = (data - mean(data)) / std(data);

% 创建时间序列训练数据
numTimeStepsTrain = floor(0.9 * numel(data));
dataTrain = data(1:numTimeStepsTrain+1);
dataTest = data(numTimeStepsTrain+1:end);

% 准备 LSTM 输入
XTrain = dataTrain(1:end-1);
YTrain = dataTrain(2:end);
步骤 2: 定义LSTM网络架构

创建一个包含LSTM层的网络架构,适用于时间序列数据的特征。

Matlab 复制代码
layers = [
    sequenceInputLayer(1, 'Name', 'input')
    lstmLayer(50, 'OutputMode', 'sequence', 'Name', 'lstm')
    fullyConnectedLayer(1, 'Name', 'fc')
    regressionLayer('Name', 'output')
];

% 查看网络架构
analyzeNetwork(layers);
步骤 3: 配置训练选项

设置训练参数,确保模型在训练时的效率和效果。

Matlab 复制代码
options = trainingOptions('adam', ...
    'MaxEpochs', 100, ...
    'MiniBatchSize', 20, ...
    'GradientThreshold', 1, ...
    'InitialLearnRate', 0.005, ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropPeriod', 125, ...
    'LearnRateDropFactor', 0.2, ...
    'Verbose', 0, ...
    'Plots', 'training-progress');
步骤 4: 训练LSTM网络

使用配置的参数和数据训练网络。

Matlab 复制代码
net = trainNetwork(XTrain', YTrain', layers, options);
步骤 5: 评估网络性能

使用训练好的网络在测试集上进行预测,并评估其预测性能。

Matlab 复制代码
net = predictAndUpdateState(net, XTrain');
[net, YPred] = predictAndUpdateState(net, YTrain(end));

% 预测未来步骤
numFutureSteps = 20;
for i = 2:numFutureSteps
    [net, YPred(:, i)] = predictAndUpdateState(net, YPred(:, i-1), 'ExecutionEnvironment', 'cpu');
end

% 可视化预测结果
figure;
subplot(2,1,1);
plot(dataTrain(end-100:end));
hold on;
idx = numTimeStepsTrain:(numTimeStepsTrain+numFutureSteps);
plot(idx, [data(numTimeStepsTrain) YPred], '.-');
hold off;
legend(["Observed" "Forecast"]);
title("Forecast");
ylabel("Cases");
xlabel("Month");

结论

(1)设计并训练了一个基本的多层感知器(MLP)来识别手写数字。这个过程展示了使用 MATLAB 进行神经网络训练的完整流程,包括数据预处理、网络架构设计、训练配置设置以及性能评估。在实际应用中,网络的性能大量依赖于所选的架构、训练算法和超参数的调整。更深的网络或更复杂的结构(如卷积神经网络)可能会在处理图像或序列数据时表现更好。MATLAB 的深度学习工具箱提供了强大的工具和函数,帮助研究人员和工程师优化这些参数,以实现更高效和精准的模型。

(2)卷积神经网络(CNN)是图像分类任务中的黄金标准,能够有效地从图像数据中学习高级特征。通过MATLAB的深度学习工具箱,我们可以轻松设计、训练并验证CNN模型。在设计CNN时,层数、过滤器大小、批归一化和Dropout等都是重要的因素,需要根据具体任务进行调整。此外,实际应用中可能还需要处理过拟合、调整学习率和使用数据增强等问题来进一步提高模型的泛化能力和性能。针对特定的应用,如视频分析或自然语言处理,我们还可以探索使用循环神经网络(RNN)或其变体,如LSTM和GRU,这些网络结构特别适用于处理序列数据。

(3)LSTM网络是解决复杂时间序列预测问题的有效工具,能够学习和记住长期依赖关系。通过MATLAB的深度学习工具箱,我们可以轻松设计、训练并评估这样的网络。在实际应用中,LSTM的参数调整对模型的性能至关重要,可能需要多次实验以找到最优的网络结构和训练配置。此外,对于更复杂的序列预测任务,可以考虑使用更高级的LSTM变体或其他类型的循环网络。

相关推荐
残月只会敲键盘4 分钟前
面相小白的php反序列化漏洞原理剖析
开发语言·php
ac-er88887 分钟前
PHP弱类型安全问题
开发语言·安全·php
ac-er88887 分钟前
PHP网络爬虫常见的反爬策略
开发语言·爬虫·php
翔云API11 分钟前
PHP静默活体识别API接口应用场景与集成方案
人工智能
爱吃喵的鲤鱼17 分钟前
linux进程的状态之环境变量
linux·运维·服务器·开发语言·c++
浊酒南街18 分钟前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)4.9-4.10
人工智能·深度学习·神经网络·cnn
Tony聊跨境33 分钟前
独立站SEO类型及优化:来检查这些方面你有没有落下
网络·人工智能·tcp/ip·ip
懒惰才能让科技进步39 分钟前
从零学习大模型(十二)-----基于梯度的重要性剪枝(Gradient-based Pruning)
人工智能·深度学习·学习·算法·chatgpt·transformer·剪枝
DARLING Zero two♡43 分钟前
关于我、重生到500年前凭借C语言改变世界科技vlog.16——万字详解指针概念及技巧
c语言·开发语言·科技
Gu Gu Study1 小时前
【用Java学习数据结构系列】泛型上界与通配符上界
java·开发语言