基于CNN-LSTM的时序预测MATLAB实战

卷积神经网络(CNN)用于提取时间序列数据中的局部空间特征,通过卷积层和池化层的堆叠,CNN能够有效捕获数据中的短期模式和局部依赖关系。长短时记忆网络(LSTM)用于处理时间序列数据,特别擅长捕捉数据中的长期依赖关系,LSTM通过引入门控机制和记忆单元来解决长期依赖问题。基于CNN-LSTM的时序预测结合上述两种网络的优点,同时考虑了时空特征,模型通常能够获得更高的预测精度。

一、算法原理

1.1 CNN原理

卷积神经网络具有局部连接、权值共享和空间相关等特性。卷积神经网络结构包含卷积层、激活层和池化层。

(a)二维卷积层将滑动卷积滤波器应用于输入。该层通过沿输入垂直和水平方向 移动滤波器对输入进行卷积,并计算权重与输入的点积,然后加入一个偏置项。具体表达式为:

卷积层的功能是对输入数据进行特征提取,其内部包含多个卷积核,也称为感受野。将输入图像和卷积核做卷积运算,可以增强原始信号特征的同时降低噪声,卷积运算的具体过程如图所示。

卷积运算的具体过程

(2)激活函数

在卷积神经网络中,常用的激活函数包括 Sigmoid 函数、Tanh 函数、Swish 函数和 Relu 函数。Relu 函数解决了 Sigmoid 函数和 Tanh 函数梯度消失的问题, 提高了模型收敛的速度,受到的广泛学者的欢迎。

(3)池化层

池化层又称为下采样层,池化层分为平均池化层和最大池化层。其中,最大池化层通过将输入分为矩形池化区域,并计算每个区域的最大值来执行下采样, 而平均池化层则是计算池化区域的平均值来执行下采样。池化层的池化过程如图所示。

池化过程示意图

1.2 LSTM原理

LSTM采用循环神经网络( Recurrent Neural Network,RNN )架构[8],它是专门为从序列中学习长期依赖关系而设计的。LSTM可以使用4个组件:输入门、输出门、遗忘门和具有自循环连接的单元来移除或添加块状态的信息。

LSTM网络结构

设输入序列共有 k 个时间步,LSTM 门控机制 结构为遗忘门、输入门和输出门,xt携带网络输入值 作为向量引入系统,ht 通过隐含层对 LSTM 细胞进 行输出,ct携带着 LSTM 细胞状态进行运算。LSTM 运算规则如下:

计算后保留 ct与 ht,用于下一时间步的计算;最后一步计算完成后,将隐藏层向量 hk作为输出与本组序列对应的预测值对比,得出损失函数值,依据梯度下降算法,优化权重和偏置参数,以此训练出迭代次数范围内最精确的网络参数。

1.3 CNN-LSTM框架

以时序预测为例,本次使用的CNN-LSTM框架如图所示。

CNN-LSTM框架

二、具体模型框架

通过layerGraph来组合CNN和LSTM层,并使用trainNetwork函数进行模型训练,使用Dropout技术来减少模型的过拟合现象。

对LSTM参数进行设置,确定数据输入的特征维度,即时间步长的特征数量。numhidden_units1numhidden_units2numhidden_units3分别代表不同LSTM层中的隐含单元数。例如,可以设置为50、20、100,表示不同层的神经元数量。Dropout层用于减少过拟合,例如设置为0.3,表示在训练过程中随机丢弃30%的神经元。

使用Adam优化算法进行训练,这是一种改进的梯度优化算法,能够解决传统SGD算法的一些问题。训练完成后使用测试数据集进行预测,并采用如均方误差(MSE)、均方根误差(RMSE)、平均绝对误差(MAE)等评价指标来定量评估模型性能。

Matlab 复制代码
clc
clear
load('Train.mat')
load('Test.mat')

