神经网络11-TFT模型的简单示例

Temporal Fusion Transformer (TFT) 是一种用于时间序列预测的深度学习模型,它结合了Transformer架构的优点和专门为时间序列设计的一些优化技术。TFT尤其擅长处理多变量时间序列数据,并且能够捕捉到长期依赖关系,同时通过自注意力机制有效地处理时序特征。TFT的工作原理主要由以下几个部分组成:

1. 输入数据处理

  • 输入特征:TFT的输入是一个多变量时间序列,每个样本包含多个特征(如10个特征,每个特征有240个时间步)。每个时间步的特征值可以是连续的(如温度、股价等),也可以是分类的(如星期几、节假日等)。
  • 静态和时间序列特征:TFT区分了静态特征(例如个体ID、地点)和动态特征(例如时间步上的温度)。静态特征在模型中用于增强个体的预测性能,而动态特征则帮助模型捕捉到随时间变化的模式。

2. Encoder-Decoder架构

TFT采用了编码器-解码器(Encoder-Decoder)的架构,这个架构原本用于序列到序列的任务(如机器翻译),但是在TFT中做了调整:

  • 编码器(Encoder):输入的时间序列通过编码器进行处理,编码器包括一个由自注意力机制和GRU(门控循环单元)组成的结构。自注意力机制能够帮助模型捕捉不同时间步之间的依赖关系,而GRU有助于捕捉短期的时间依赖性。
  • 解码器(Decoder):解码器根据编码器输出的特征以及未来的已知信息来生成预测。解码器可以直接预测下一个时间步的值。

3. 自注意力机制(Self-Attention)

自注意力机制在TFT中用于捕捉时间序列中各个时间步之间的长短期依赖关系。它通过计算每个时间步和其他时间步之间的相关性,自动地给出不同时间步的权重。这样,模型可以根据时间序列中的重要性自适应地调整权重。

4. 门控机制(Gating Mechanisms)

TFT采用了多个门控机制(例如:GRN(Gated Residual Network)和变量选择网络)来控制信息流,避免不必要的计算,并且使得模型更加灵活:

  • 变量选择网络:自动选择哪些输入特征对于当前时间步的预测更为重要,从而提升了模型的性能和可解释性。
  • 门控残差单元(GRN):通过加权处理动态特征和静态特征来提供更丰富的信息。

5. 多尺度时间步(Multi-Scale Temporal Fusion)

TFT不仅利用了全局时间步的特征,还通过多尺度处理能够同时捕捉长期和短期的模式。例如,通过多个不同的时间尺度对历史信息进行融合,从而提升了模型的预测精度。

6. 预测头(Forecasting Head)

在解码器的顶部,TFT有一个预测头,它根据模型输出的时间序列特征来进行实际的预测。它生成一个未来时间步的预测值,通常用于回归任务或二分类任务。

7. 自解释性(Interpretability)

TFT具有一定的可解释性,特别是通过注意力机制变量选择网络,可以分析哪些特征对于模型的预测最重要。这对于需要模型透明度和决策依据的场景非常有用。

8.简单例子:预测股票价格

假设我们有一个简单的时间序列数据集,包含时间步(例如每天的股票价格),并且我们希望预测未来几天的股票价格。假设我们有以下结构的时间序列数据:

  • 每天的股票价格(连续特征)。
  • 每天的交易量(连续特征)。
  • 每天的节假日信息(分类特征,例如是否为节假日)。

我们的目标是基于过去的几个时间步的数据预测未来的股票价格。

模型步骤

  1. 数据预处理 :我们将数据准备为适合TFT的格式。TFT需要有时间步特征静态特征目标变量
  2. 模型构建:使用TFT模型进行训练和预测。
  3. 预测:用训练好的模型预测未来的股票价格。

示例代码

假设我们使用pytorch-forecasting这个库来实现TFT模型。这个库为时间序列任务提供了简化的API。

安装必要的库

首先,安装相关库:

bash 复制代码
pip install pytorch-forecasting pytorch-lightning

