LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为 AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

复制代码
pip install darts

3.步骤二

部分代码如下:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings

warnings.filterwarnings("ignore")
import logging

logging.disable(logging.CRITICAL)

####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create month and year covariate series
year_series = datetime_attribute_timeseries(
    pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
    attribute="year",
    one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
    year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))

####################构建模型##########################
my_model = RNNModel(
    model="LSTM",
    hidden_dim=20,
    dropout=0,
    batch_size=16,
    n_epochs=300,
    optimizer_kwargs={"lr": 1e-3},
    model_name="Air_RNN",
    log_tensorboard=True,
    random_state=42,
    training_length=20,
    input_chunk_length=14,
    force_reset=True,
    save_checkpoints=True,
)


my_model.fit(
    train_transformed,
    future_covariates=covariates,
    val_series=val_transformed,
    val_future_covariates=covariates,
    verbose=True,
)

完整代码下载地址:下载地址

相关推荐
2301_773553621 分钟前
怎么删除MongoDB中不再使用的账号
jvm·数据库·python
qq_342295822 分钟前
SQL报表星型模型优化_事实表索引设计
jvm·数据库·python
u0109147606 分钟前
SQL优化多表关联中的字符串连接字段_建立前缀索引提升JOIN
jvm·数据库·python
2301_7775993710 分钟前
Oracle环境下的设置主键与自增列指南_特定语法与可视化配置
jvm·数据库·python
a95114164211 分钟前
Golang怎么用go get添加依赖_Golang如何在项目中引入第三方库【入门】
jvm·数据库·python
m0_6765443818 分钟前
Golang怎么解决nil pointer错误_Golang如何排查和修复空指针引用崩溃【避坑】
jvm·数据库·python
John.Lewis26 分钟前
Python小课(6)基础语法⑤
开发语言·python
2301_7775993727 分钟前
如何优化宝塔面板的服务器内存使用_调整MySQL内存占用
jvm·数据库·python
财经资讯数据_灵砚智能28 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年4月22日
人工智能·python·信息可视化·自然语言处理·ai编程
csgo打的菜又爱玩32 分钟前
7.DispatcherResourceManagerComponentFactory解析.md
开发语言·python·flink