时间序列预测大模型-TimeGPT

时间序列预测领域正在经历一个非常激动人心的时期。仅在过去的三年里,我们就看到了许多重要的贡献,例如N-BEATSN-HiTSPatchTSTTimesNet

与此同时,大型语言模型 (LLM)最近在 ChatGPT 等应用程序中广受欢迎,因为它们无需进一步训练即可适应各种任务。

这就引出了一个问题:时间序列的基础模型是否可以像自然语言处理一样存在?在大量时间序列数据上预先训练的大型模型是否有可能对未见过的数据产生准确的预测?

由 Azul Garza 和 Max Mergenthaler-Canseco 提出,作者将大模型背后的技术和架构应用于预测领域,成功构建了第一个能够进行零样本推理的时间序列基础模型。

在本文中,我们首先探讨 TimeGPT 背后的架构以及模型的训练方式。然后,我们将其应用于预测项目,以根据其他最先进的方法(如 N-BEATS、N-HiTS 和 PatchTST)评估其性能。

欲了解更多详细信息,请务必阅读原始论文

探索 TimeGPT

如前所述,TimeGPT 是创建时间序列预测基础模型的首次尝试。

说明如何训练 TimeGPT 以对看不见的数据进行推理。图片由 Azul Garza 和 Max Mergenthaler-Canseco 拍摄,来自TimeGPT-1

从上图我们可以看出,TimeGPT 背后的总体思想是在来自不同领域的大量数据上训练模型,然后对未见过的数据进行零样本推理。

当然,这种方法依赖于迁移学习,这是模型利用训练期间获得的知识解决新任务的能力。

现在,只有当模型足够大并且经过大量数据训练时,这才有效。

训练TimeGPT

为此,作者使用超过 1000 亿个数据点对 TimeGPT 进行了训练,这些数据点全部来自开源时间序列数据。该数据集涵盖广泛的领域,从金融、经济和天气,到网络流量、能源和销售。

请注意,作者没有透露用于管理 1000 亿个数据点的公共数据来源。

这种多样性对于基础模型的成功至关重要,因为它可以学习不同的时间模式,从而更好地泛化。

例如,我们可以预期天气数据具有每日季节性(白天比晚上更热)和每年季节性,而汽车交通数据可以具有每日季节性(白天道路上的汽车多于夜间)和每周季节性。季节性(一周内道路上的汽车多于周末)。

为了确保模型的稳健性和泛化能力,预处理被保持在最低限度。事实上,只填充了缺失的值,其余的保持原始形式。虽然作者没有指定数据插补的方法,但我怀疑使用了某种插值技术,例如线性插值、样条插值或移动平均插值。

然后对模型进行多天的训练,在此期间优化超参数和学习率。虽然作者没有透露训练需要多少天和 GPU,但我们确实知道该模型是在 PyTorch 中实现的,并且它使用 Adam 优化器和学习率衰减策略。

TimeGPT的架构

TimeGPT 利用 Transformer 架构和基于 Google 和多伦多大学 2017 年开创性工作的自注意力机制。

TimeGPT 的架构。输入序列与外生变量一起被馈送到 Transfomer 的编码器,然后解码器生成预测。图片由 Azul Garza 和 Max Mergenthaler-Canseco 拍摄,来自TimeGPT-1

从上图我们可以看到TimeGPT采用了完整的编码器-解码器Transformer架构。

输入可以包含历史数据窗口以及外源数据,例如准时事件或其他系列。

输入被馈送到模型的编码器部分。然后,编码器内部的注意力机制从输入中学习不同的属性。然后将其馈送到解码器,解码器使用学到的信息来生成预测。当然,当预测序列达到用户设置的预测范围的长度时,预测序列就会结束。

值得注意的是,作者在 TimeGPT 中实现了保形预测,允许模型根据历史误差估计预测区间。

TimeGPT 的功能

考虑到 TimeGPT 是构建时间序列基础模型的首次尝试,因此它具有广泛的功能。

