LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为 AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

复制代码
pip install darts

3.步骤二

部分代码如下:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings

warnings.filterwarnings("ignore")
import logging

logging.disable(logging.CRITICAL)

####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入

# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create month and year covariate series
year_series = datetime_attribute_timeseries(
    pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),
    attribute="year",
    one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(
    year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))

####################构建模型##########################
my_model = RNNModel(
    model="LSTM",
    hidden_dim=20,
    dropout=0,
    batch_size=16,
    n_epochs=300,
    optimizer_kwargs={"lr": 1e-3},
    model_name="Air_RNN",
    log_tensorboard=True,
    random_state=42,
    training_length=20,
    input_chunk_length=14,
    force_reset=True,
    save_checkpoints=True,
)


my_model.fit(
    train_transformed,
    future_covariates=covariates,
    val_series=val_transformed,
    val_future_covariates=covariates,
    verbose=True,
)

完整代码下载地址:下载地址

相关推荐
程序员爱钓鱼6 分钟前
Python编程实战——Python实用工具与库:Pandas数据处理
后端·python·ipython
程序员爱钓鱼10 分钟前
Python编程实战——Python实用工具与库:Numpy基础
后端·python·面试
程序员霸哥哥12 分钟前
从零搭建PyTorch计算机视觉模型
人工智能·pytorch·python·计算机视觉
机器学习之心31 分钟前
MATLAB遗传算法优化RVFL神经网络回归预测(随机函数链接神经网络)
神经网络·matlab·回归
晚秋大魔王44 分钟前
基于python的jlink单片机自动化批量烧录工具
前端·python·单片机
胖哥真不错1 小时前
Python基于PyTorch实现多输入多输出进行CNN卷积神经网络回归预测项目实战
pytorch·python·毕业设计·课程设计·毕设·多输入多输出·cnn卷积神经网络回归预测
程序员-小李1 小时前
基于PyTorch的动物识别模型训练与应用实战
人工智能·pytorch·python
闲人编程4 小时前
Python在网络安全中的应用:编写一个简单的端口扫描器
网络·python·web安全·硬件·端口·codecapsule·扫描器
大大dxy大大6 小时前
机器学习实现逻辑回归-癌症分类预测
机器学习·分类·逻辑回归
武子康6 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr