LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为 AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

复制代码
pip install darts

3.步骤二

部分代码如下:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings

warnings.filterwarnings("ignore")
import logging

logging.disable(logging.CRITICAL)

####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create month and year covariate series
year_series = datetime_attribute_timeseries(
    pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
    attribute="year",
    one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
    year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))

####################构建模型##########################
my_model = RNNModel(
    model="LSTM",
    hidden_dim=20,
    dropout=0,
    batch_size=16,
    n_epochs=300,
    optimizer_kwargs={"lr": 1e-3},
    model_name="Air_RNN",
    log_tensorboard=True,
    random_state=42,
    training_length=20,
    input_chunk_length=14,
    force_reset=True,
    save_checkpoints=True,
)


my_model.fit(
    train_transformed,
    future_covariates=covariates,
    val_series=val_transformed,
    val_future_covariates=covariates,
    verbose=True,
)

完整代码下载地址:下载地址

相关推荐
屋檐上的大修勾几秒前
AI算力开放-yolov8适配 mmyolo大疆无人机
开发语言·python
郑州光合科技余经理5 分钟前
开发实战:海外版同城o2o生活服务平台核心模块设计
开发语言·git·python·架构·uni-app·生活·智慧城市
Kratzdisteln5 分钟前
【Python】Flask 2
开发语言·python·flask
程序员三藏11 分钟前
单元测试详解
自动化测试·软件测试·python·测试工具·职场和发展·单元测试·测试用例
540_54015 分钟前
ADVANCE Day44
人工智能·python·深度学习
AI科技星17 分钟前
电场起源的几何革命:变化的引力场产生电场方程的第一性原理推导、验证与统一性意义
开发语言·人工智能·线性代数·算法·机器学习·数学建模
九河云20 分钟前
华为云SDRS跨Region双活:筑牢证券核心系统零中断防线
大数据·人工智能·安全·机器学习·华为云
好好学操作系统24 分钟前
flash_attn ImportError undefined symbol:
开发语言·python
CCPC不拿奖不改名24 分钟前
面向对象编程:继承与多态+面试习题
开发语言·数据结构·python·学习·面试·职场和发展
WangUnionpub25 分钟前
2026 国自然基金申请全周期时间线
大数据·人工智能·物联网·机器学习·计算机视觉