首先,TimeGPT 是一个预先训练的模型,这意味着我们可以生成预测,而无需专门针对我们的数据进行训练。尽管如此,仍然可以根据我们的数据对模型进行微调。

其次,该模型支持外生变量来预测我们的目标,并且它可以处理多元预测任务。

最后,通过使用保形预测,TimeGPT 可以估计预测间隔。这反过来又允许模型执行异常检测。基本上,如果数据点超出 99% 置信区间,则模型会将其标记为异常。

请记住,所有这些任务都可以通过零样本推理或一些微调来实现,这是时间序列预测领域范式的根本转变。

现在我们对 TimeGPT、它的工作原理以及训练方式有了更深入的了解,让我们看看该模型的实际应用。

使用 TimeGPT 进行预测

现在让我们将 TimeGPT 应用于预测任务,并将其性能与其他模型进行比较。

请注意,在撰写本文时,TimeGPT 只能通过 API 访问,并且处于封闭测试阶段。我提交了请求并获得了两周免费访问该模型的许可。要获取令牌并访问模型,您必须访问他们的网站

如前所述,该模型是根据来自公开数据的 1000 亿个数据点进行训练的。由于作者没有指定实际使用的数据集,我认为在已知的基准数据集(例如ETTWeather)上测试模型是不合理的,因为模型可能在训练期间看到了这些数据。

因此,我为本文编译并开源了自己的数据集。

具体来说,我策划了从 2020 年 1 月 1 日到 2023 年 10 月 12 日期间博客上的每日浏览量。我还添加了两个外生变量:一个表示发布新文章的日期,另一个表示发布新文章的日期。我在美国度假,因为我的大多数观众都住在那里。

该数据集现已在GitHub上公开提供,最重要的是,我们确信 TimeGPT 并未使用该数据进行训练

导入库并读取数据

自然的第一步是导入此实验的库。

ba 复制代码
import pandas as pd
import numpy as np
import datetime
import matplotlib.pyplot as plt

from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS, NBEATS, PatchTST

from neuralforecast.losses.numpy import mae, mse

from nixtlats import TimeGPT

%matplotlib inline

然后,为了访问 TimeGPT 模型,我们从文件中读取 API 密钥。请注意,我没有将 API 密钥分配给环境变量,因为访问仅限两周。

ba 复制代码
with open("data/timegpt_api_key.txt", 'r') as file:
        API_KEY = file.read()

然后,我们就可以读取数据了。

ba 复制代码
df = pd.read_csv('data/medium_views_published_holidays.csv')
df['ds'] = pd.to_datetime(df['ds'])

df.head()

我们数据集的前五行。

从上图可以看出,数据集的格式与我们使用 Nixtla 的其他开源库时的格式相同。

我们有一个unique_id列来标记不同的时间序列,但在我们的例子中,我们只有一个序列。

y 列代表我博客上的每日浏览量,published是一个简单的标志,用于标记发布新文章 (1) 或未发布文章 (0) 的日期。直观上我们知道,当新内容发布时,浏览量通常会在一段时间内增加。

最后, is_holiday列表示美国是否有假期。直觉是,在假期,访问我博客的人会更少。

现在,让我们可视化我们的数据并寻找可辨别的模式。

ba 复制代码
published_dates = df[df['published'] == 1]

fig, ax = plt.subplots(figsize=(12,8))

ax.plot(df['ds'], df['y'])
ax.scatter(published_dates['ds'], published_dates['y'], marker='o', color='red', label='New article')
ax.set_xlabel('Day')
ax.set_ylabel('Total views')
ax.legend(loc='best')

fig.autofmt_xdate()


plt.tight_layout()

我的博客的每日浏览量。

从上图中,我们已经可以看到一些有趣的行为。首先,请注意,红点表示一篇新发表的文章,它们几乎立即出现访问高峰。

我们还注意到 2021 年的活动减少,这反映在我博客的每日浏览量减少上。最后,在 2023 年,我们注意到文章发表后访问量出现了一些异常峰值。

放大数据,我们还发现了明显的每周季节性。