% LSTM 层设置,参数设置
inputSize = size(Train_xNorm{1},1);   %数据输入x的特征维度
outputSize = 1;  %数据输出y的维度  
numhidden_units1=50;
numhidden_units2= 20;
numhidden_units3=100;
%
opts = trainingOptions('adam', ...
    'MaxEpochs',10, ...
    'GradientThreshold',1,...
    'ExecutionEnvironment','cpu',...
    'InitialLearnRate',0.001, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropPeriod',2, ...   %2个epoch后学习率更新
    'LearnRateDropFactor',0.5, ...
    'Shuffle','once',...  % 时间序列长度
    'SequenceLength',k,...
    'MiniBatchSize',24,...
    'Verbose',0);
%% lstm

layers = [ ...
    
    sequenceInputLayer([inputSize,1,1],'name','input')   %输入层设置
    sequenceFoldingLayer('name','fold')
    convolution2dLayer([2,1],10,'Stride',[1,1],'name','conv1')
    batchNormalizationLayer('name','batchnorm1')
    reluLayer('name','relu1')
    maxPooling2dLayer([1,3],'Stride',1,'Padding','same','name','maxpool')
    sequenceUnfoldingLayer('name','unfold')
    flattenLayer('name','flatten')
    lstmLayer(numhidden_units1,'Outputmode','sequence','name','hidden1') 
    dropoutLayer(0.3,'name','dropout_1')
    lstmLayer(numhidden_units2,'Outputmode','last','name','hidden2') 
    dropoutLayer(0.3,'name','drdiopout_2')
    fullyConnectedLayer(outputSize,'name','fullconnect')   % 全连接层设置(影响输出维度)(cell层出来的输出层) %
    tanhLayer('name','softmax')
    regressionLayer('name','output')];

lgraph = layerGraph(layers)
lgraph = connectLayers(lgraph,'fold/miniBatchSize','unfold/miniBatchSize');
plot(lgraph)
%
% 网络训练
tic

net = trainNetwork(Train_xNorm,Train_yNorm,lgraph,opts);

%% 测试
figure
Predict_Ynorm = net.predict(Test_xNorm);
Predict_Y  = mapminmax('reverse',Predict_Ynorm',yopt);
Predict_Y = Predict_Y';

plot(Predict_Y,'g-')
hold on 

plot(Test_y);    
legend('预测值','实际值')

通过比较预测值和实际值可以评估模型的性能,通常使用图表来展示预测结果与实际结果的对比:

基于CNN-LSTM的时序预测模型结合了CNN的特征提取能力和LSTM的时序建模能力,使其在处理具有复杂空间和时间依赖性的时间序列数据方面表现出色。在MATLAB中实现这一模型,可以利用其强大的内置函数和深度学习工具箱,进行有效的模型构建、训练和评估。

相关推荐
秀儿还能再秀27 分钟前
神经网络(系统性学习四):深度学习——卷积神经网络(CNN)
人工智能·深度学习·机器学习·cnn·学习笔记
studyer_domi3 小时前
matlab蜗轮蜗杆设计优化问题
开发语言·matlab
铖铖的花嫁3 小时前
基于CNN+RNNs(LSTM, GRU)的红点位置检测(pytorch)
cnn·gru·lstm
爱研究的小牛4 小时前
AIVA 技术浅析(四):捕捉音乐作品中的长期依赖关系
人工智能·rnn·深度学习·aigc·lstm
micro_xx5 小时前
Matlab 深度学习工具箱 案例学习与测试————求二阶微分方程
深度学习·学习·matlab
AI浩5 小时前
ShuffleNet:一种为移动设备设计的极致高效的卷积神经网络
人工智能·神经网络·cnn
Evand J9 小时前
【MATLAB蓝牙定位代码】三维平面定位设计,通过N个蓝牙锚点实现对未知位置的精准定位
开发语言·matlab·平面
Matlab程序猿小助手11 小时前
【MATLAB源码-第222期】基于matlab的改进蚁群算法三维栅格地图路径规划,加入精英蚁群策略。包括起点终点,障碍物,着火点,楼梯。
开发语言·人工智能·算法·matlab·机器人·无人机
卧式纯绿11 小时前
自动驾驶3D目标检测综述(三)
人工智能·python·深度学习·目标检测·3d·cnn·自动驾驶