构建简单的TFT模型

python 复制代码
import pandas as pd
import numpy as np
import torch
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting import Baseline

# 生成模拟数据
np.random.seed(42)

# 创建一个简单的时间序列数据集
n_samples = 1000
time_idx = np.tile(np.arange(1, 101), n_samples // 100)  # 100天的周期重复n_samples次
stock_price = np.sin(time_idx * 0.1) + np.random.normal(0, 0.1, len(time_idx))  # 模拟的股票价格
volume = np.random.normal(1000, 100, len(time_idx))  # 模拟的交易量
is_holiday = (time_idx % 7 == 0).astype(int)  # 假设每7天是一个节假日

# 创建DataFrame
data = pd.DataFrame({
    "time_idx": time_idx,
    "stock_price": stock_price,
    "volume": volume,
    "is_holiday": is_holiday
})

# 创建训练和验证集
max_encoder_length = 60  # 用60个时间步预测未来
max_prediction_length = 10  # 预测未来10个时间步

# 将数据转换为TimeSeriesDataSet
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= 80],  # 训练集使用前80天的数据
    time_idx="time_idx",
    target="stock_price",
    group_ids=["is_holiday"],  # 对应的静态特征是节假日
    min_encoder_length=max_encoder_length,
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["is_holiday"],
    time_varying_known_reals=["stock_price", "volume"],  # 动态已知特征
    time_varying_unknown_reals=["stock_price"],  # 动态目标变量
    target_normalizer=GroupNormalizer(groups=["is_holiday"], transformation="softplus"),  # 归一化
)

# 创建数据加载器
train_dataloader = torch.utils.data.DataLoader(training, batch_size=64, shuffle=True)

# 初始化TFT模型
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=16,  # 隐藏层单元数
    attention_head_size=4,  # 自注意力头数
    dropout=0.1,  # Dropout概率
    hidden_continuous_size=8,  # 连续特征的隐藏层大小
    output_size=1,  # 输出预测的维度
    loss= torch.nn.MSELoss(),  # 使用均方误差损失
)

# 训练模型
import pytorch_lightning as pl
trainer = pl.Trainer(max_epochs=10, gpus=0)  # 设置epochs
trainer.fit(tft, train_dataloader)

# 进行预测
test_data = data[lambda x: x.time_idx > 80]  # 测试集使用剩余的数据
test_dataset = TimeSeriesDataSet.from_dataset(training, test_data)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

predictions = tft.predict(test_dataloader, mode="raw")

代码解析

1.数据生成

我们使用sin函数加上随机噪声来模拟股票价格,并且生成随机的交易量和节假日标记。最终生成的股票数据如下所示:

2.TimeSeriesDataSet

这是pytorch-forecasting库中的一个数据类,它帮助我们将数据集转化为模型能够理解的格式。我们指定了哪些是时间步特征、静态特征和目标变量,并且定义了时间步长度(60步历史数据用于预测未来10步)。