我的博客的每日浏览量。在这里,我们看到明显的每周季节性,周末参观的人较少。

从上图中,我们现在可以看到周末访问博客的访问者比工作日少。

考虑到所有这些,让我们看看如何使用 TimeGPT 来进行预测。

使用 TimeGPT 进行预测

首先,我们将数据集分为训练集和测试集。在这里,我将为测试集保留 168 个时间步,对应于 24 周的每日数据。

ba 复制代码
train = df[:-168]
test = df[-168:]

然后,我们的预测范围为 7 天,因为我有兴趣预测一整周的每日观看次数。

现在,该 API 不附带交叉验证的实现。因此,我们创建自己的循环来一次生成七个预测,直到我们对整个测试集进行预测。

ba 复制代码
future_exog = test[['unique_id', 'ds', 'published', 'is_holiday']]

timegpt = TimeGPT(token=API_KEY)

timegpt_preds = []

for i in range(0, 162, 7):

    timegpt_preds_df = timegpt.forecast(
        df=df.iloc[:1213+i],
        X_df = future_exog[i:i+7],
        h=7,
        finetune_steps=10,
        id_col='unique_id',
        time_col='ds',
        target_col='y'
    )
    
    preds = timegpt_preds_df['TimeGPT']
    
    timegpt_preds.extend(preds)

在上面的代码块中,请注意,我们必须传递外生变量的未来值。这很好,因为它们是静态变量。我们知道未来的假期日期,博客作者也知道他计划何时发表文章。

另请注意,我们使用finetune_steps参数根据数据微调 TimeGPT 。

循环完成后,我们可以将预测添加到测试集中。同样,TimeGPT 一次生成 7 个预测,直到获得 168 个预测,以便我们可以评估其预测下周每日观看次数的能力。

ba 复制代码
test['TimeGPT'] = timegpt_preds

test.head()

来自 TimeGPT 的预测。

使用 N-BEATS、N-HiTS 和 PatchTST 进行预测

现在,让我们应用其他方法来看看专门在我们的数据集上训练这些模型是否可以产生更好的预测。

对于这个实验,如前所述,我们使用 N-BEATS、N-HiTS 和 PatchTST。

ba 复制代码
horizon = 7

models = [NHITS(h=horizon,
               input_size=5*horizon,
               max_steps=50),
         NBEATS(h=horizon,
               input_size=5*horizon,
               max_steps=50),
         PatchTST(h=horizon,
                 input_size=5*horizon,
                 max_steps=50)]

然后,我们初始化NeuralForecast对象并指定数据的频率,在本例中为每天。

ba 复制代码
nf = NeuralForecast(models=models, freq='D')

然后,我们在 7 个时间步长的 24 个窗口上运行交叉验证,以获得与 TimeGPT 使用的测试集一致的预测。

ba 复制代码
preds_df = nf.cross_validation(
    df=df, 
    static_df=future_exog , 
    step_size=7, 
    n_windows=24
)

然后,我们可以简单地将 TimeGPT 的预测添加到这个新的preds_df DataFrame 中,以获得包含所有模型预测的单个 DataFrame。

ba 复制代码
preds_df['TimeGPT'] = test['TimeGPT']

包含所有模型预测的数据框。

伟大的!我们现在准备评估每个模型的性能。

评估

在测量性能指标之前,让我们可视化测试集上每个模型的预测。

可视化每个模型的预测。

首先,我们看到每个模型之间有很多重叠。然而,我们确实注意到 N-HiTS 预测了两个在现实生活中未实现的峰值。此外,PatchTST 似乎经常被低估。然而,TimeGPT 似乎通常与实际数据重叠得很好。

当然,评估每个模型性能的唯一方法是测量性能指标。在这里,我们使用平均绝对误差(MAE)和均方误差(MSE)。此外,我们将预测四舍五入为整数,因为小数对于博客的日常访问者来说没有意义。

ba 复制代码
preds_df = preds_df.round({
    'NHITS': 0,
    'NBEATS': 0,
    'PatchTST': 0,
    'TimeGPT': 0
})

