机器学习第四十四周周报 SAMformer

文章目录

  • [week44 SAMformer](#week44 SAMformer)
  • 摘要
  • Abstract
    • [1. 题目](#1. 题目)
    • [2. Abstract](#2. Abstract)
    • [3. 网络架构](#3. 网络架构)
      • [3.1 问题提出](#3.1 问题提出)
      • [3.2 微型示例](#3.2 微型示例)
      • [3.3 SAMformer](#3.3 SAMformer)
    • [4. 文献解读](#4. 文献解读)
      • [4.1 Introduction](#4.1 Introduction)
      • [4.2 创新点](#4.2 创新点)
      • [4.3 实验过程](#4.3 实验过程)
    • [5. 结论](#5. 结论)
    • 6.代码复现

week44 SAMformer

摘要

本周阅读了题为SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention的论文。研究发现,Transformer在小规模线性预测问题中表达能力虽强,但难以收敛至理想水平,其注意力机制导致泛化能力低。为此,该文提出一种轻量级Transformer模型,结合锐度感知优化,成功避免不良局部最小值。实验证明,该模型在多元时间序列数据集上表现优越,超越当前先进方法,且参数显著减少。

Abstract

This week's weekly newspaper decodes the paper entitled SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention. Research has found that while the Transformer model possesses strong expressive power in small-scale linear prediction problems, it fails to converge to the ideal level due to its attention mechanism, which leads to low generalization ability. Therefore, this article proposes a lightweight Transformer model that, combined with sharpness-aware optimization, successfully avoids undesirable local minima. Experiments have proven that this model performs superbly on multivariate time series datasets, surpassing current advanced methods while significantly reducing parameters.

1. 题目

标题:SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention

作者:Romain Ilbert, Ambroise Odonnat, Vasilii Feofanov, Aladin Virmaux, Giuseppe Paolo, Themis Palpanas, Ievgen Redko

发布:Accepted as an Oral at ICML 2024, Vienna

链接:https://arxiv.org/abs/2402.10198

2. Abstract

首先研究一个小规模线性预测问题,结果表明Transformer尽管具有很高的表达能力,但无法收敛到期望中水平。进一步确定Transformer的注意力是造成这种低泛化能力的原因。基于这一见解,提出了一种浅层轻量级Transformer模型,当通过锐度感知优化进行优化时,该模型成功地避免了不良的局部最小值。凭经验证明,这一结果可以扩展到所有常用的现实世界多元时间序列数据集。特别是,SAMformer 超越了当前最先进的方法,与最大的基础模型 MOIRAI 相当,但参数却少得多。

First, this article investigated a small-scale linear prediction problem and found that although the Transformer model possesses high expressive power, it fails to converge to the desired level. Further analysis revealed that the attention mechanism in the Transformer is the cause of this low generalization ability. Based on this insight, we propose a shallow and lightweight Transformer model that successfully avoids undesirable local minima when optimized using sharpness-aware minimization. Empirical evidence demonstrates that this result can be extended to all commonly used real-world multivariate time series datasets. Specifically, SAMformer outperforms current state-of-the-art methods, comparable to the largest baseline model MOIRAI, but with significantly fewer parameters.

3. 网络架构

3.1 问题提出

考虑多元长期预测框架:给定长度为 L(回溯窗口)的 D 维时间序列,排列在矩阵 X ∈ R D × L X ∈ R^{D×L} X∈RD×L 中以促进通道关注,目标是预测其下一个 H 值(预测范围),用 Y ∈ R D × H Y ∈ R^{D×H} Y∈RD×H 表示。假设我们可以访问由N个观测值 ( X , Y ) = ( { X ( i ) } i = 0 N , { Y ( i ) } i = 0 N ) (X, Y) = (\{X^{(i)}\}^N_{i=0}, \{Y^{(i)}\}^N_{i=0}) (X,Y)=({X(i)}i=0N,{Y(i)}i=0N) 组成的训练集,并表示为 X d ( i ) ∈ R 1 × L X^{(i)}d ∈ R^{1×L} Xd(i)∈R1×L(分别为 Y d ( i ) ∈ R 1 × H Y^{(i)}d ∈ R^{1×H} Yd(i)∈R1×H)第 i 个输入(分别为目标)时间序列的第 d 个特征。目标是训练一个由 ω 参数化的预测器 f ω : R D × L → R D × H f_ω:R^{D×L} \rightarrow R^{D×H} fω:RD×L→RD×H ,以最小化训练集上的均方误差 (MSE):
L t r a i n ( ω ) = 1 N D ∑ i = 0 N ∣ ∣ Y ( i ) − f ω ( X ( i ) ) ∣ ∣ F 2 (1) L
{train}(ω)=\frac1{ND}\sum^N
{i=0}||Y^{(i)}-f_ω(X^{(i)})||^2_F \tag{1} Ltrain(ω)=ND1i=0∑N∣∣Y(i)−fω(X(i))∣∣F2(1)

3.2 微型示例

Transformer 的性能与经过训练直接将输入投影到输出的简单线性神经网络相当或更差。考虑以下小型回归问题的生成模型,模仿稍后考虑的时间序列预测设置
Y = X W t o y + ϵ (2) Y=XW_{toy}+\epsilon\tag{2} Y=XWtoy+ϵ(2)

令L = 512,H = 96,D = 7 且 W t o y ε R L × H W^{toy} ε R^{L×H} WtoyεRL×H, $\epsilon \in R^{D×H} $具有随机正态条目,并生成 15000 个输入目标对 (X,Y)(10000 个用于训练,5000 个用于验证)。考虑到这个生成模型,希望开发一种Transformer 架构,可以有效地解决方程(1)中的问题。 (2)没有不必要的复杂性。为了实现这一目标,建议通过将注意力应用于 X 并结合将 X 添加到注意力输出的残差连接来简化常用的 Transformer 编码器。没有在此残差连接之上添加前馈块,而是直接采用线性层进行输出预测。正式地,模型定义如下:
f ( X ) = [ X + A ( X ) X W V W O ] W (3) f(X)=[X+A(X)XW_VW_O]W \tag{3} f(X)=[X+A(X)XWVWO]W(3)

A(X) 是输入序列 X ∈ R D × L X \in R^{D×L} X∈RD×L 的注意力矩阵,定义为
A ( X ) = softmax ( X W Q W W K T X T d m ) ∈ R D × D (4) A(X)=\text{softmax}(\frac{XW_QWW^T_KX^T}{\sqrt d_m})\in R^{D\times D} \tag{4} A(X)=softmax(d mXWQWWKTXT)∈RD×D(4)

首先,注意力渠道化,这简化了问题,降低了过度参数化的风险,因为矩阵W与Eq.(2)中的形状相同,并且由于L > d,注意矩阵变得更小。

根据方程生成的数据拟合Transformer 的优化问题。 式2理论上允许无限多个最优分类器W。

如上图,尽管 Transformer 很简单,但它却存在严重的过度拟合问题。随机Transformer中的注意力权重可以提高泛化能力,暗示注意力在防止收敛到最佳局部最小值方面的作用。随机Transformer仅优化W,自注意力权重固定。Transformer泛化能力差的主要原因是注意力模块的可训练性问题。

3.3 SAMformer

为了实现更好的泛化性能和训练稳定性,采用了锐度感知最小框架。将式1迭代为
L t r a i n S A M ( ω ) = max ⁡ ∣ ∣ ϵ ∣ ∣ < ρ L t r a i n ( ω + ϵ ) L^{SAM}{train}(ω)=\max{||\epsilon||<\rho}L_{train}(ω+\epsilon) LtrainSAM(ω)=∣∣ϵ∣∣<ρmaxLtrain(ω+ϵ)

提议的 SAMformer 基于式(3),有两个重要的修改。

首先,为其配备了应用于 X 的可逆实例归一化(RevIN),因为该技术被证明可以有效处理时间序列中训练数据和测试数据之间的转换。其次,使用 SAM 优化模型,使其收敛到更平坦的局部最小值。总的来说,这给出了图 4 中带有一个编码器的浅层变压器模型。

4. 文献解读

4.1 Introduction

当前方法的局限性:最近将 Transformer 应用于时间序列数据的工作主要集中在:

  1. 降低注意力的二次成本的有效实现
  2. 分解时间序列以更好地捕捉其中的潜在模式

上述研究没有很好的解决现有困境

Transformer 的可训练性

在时间序列预测的情况下,存在如何有效地训练 Transformer 架构而不出现过度拟合的问题。该研究目标是证明,通过消除训练的不稳定性,变压器可以在多元长期预测方面表现出色,这与之前对其局限性的看法相反。

4.2 创新点

该问题提出了SAMformer,主要贡献大致如下:

  1. 研究表明,即使Transformer 架构是为了解决简单的微型线性预测问题而定制的,它的泛化能力仍然很差并且收敛到尖锐的局部最小值。进一步确定注意力是造成这种现象的主要原因;
  2. 提出了一种浅层Transformer 模型,称为 SAMformer,它结合了研究界提出的最佳实践,包括可逆实例归一化和通道注意力最近在计算机视觉社区中引入。结果证明,通过锐度感知最小化(SAM)优化这样一个简单的Transformer 可以收敛到局部最小值,并具有更好的泛化能力;
  3. 凭经验证明了该方法在常见的多元长期预测数据集上的优越性。 SAMformer 超越了当前最先进的方法,与最大的基础模型 MOIRAI 相当,但参数却少得多。

4.3 实验过程

在该部分,提供了实证证明 SAMformer 在通用基准的多元长期时间序列预测中的定量和定性优势。具体来说,证明 SAMformer 比当前最先进的多元 TSMixer高出 14.33%,同时参数减少了约 4 倍。

数据集:在现实世界多元时间序列的 8 个公开数据集上进行了实验,四个电力变压器温度数据集 ETTh1、ETTh2、ETTm1 和 ETTm2、Electricity (UCI, 2015)、Exchange (Lai et al., 2018b)、Traffic (California Department of Transportation, 2021)以及Weather (Max Planck Institute, 2021)。

所有时间序列均以输入长度 L = 512、预测范围 H ∈ {96, 192, 336, 720} 和步幅为 1 进行分段,这意味着每个后续窗口都会移动一步。

基线模型:Transformer 和 TSMixer(Chen 等人,2023)。其中TSMixer 是完全基于 MLP 构建的最先进的多元基线。iTransformer、PatchTST、FEDformer、Informer (Zhou et al., 2022)、Informer (Zhou et al., 2021)和Autoformer。所有报告的结果都是使用 RevIN(Kim 等人,2021b)获得的,以便在 SAMformer 及其竞争对手之间进行更公平的比较。

评估方法:所有模型都经过训练,以式1最大限度地减少方程式中定义的 MSE 损失。 报告测试集上的平均 MSE,以及使用不同种子运行 5 次的标准差。其他详细信息和结果,包括平均绝对误差 (MAE)。除非另有说明,所有的结果都是使用不同种子进行 5 次运行而获得的。

实验结果

在SAMformer的训练中引入SAM,使其损耗比Transformer更平滑。我们在上图a中通过比较

在ETTh1和Exchange上训练后Transformer和SAMformer的值来说明这一点。我们的观察表明,Transformer表现出相当高的清晰度,而SAMformer有一个理想的行为,损失景观清晰度是一个数量级小。

SAMformer演示了针对随机初始化的反业务。图5b - 1给出了SAMformer和Transformer在ETTh1

和Exchange上5种不同种子的试验MSE分布,预测水平为H = 96。SAMformer在不同的种子选择中始终保持性能稳定性,而Transformer表现出显著的差异,因此高度依赖于权重初始化。

在表 1 中,报告了使用不同种子进行多次运行的性能,从而获得更可靠的评估。为了公平比较,还包括了经过训练的 TSMixer 的性能

SAMformer在8个数据集中的7个上明显优于其竞争对手。特别是,它比其最

佳竞争对手TSMixer+SAM提高了5.25%,比独立TSMixer提高了14.33%,比基于变压器的最佳模型FEDformer提高了12.36%。此外,它比Transformer提高了16.96%。对于每个数据集和视界,SAMformer被排名第一或第二。值得注意的是,SAM的集成提高了TSMixer的泛化能力,平均提高了9.58%。

5. 结论

SAMformer 通过锐度感知最小化进行了优化,与现有的预测基线(包括当前最大的基础模型 MOIRAI)相比,带来了显着的性能增益,并受益于跨数据集和预测范围的高通用性和鲁棒性。最后,我们还表明,时间序列预测中的通道注意力在计算和性能方面比以前常用的时间注意力更加有效。我们相信,这一令人惊讶的发现可能会刺激在我们简单的架构之上进行许多进一步的工作,以进一步改进它。

6.代码复现

该代码可从 https://github.com/romilbert/samformer 获取。

attention实现

这段代码定义了一个使用PyTorch库的函数,实现了缩放点积注意力机制,用于模型中的注意力计算。它接受查询(query)、键(key)和值(value)张量,并可选地接受注意力掩码和丢失概率参数。函数首先计算查询和键的点积,并根据查询的维度大小进行缩放。如果指定了因果关系或提供了注意力掩码,它会修改注意力权重以避免未来信息的泄露或应用额外的掩码。最后,它使用Softmax函数规范化注意力权重,并通过点积操作与值张量结合,输出最终的注意力加权结果。

python 复制代码
import torch
 
import numpy as np
 
 
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
    """
    A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
    """
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
 
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

dataset处理

这段代码定义了一个Python类 LabeledDataset,它是一个继承自 torch.utils.data.Dataset 的自定义数据集类,用于处理带有标签的数据。该类的主要作用是将 NumPy 数组格式的数据和标签转换为 PyTorch 张量格式,并提供了一些基本的数据处理方法,使其可以用于 PyTorch 的数据加载和预处理流程中。

python 复制代码
import torch
 
from torch.utils.data import Dataset
 
 
class LabeledDataset(Dataset):
    def __init__(self, x, y):
        """
        Converts numpy data to a torch dataset
        Args:
            x (np.array): data matrix
            y (np.array): class labels
        """
        self.x = torch.FloatTensor(x)
        self.y = torch.FloatTensor(y)
 
    def transform(self, x):
        return torch.FloatTensor(x)
 
    def __len__(self):
        return self.y.shape[0]
 
    def __getitem__(self, idx):
        examples = self.x[idx]
        labels = self.y[idx]
        return examples, labels

RevIN

这段代码定义了一个名为 RevIN(Reversible Instance Normalization)的 Python 类,它继承自 PyTorch 的 nn.Module。这个类实现了可逆的实例归一化,主要用于神经网络中,可以在正向传播时进行标准化,并在需要时进行反向去标准化。

python 复制代码
import torch
import torch.nn as nn
 
 
class RevIN(nn.Module):
    """
    Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p
    https://github.com/ts-kim/RevIN
    """
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        """
        :param num_features: the number of features or channels
        :param eps: a value added for numerical stability
        :param affine: if True, RevIN has learnable affine parameters
        """
        super(RevIN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()
 
    def forward(self, x, mode:str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else: raise NotImplementedError
        return x
 
    def _init_params(self):
        # initialize RevIN params: (C,)
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
 
    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim-1))
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
 
    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x
 
    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps*self.eps)
        x = x * self.stdev
        x = x + self.mean
        return x

SAM

这段代码定义了一个名为 SAM 的 Python 类,它是一个用于优化神经网络训练的自定义优化器,继承自 PyTorch 的 Optimizer。SAM 代表 Sharpness-Aware Minimization,这是一种用于改进模型泛化能力的优化技术,通过最小化损失函数的锐度来实现。

python 复制代码
import torch

from torch.optim import Optimizer


class SAM(Optimizer):
    """
    SAM: Sharpness-Aware Minimization for Efficiently Improving Generalization https://arxiv.org/abs/2010.01412
    https://github.com/davda54/sam
    """
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
    assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
 
    defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
    super(SAM, self).__init__(params, defaults)
 
    self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
    self.param_groups = self.base_optimizer.param_groups
 
@torch.no_grad()
def first_step(self, zero_grad=False):
    grad_norm = self._grad_norm()
    for group in self.param_groups:
        scale = group["rho"] / (grad_norm + 1e-12)
 
        for p in group["params"]:
            if p.grad is None:
                continue
            e_w = (
                (torch.pow(p, 2) if group["adaptive"] else 1.0)
                * p.grad
                * scale.to(p)
            )
            p.add_(e_w)  # climb to the local maximum "w + e(w)"
            self.state[p]["e_w"] = e_w
 
    if zero_grad:
        self.zero_grad()
 
@torch.no_grad()
def second_step(self, zero_grad=False):
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
 
    self.base_optimizer.step()  # do the actual "sharpness-aware" update
 
    if zero_grad:
        self.zero_grad()
 
@torch.no_grad()
def step(self, closure=None):
    assert (
        closure is not None
    ), "Sharpness Aware Minimization requires closure, but it was not provided"
    closure = torch.enable_grad()(
        closure
    )  # the closure should do a full forward-backward pass
 
    self.first_step(zero_grad=True)
    closure()
    self.second_step()
 
def _grad_norm(self):
    shared_device = self.param_groups[0]["params"][
        0
    ].device  # put everything on the same device, in case of model parallelism
    norm = torch.norm(
        torch.stack(
            [
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad)
                .norm(p=2)
                .to(shared_device)
                for group in self.param_groups
                for p in group["params"]
                if p.grad is not None
            ]
        ),
        p=2,
    )
    return norm

samformer

这段代码定义了两个主要的 Python 类,SAMFormerArchitecture 和 SAMFormer,它们基于 PyTorch 框架用于构建和训练一个深度学习模型,具体是用于时间序列预测任务。这些类利用了一些先进的技术如可逆实例归一化(RevIN)、注意力机制和锐度感知最小化(SAM)来改善模型的预测性能和泛化能力。

python 复制代码
import torch
import random
import numpy as np

from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader

from .utils.attention import scaled_dot_product_attention
from .utils.dataset import LabeledDataset
from .utils.revin import RevIN
from .utils.sam import SAM


class SAMFormerArchitecture(nn.Module):
    def __init__(self, num_channels, seq_len, hid_dim, pred_horizon, use_revin=True):
        super().__init__()
        self.revin = RevIN(num_features=num_channels)
        self.compute_keys = nn.Linear(seq_len, hid_dim)
        self.compute_queries = nn.Linear(seq_len, hid_dim)
        self.compute_values = nn.Linear(seq_len, seq_len)
        self.linear_forecaster = nn.Linear(seq_len, pred_horizon)
        self.use_revin = use_revin
        
def forward(self, x):
    # RevIN Normalization
    if self.use_revin:
        x_norm = self.revin(x.transpose(1, 2), mode='norm').transpose(1, 2) # (n, D, L)
    else:
        x_norm = x
    # Channel-Wise Attention
    queries = self.compute_queries(x_norm) # (n, D, hid_dim)
    keys = self.compute_keys(x_norm) # (n, D, hid_dim)
    values = self.compute_values(x_norm) # (n, D, L)
    if hasattr(nn.functional, 'scaled_dot_product_attention'):
        att_score = nn.functional.scaled_dot_product_attention(queries, keys, values) # (n, D, L)
    else:
        att_score = scaled_dot_product_attention(queries, keys, values) # (n, D, L)
    out = x_norm + att_score # (n, D, L)
    # Linear Forecasting
    out = self.linear_forecaster(out) # (n, D, H)
    # RevIN Denormalization
    if self.use_revin:
        out = self.revin(out.transpose(1, 2), mode='denorm').transpose(1, 2) # (n, D, H)
    return out.reshape([out.shape[0], out.shape[1]*out.shape[2]])
class SAMFormer:
    """
    SAMFormer pytorch trainer implemented in the sklearn fashion
    """
    def __init__(self, device='cuda:0', num_epochs=100, batch_size=256, base_optimizer=torch.optim.Adam,
                 learning_rate=1e-3, weight_decay=1e-5, rho=0.5, use_revin=True, random_state=None):
        self.network = None
        self.criterion = nn.MSELoss()
        self.device = device
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.base_optimizer = base_optimizer
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.rho = rho
        self.use_revin = use_revin
        self.random_state = random_state
        
def fit(self, x, y):
    if self.random_state is not None:
        torch.manual_seed(self.random_state)
        random.seed(self.random_state)
        np.random.seed(self.random_state)
        torch.cuda.manual_seed_all(self.random_state)
 
    self.network = SAMFormerArchitecture(num_channels=x.shape[1], seq_len=x.shape[2], hid_dim=16,
                                         pred_horizon=y.shape[1] // x.shape[1], use_revin=self.use_revin)
    self.criterion = self.criterion.to(self.device)
    self.network = self.network.to(self.device)
    self.network.train()
 
    optimizer = SAM(self.network.parameters(), base_optimizer=self.base_optimizer, rho=self.rho,
                    lr=self.learning_rate, weight_decay=self.weight_decay)
 
    train_dataset = LabeledDataset(x, y)
    data_loader_train = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
 
    progress_bar = tqdm(range(self.num_epochs))
    for epoch in progress_bar:
        loss_list = []
        for (x_batch, y_batch) in data_loader_train:
            x_batch = x_batch.to(self.device)
            y_batch = y_batch.to(self.device)
            # =============== forward ===============
            out_batch = self.network(x_batch)
            loss = self.criterion(out_batch, y_batch)
            # =============== backward ===============
            if optimizer.__class__.__name__ == 'SAM':
                loss.backward()
                optimizer.first_step(zero_grad=True)
 
                out_batch = self.network(x_batch)
                loss = self.criterion(out_batch, y_batch)
 
                loss.backward()
                optimizer.second_step(zero_grad=True)
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            loss_list.append(loss.item())
        # =============== save model / update log ===============
        train_loss = np.mean(loss_list)
        self.network.train()
        progress_bar.set_description("Epoch {:d}: Train Loss {:.4f}".format(epoch, train_loss), refresh=True)
    return
 
def forecast(self, x, batch_size=256):
    self.network.eval()
    dataset = torch.utils.data.TensorDataset(torch.tensor(x, dtype=torch.float))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    outs = []
    for _, batch in enumerate(dataloader):
        x = batch[0].to(self.device)
        with torch.no_grad():
            out = self.network(x)
        outs.append(out.cpu())
    outs = torch.cat(outs)
    return outs.cpu().numpy()
 
def predict(self, x, batch_size=256):
    return self.forecast(x, batch_size=batch_size)

小结

本文讨论了在时间序列预测中运用转换器模型的挑战与创新。传统的基于Transformer的模型虽然在自然语言处理和计算机视觉领域表现出色,但在多变量长期预测任务上,它们的性能却不及简单的线性模型。研究中指出,这些模型在基本的线性预测场景中难以实现最佳解决方案,主要问题在于其注意力机制的泛化能力较差。为应对这一问题,研究提出了一种新型的浅层、轻量级Transformer模型,即SAMformer。该模型采用锐度感知优化(SAM),有效克服了不良的局部最小值,显著提高了模型在多变量时间序列数据集上的性能,性能提升幅度达14.33%,且模型的参数数量大约减少了四倍。此外,SAMformer展示了优越的泛化能力和鲁棒性,其在多个数据集上的表现均优于当前先进的多变量模型TSMixer。研究结果表明,采用channel-wise注意力机制的SAMformer在计算和性能方面都比传统的temporal attention更为有效,为时间序列预测领域提供了新的视角和方法。

参考文献

[1] Romain Ilbert, Ambroise Odonnat, Vasilii Feofanov, Aladin Virmaux, Giuseppe Paolo, Themis Palpanas, Ievgen Redko "SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention" [C], ICML 2024

相关推荐
编程迪6 分钟前
自研PHP版本AI口播数字人系统源码适配支持公众号H5小程序
人工智能·数字人系统源码·口播数字人·数字人小程序·数字人开源
Anna_Tong10 分钟前
人工智能的视觉天赋:一文读懂卷积神经网络
人工智能·神经网络·cnn
ZHOU_WUYI31 分钟前
adb 安装教程
人工智能·adb
weixin_443042651 小时前
信息系统管理师试题-转型升级
人工智能·信息系统项目管理师
CV-King2 小时前
旋转框目标检测自定义数据集训练测试流程
人工智能·目标检测·计算机视觉
无问社区2 小时前
无问社区-无问AI模型
人工智能·web安全·网络安全
井底哇哇2 小时前
Apline linux 安装scikit-learn 过程记录
python·机器学习·scikit-learn
Jacen.L2 小时前
探究音频丢字位置和丢字时间对pesq分数的影响
人工智能·语音识别
DashVector2 小时前
如何通过HTTP API插入或更新Doc
大数据·数据库·数据仓库·人工智能·http·数据库架构·向量检索
海棠AI实验室2 小时前
机器学习基础算法 (二)-逻辑回归
人工智能·python·机器学习