本文来源公众号**"程序员学长"**,仅用于学术分享,侵权删,干货满满。
原文链接:快速学会一个算法,RNN
今天给大家分享一个超强的算法模型,RNN
循环神经网络(Recurrent Neural Network, RNN)是一种专门用于**「处理序列数据的神经网络」**。
由于其能够处理不同长度的输入序列,并保持过去信息的能力,它广泛应用于自然语言处理、语音识别和时间序列预测等领域。
RNN的算法原理
RNN 的核心思想是使用循环的连接结构来保持对之前处理过的信息的记忆。
这种记忆通过隐藏层的状态来表达,每个时间步的隐藏状态都依赖于前一时间步的隐藏状态和当前时间步的输入。这种结构使得 RNN 能够捕获时间序列数据中的动态变化特性。
RNN 的问题
循环神经网络(RNN)虽然在处理序列数据方面具有明显优势,但在实际应用中遇到了几个关键问题,特别是梯度消失和梯度爆炸问题。这些问题直接影响了网络的训练效率和性能,进而催生了长短时记忆网络(LSTM)和门控循环单元(GRU)这两种更为高效的RNN变体。
-
梯度消失
在 RNN 中,当网络层较多或者处理的序列数据较长时,由于梯度在反向传播过程中反复乘以小于1的数(如激活函数的导数),导致梯度逐渐变小,最终接近于零。这会使得网络中的权重无法有效更新,特别是序列前端的权重,从而难以捕捉到序列中早期的重要信息。
-
梯度爆炸
与梯度消失相反,梯度爆炸是指在反向传播过程中梯度逐渐变得非常大,这通常发生在权重值较大时。梯度爆炸会导致网络权重的大幅波动,使得训练过程变得不稳定,甚至导致数值计算溢出。
-
难以捕捉长期依赖
由于梯度消失和梯度爆炸的问题,「标准的 RNN 在处理长序列时难以学习到输入序列中的长距离依赖关系」。这意味着网络难以记忆并利用序列中早期的信息来影响后续的输出,这对于许多需要理解整个输入序列上下文的任务来说是一个大问题,如语言翻译、文本生成等。
为了解决这些问题,研究者们开发了 LSTM 和 GRU 这两种特殊类型的RNN。
RNN 变体
LSTM
LSTM(Long Short-Term Memory)是一种特殊类型的循环神经网络(RNN),「专门设计用来解决传统 RNN 在处理序列数据时面临的长期依赖问题」。
LSTM 的关键特征是其维持细胞状态的能力,「细胞状态充当可以存储长序列信息的记忆单元」。这使得 LSTM 能够随着时间的推移选择性地记住或忘记信息,使它们非常适合上下文和远程依赖性至关重要的任务。
LSTM 的核心组件
LSTM 的关键在于其内部状态(cell state)和三个重要的门控机制:输入门、遗忘门和输出门。这些门控制着信息的流入、更新和流出,使 LSTM 能够在必要时保存信息跨越多个时间步,或者丢弃不再需要的信息。
GRU
门控循环单元(Gated Recurrent Unit, GRU)旨在简化长短时记忆网络(LSTM)的结构,同时保持对长期依赖信息的捕捉能力。
GRU 对比 LSTM 的主要区别在于**「其结构更简单,参数更少,这使得 GRU 在某些情况下训练更快,计算效率更高。」**
GRU的核心组件
GRU 将 LSTM 中的三个门控合并为两个门控,即更新门(update gate)和重置门(reset gate)。
这两个门控决定了信息是如何在单元间传递的,帮助网络捕捉时间序列中的长距离依赖。
案例分享
下面我们来使用 RNN、GRU 和 LSTM 进行苹果股价的预测。
python
import yfinance as yf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout,SimpleRNN,GRU
# 获取苹果公司的股票数据
data = yf.download('AAPL', start='2018-01-01', end='2023-01-01')
# 使用收盘价
close_prices = data['Close'].values.reshape(-1, 1)
# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
close_prices = scaler.fit_transform(close_prices)
# 划分数据集为训练集和测试集
split = int(0.8 * len(close_prices))
train = close_prices[:split]
test = close_prices[split:]
# 创建序列数据集
def create_dataset(data, steps):
X, y = [], []
for i in range(len(data) - steps):
X.append(data[i:(i + steps), 0])
y.append(data[i + steps, 0])
return np.array(X), np.array(y)
steps = 60
X_train, y_train = create_dataset(train, steps)
X_test, y_test = create_dataset(test, steps)
# 重塑输入以符合 RNN 模型的期望格式 [样本数, 时间步, 特征数]
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
# 构建 LSTM 模型
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)
# 构建 RNN 模型
model_rnn = Sequential()
model_rnn.add(SimpleRNN(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model_rnn.add(SimpleRNN(units=50))
model_rnn.add(Dense(1))
model_rnn.compile(optimizer='adam', loss='mean_squared_error')
model_rnn.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)
# 构建 GRU 模型
model_gru = Sequential()
model_gru.add(GRU(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model_gru.add(GRU(units=50))
model_gru.add(Dense(1))
model_gru.compile(optimizer='adam', loss='mean_squared_error')
model_gru.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)
接下来,我们来看一下预测的结果。
python
# LSTM 预测
predicted_stock_price_lstm = model.predict(X_test)
predicted_stock_price_lstm = scaler.inverse_transform(predicted_stock_price_lstm)
# RNN 预测
predicted_stock_price_rnn = model_rnn.predict(X_test)
predicted_stock_price_rnn = scaler.inverse_transform(predicted_stock_price_rnn)
# GRU 预测
predicted_stock_price_gru = model_gru.predict(X_test)
predicted_stock_price_gru = scaler.inverse_transform(predicted_stock_price_gru)
# 绘图
plt.figure(figsize=(14, 5))
plt.plot(real_stock_price, color='red', label='Real Apple Stock Price')
plt.plot(predicted_stock_price_lstm, color='blue', label='Predicted Apple Stock Price (LSTM)')
plt.plot(predicted_stock_price_rnn, color='green', label='Predicted Apple Stock Price (RNN)')
plt.plot(predicted_stock_price_gru, color='purple', label='Predicted Apple Stock Price (GRU)')
plt.title('Apple Stock Price Prediction Comparison')
plt.xlabel('Time')
plt.ylabel('Apple Stock Price')
plt.legend()
plt.show()
THE END !
文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。