data = {'N-HiTS': [mae(preds_df['NHITS'], preds_df['y']), mse(preds_df['NHITS'], preds_df['y'])],
       'N-BEATS': [mae(preds_df['NBEATS'], preds_df['y']), mse(preds_df['NBEATS'], preds_df['y'])],
       'PatchTST': [mae(preds_df['PatchTST'], preds_df['y']), mse(preds_df['PatchTST'], preds_df['y'])],
       'TimeGPT': [mae(preds_df['TimeGPT'], preds_df['y']), mse(preds_df['TimeGPT'], preds_df['y'])]}

metrics_df = pd.DataFrame(data=data)
metrics_df.index = ['mae', 'mse']

metrics_df.style.highlight_min(color='lightgreen', axis=1)

每个模型的性能指标。在这里,TimeGPT 是冠军模型,因为它实现了最低的 MAE 和 MSE。

从上图可以看出,TimeGPT 是冠军模型,因为它实现了最低的 MAE 和 MSE,其次是 N-BEATS、PatchTST 和 N-HiTS。

这是一个令人兴奋的结果,因为 TimeGPT 从未见过这个数据集,并且只进行了几个步骤的微调。虽然这不是一个详尽的实验,但我相信它确实展示了预测领域潜在的基础模型的一瞥。

我对TimeGPT的个人看法

虽然我对 TimeGPT 的简短实验被证明是令人兴奋的,但我必须指出,原始论文在许多重要领域仍然含糊不清。

同样,我们不知道使用哪些数据集来训练和测试模型,因此我们无法真正验证 TimeGPT 的性能结果,如下所示。

Azul Garza 和 Max Mergenthaler-Canseco 的原始论文中报告的 TimeGPT 性能结果

从上表中我们可以看到,TimeGPT 在每月和每周的频率上表现最好,N-HiTS 和 Temporal Fusion Transformer (TFT) 通常排名第二或第三。话又说回来,因为我们不知道使用了哪些数据,所以我们无法验证这些指标。

在如何训练模型以及如何调整模型来处理时间序列数据方面也缺乏透明度。

我相信该模型是用于商业用途,这解释了为什么该论文缺乏重现 TimeGPT 的细节。这并没有什么问题,但论文缺乏可重复性是科学界担心的问题。

尽管如此,我还是希望这能够激发时间序列基础模型的新工作和研究,并且我们最终会看到这些模型的开源版本,就像我们看到法学硕士发生的情况一样。

结论

TimeGPT 是时间序列预测的第一个基础模型。

它利用 Transformer 架构,并在 1000 亿个数据点上进行了预训练,以对新的未见数据进行零样本推理。

结合保形预测技术,该模型无需在特定数据集上进行训练即可生成预测区间并执行异常检测。

我仍然相信每个预测问题都需要独特的方法,因此请务必测试 TimeGPT 以及其他模型。

谢谢阅读!我希望您喜欢它并学到新东西!

相关推荐
belldeep2 小时前
python:reportlab 将多个图片合并成一个PDF文件
python·pdf·reportlab
FreakStudio5 小时前
全网最适合入门的面向对象编程教程:56 Python字符串与序列化-正则表达式和re模块应用
python·单片机·嵌入式·面向对象·电子diy
丶21365 小时前
【CUDA】【PyTorch】安装 PyTorch 与 CUDA 11.7 的详细步骤
人工智能·pytorch·python
_.Switch6 小时前
Python Web 应用中的 API 网关集成与优化
开发语言·前端·后端·python·架构·log4j
一个闪现必杀技6 小时前
Python入门--函数
开发语言·python·青少年编程·pycharm
小鹿( ﹡ˆoˆ﹡ )6 小时前
探索IP协议的神秘面纱:Python中的网络通信
python·tcp/ip·php
卷心菜小温7 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
陈苏同学7 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
唐家小妹7 小时前
介绍一款开源的 Modern GUI PySide6 / PyQt6的使用
python·pyqt
羊小猪~~8 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm