LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为 AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

复制代码
pip install darts

3.步骤二

部分代码如下:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings

warnings.filterwarnings("ignore")
import logging

logging.disable(logging.CRITICAL)

####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create month and year covariate series
year_series = datetime_attribute_timeseries(
    pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
    attribute="year",
    one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
    year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))

####################构建模型##########################
my_model = RNNModel(
    model="LSTM",
    hidden_dim=20,
    dropout=0,
    batch_size=16,
    n_epochs=300,
    optimizer_kwargs={"lr": 1e-3},
    model_name="Air_RNN",
    log_tensorboard=True,
    random_state=42,
    training_length=20,
    input_chunk_length=14,
    force_reset=True,
    save_checkpoints=True,
)


my_model.fit(
    train_transformed,
    future_covariates=covariates,
    val_series=val_transformed,
    val_future_covariates=covariates,
    verbose=True,
)

完整代码下载地址:下载地址

相关推荐
机器学习之心1 小时前
贝叶斯优化+卷积神经网络+多目标优化+多属性决策!BO-CNN+NSGAII+熵权TOPSIS,附实验报告!
人工智能·神经网络·cnn·多目标优化·多属性决策
NotFound4861 小时前
c++如何通过解析二进制PE文件头检测程序是否开启了DEP保护机制【进阶】
jvm·数据库·python
zhangchaoxies1 小时前
PHP源码能否在NAS设备上运行_NAS部署PHP源码可行性【教程】
jvm·数据库·python
2301_764150562 小时前
如何在 Laravel 中正确保存嵌套动态表单数据(主服务 + 子服务)
jvm·数据库·python
2401_832635582 小时前
如何进行SQL安全基线评估_定期核对数据库安全配置
jvm·数据库·python
zhangchaoxies2 小时前
HTML怎么实现键盘操作全站导航_HTML全局快捷键说明面板【方法】
jvm·数据库·python
vegetablec2 小时前
如何用 location.reload(true) 强制浏览器从服务器刷新页面
jvm·数据库·python
2301_814809862 小时前
如何让导航栏的下落动画效果更缓慢?
jvm·数据库·python
杜子不疼.2 小时前
Python多模态AI开发指南:让AI同时理解文字、图片和语音
开发语言·人工智能·python
InfinteJustice2 小时前
如何加固SQL通信安全_启用SSL加密确保数据传输安全
jvm·数据库·python