LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为 AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

复制代码
pip install darts

3.步骤二

部分代码如下:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings

warnings.filterwarnings("ignore")
import logging

logging.disable(logging.CRITICAL)

####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create month and year covariate series
year_series = datetime_attribute_timeseries(
    pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
    attribute="year",
    one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
    year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))

####################构建模型##########################
my_model = RNNModel(
    model="LSTM",
    hidden_dim=20,
    dropout=0,
    batch_size=16,
    n_epochs=300,
    optimizer_kwargs={"lr": 1e-3},
    model_name="Air_RNN",
    log_tensorboard=True,
    random_state=42,
    training_length=20,
    input_chunk_length=14,
    force_reset=True,
    save_checkpoints=True,
)


my_model.fit(
    train_transformed,
    future_covariates=covariates,
    val_series=val_transformed,
    val_future_covariates=covariates,
    verbose=True,
)

完整代码下载地址:下载地址

相关推荐
哥布林学者10 分钟前
高光谱拼接算法(一)扫推式成像和航带拼接算法
机器学习·高光谱成像
X1A0RAN1 小时前
解决Pycharm中部分文件或文件夹被隐藏不展示问题
ide·python·pycharm
MomentYY1 小时前
第 3 篇:让 Agent 学会分工,LangGraph 构建多 Agent系统
人工智能·python·agent
程序员Jelena1 小时前
Python 代码是什么?—— 从字节到执行的完整解析
python
测试员周周1 小时前
【Appium 系列】第13节-混合测试执行器 — API + UI 的协同执行
开发语言·人工智能·python·功能测试·ui·appium·pytest
malog_2 小时前
大语言模型后训练全解析
人工智能·深度学习·机器学习·ai·语言模型
用户8356290780512 小时前
Python 操作 PowerPoint OLE 对象
后端·python
小江的记录本2 小时前
【Java基础】Java 8-21新特性:JDK21 LTS:虚拟线程、模式匹配switch、结构化并发、序列集合(附《思维导图》+《面试高频考点清单》)
java·数据库·python·mysql·spring·面试·maven
枫叶林FYL2 小时前
【强化学习】3 双系统持续强化学习:快速迁移与元知识整合架构手册
人工智能·机器学习·架构
张登杰踩3 小时前
DINOv2 with Registers 系列模型详解:Giant 版本规格、Register Token 机制与使用指南
python·numpy