详细解析:

  1. data[lambda x: x.time_idx <= 80]

    • data 是一个包含时间序列数据的 DataFrame。
    • lambda x: x.time_idx <= 80 是一个条件筛选器,用于选择时间索引 (time_idx) 小于或等于 80 的数据。这样,只会使用前 80 天的数据来训练模型。
  2. time_idx="time_idx"

    • time_idx 指定了表示时间步的列名。在这个例子中,time_idx 表示时间索引(通常是从 1 到 n 的整数,表示每个时间步)。
  3. target="stock_price"

    • target 参数指定了模型要预测的目标变量。在这个例子中,目标变量是 stock_price(股票价格)。
  4. group_ids=["is_holiday"]

    • group_ids 用于指定一个或多个分组特征,这些特征将用于区分不同的时间序列或分组数据。在这里,is_holiday 表示节假日(通常是一个静态变量,指示每个时间步是否为假期)。它会告诉模型如何对待不同的节假日数据。
  5. min_encoder_length=max_encoder_lengthmax_encoder_length=max_encoder_length

    • min_encoder_lengthmax_encoder_length 指定了输入序列(编码器输入)的最小和最大长度。它们用于告诉模型每个输入序列的时间步数。
    • 这两个参数的值相等,表示使用固定长度的历史数据,假设为 max_encoder_length。例如,max_encoder_length=10 表示使用过去的 10 天数据进行预测。
    • max_encoder_lengthmin_encoder_length 应该是整数,表示时间序列的长度。
  6. min_prediction_length=max_prediction_lengthmax_prediction_length=max_prediction_length

    • min_prediction_lengthmax_prediction_length 是预测的时间步长,指定模型需要预测多少步。
    • 这两个参数的值也相等,表示预测的步长是固定的。例如,max_prediction_length=5 表示模型需要预测接下来 5 天的股票价格。
  7. static_categoricals=["is_holiday"]

    • static_categoricals 表示静态类别特征(即在整个时间序列中不变化的特征)。这里 is_holiday 是一个静态类别特征,指示每个时间步是否是节假日。
  8. time_varying_known_reals=["stock_price", "volume"]

    • time_varying_known_reals 指定了动态已知特征,这些特征在时间序列中会随时间变化,并且在训练时已知。在这个例子中,stock_price(股票价格)和 volume(交易量)是动态已知特征。
  9. time_varying_unknown_reals=["stock_price"]

    • time_varying_unknown_reals 列表指定了模型需要预测的动态目标变量(时间步变化的未知特征)。在本例中,stock_price 是需要预测的目标变量,它会随时间变化。
  10. target_normalizer=GroupNormalizer(groups=["is_holiday"], transformation="softplus")

  • target_normalizer 用于对目标变量进行归一化。在这个例子中,使用了 GroupNormalizer,它会基于 is_holiday 这一组特征对目标变量进行归一化。
  • transformation="softplus" 表示使用 Softplus 函数(log(1 + exp(x)))来平滑目标变量的分布,这有助于减小异常值的影响。
3.TFT模型

我们通过TemporalFusionTransformer.from_dataset构建了TFT模型,设置了学习率、隐藏层大小等超参数。使用均方误差(MSE)作为损失函数来训练模型。

详细解析:

1. TemporalFusionTransformer.from_dataset(training, ...):

  • TemporalFusionTransformer.from_dataset() 是一个类方法,它从给定的数据集(training)中自动配置模型的各个部分,如输入特征、目标变量、编码器长度等。
  • trainingTimeSeriesDataSet,包含了时间序列数据和相关的特征。这个方法会根据该数据集自动处理输入特征和目标,设置模型结构。

2. learning_rate=0.001:

  • learning_rate 是模型优化器的学习率,决定了模型参数在训练时的更新步幅。较小的学习率(例如 0.001)通常能帮助优化过程更稳定,但训练速度可能会变慢。

3. hidden_size=16:

  • hidden_size 是模型中每个隐藏层的单元数。这里设置为 16,表示每个隐藏层的神经元数量。较大的隐藏层大小有助于模型捕捉更多的复杂模式,但也可能增加计算复杂性和过拟合的风险。

4. attention_head_size=4:

  • attention_head_size 指定了自注意力机制中多头注意力(Multi-Head Attention)机制的头数。这里设置为 4,意味着模型将通过 4 个不同的"头"来计算注意力权重,从而捕捉不同方面的信息。这是 Transformer 模型的关键特性,可以帮助模型同时关注输入的多个部分。

5. dropout=0.1:

  • dropout 是防止过拟合的一种技术,表示在训练过程中会随机丢弃 10% 的神经元,以减少模型对特定神经元的依赖。这里设置为 0.1,表示在训练时有 10% 的概率会忽略某些神经元的输出。

6. hidden_continuous_size=8:

  • hidden_continuous_size 是连续特征的隐藏层大小。时间序列中的一些特征(例如股价、交易量等)可能是连续变量,这个参数指定了这些连续特征的隐藏表示大小。设置为 8,表示模型将使用 8 个单位来表示这些连续特征。

7. output_size=1:

  • output_size 指定了模型的输出维度。在时间序列预测任务中,通常输出一个预测值(如股票价格),所以这里设置为 1,表示模型输出一个数值。

8. loss= torch.nn.MSELoss():

  • loss 是指定优化目标的损失函数。在回归任务(如股票价格预测)中,常使用均方误差(MSE)作为损失函数。torch.nn.MSELoss() 会计算模型预测值与实际值之间的均方误差。
4.训练

使用pytorch_lightning中的Trainer进行模型训练,设置训练的epoch数为10。

详细解析

  1. import pytorch_lightning as pl

    • 导入 PyTorch Lightning 库,PyTorch Lightning 是一个用于简化 PyTorch 模型训练过程的高级框架。它封装了很多 PyTorch 中繁琐的训练步骤,使得训练过程更清晰、更易于管理。
  2. trainer = pl.Trainer(max_epochs=10, gpus=0)

    • 创建一个 Trainer 对象,这是 PyTorch Lightning 中用于训练模型的主要接口。这里使用了以下参数:
      • max_epochs=10:指定训练的最大轮数(epochs)。训练将在 10 个 epochs 后停止。
      • gpus=0:指定使用的 GPU 数量,0 表示不使用 GPU(即在 CPU 上训练)。如果你的机器上有可用的 GPU,可以将其设置为 gpus=1 或更多。
  3. trainer.fit(tft, train_dataloader)

    • fit() 方法用于训练模型。在此处,tft 是创建的 TemporalFusionTransformer 模型,train_dataloader 是训练数据的 DataLoader,包含了批次化的训练数据。
    • fit() 方法会根据训练数据自动进行模型的前向传播、反向传播和参数更新。
5. 预测

通过predict函数,模型基于测试集进行预测。

详细解析

  1. test_data = data[lambda x: x.time_idx > 80]

    • test_data 选择 datatime_idx 大于 80 的数据,作为测试集。这表示测试集包含从第 81 天开始的数据。
  2. test_dataset = TimeSeriesDataSet.from_dataset(training, test_data)

    • TimeSeriesDataSet.from_dataset() 方法根据训练集 training 和测试集 test_data 创建一个新的测试数据集。这个方法会从训练集的配置中继承必要的参数(如时间索引、特征列等),然后将其应用于测试数据集。
  3. test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

    • 使用 DataLoader 来批次化测试集。与训练集不同,测试集通常不需要进行打乱,因此 shuffle=False
    • batch_size=64 表示每个批次加载 64 个样本。
  4. predictions = tft.predict(test_dataloader, mode="raw")

    • predict() 方法用于对测试集进行预测。它将返回模型对测试集的预测结果。
    • mode="raw" 指定了返回原始的预测结果(而不是经过后处理或归一化的结果)。你可以选择不同的模式来控制返回的预测格式,例如可以选择返回预测的概率、标签等。
相关推荐
lu_rong_qq9 分钟前
【机器学习】Lesson 5 - K近邻(KNN)分类/回归
人工智能·机器学习·分类·回归
菩提祖师_1 小时前
基于大模型实现论文观点查重
java·深度学习
搬砖的小码农_Sky1 小时前
如何利用AI大模型提高软件开发效率?
人工智能
weixin_487058411 小时前
faiss 提供了多种索引类型
人工智能·机器学习·faiss
SEVEN-YEARS2 小时前
BERT模型中的嵌入后处理与注意力掩码
人工智能·bert·easyui
newxtc2 小时前
【天壤智能-注册安全分析报告-无验证纯IP限制存在误拦截隐患】
人工智能·tcp/ip·安全·网易易盾·ai写作·极验
爱技术的小伙子2 小时前
【ChatGPT】让ChatGPT生成批判性思维问题的回答
人工智能·chatgpt
Rolei_zl2 小时前
AIGC(生成式AI)试用 18 -- AI Prompt
人工智能·prompt
小众AI2 小时前
Vanna - 与你的 SQL 数据库聊天
数据库·人工智能·sql