TSTabFusionTransformer 深度学习学习笔记

TSTabFusionTransformer 深度学习学习笔记

目标读者:具备基础深度学习知识,希望理解如何融合时序数据与表格数据进行预测的学习者


1. 问题动机与直觉理解

1.1 现实世界的数据是多模态的

想象你是一位金融分析师,需要预测某只股票未来的走势。你手上有什么数据?

  • 时序数据:过去 500 天的股价、成交量、技术指标(MACD、RSI等)
  • 表格数据:公司的财务指标(市盈率、负债率)、行业分类、地理位置

或者你是医疗研究者,需要预测病人的健康风险:

  • 时序数据:过去一周的心率、血压、体温监测曲线
  • 表格数据:年龄、性别、既往病史、基因标记

关键观察 :真实世界的预测任务往往需要同时利用动态的时序信息静态的表格属性

1.2 传统方法的困境

方案1:只用时序模型(如纯 Transformer)
python 复制代码
# 只用时序数据
model = TimeSeriesTransformer(c_in=6, seq_len=500)
pred = model(x_timeseries)  # 忽略了表格特征

问题:丢失了宝贵的静态信息!比如股票所属行业、病人的年龄性别等关键信息。

方案2:简单拼接后送入模型
python 复制代码
# 将表格特征复制到每个时间步
x_combined = torch.cat([x_timeseries, x_tabular.unsqueeze(-1).repeat(1,1,seq_len)], dim=1)

问题

  • 破坏了时序数据的结构性
  • 表格特征被强行"时序化",不自然
  • 参数冗余(表格特征在每个时间步重复)
方案3:两个模型分别处理后融合
python 复制代码
# 分别处理
ts_features = ts_model(x_timeseries)
tab_features = tab_model(x_tabular)

# 后期融合
combined = torch.cat([ts_features, tab_features], dim=1)
pred = final_mlp(combined)

问题

  • 两个模型独立训练,无法"互相学习"
  • 融合发生在最后阶段,交互不充分
  • 时序特征无法根据表格信息调整注意力

1.3 TSTabFusionTransformer 的核心思想

"让时序 Transformer 在处理时序数据时,能够'看到'表格信息,并据此调整注意力模式"

直觉上:

  • 时序分支:用 Transformer 提取时序模式(趋势、周期、异常)
  • 表格分支:用 MLP 提取表格特征的嵌入
  • 深度融合:在 Transformer 的每一层,都将表格信息融入到时序特征中

类比:就像你在看股价走势图时,脑海中同时记着"这是一家科技公司"、"市盈率偏高"等信息,这些静态知识会影响你对时序模式的解读。


2. 核心思想与概念拆解

2.1 整体架构概览

复制代码
输入:
  ┌─────────────────┐        ┌──────────────────┐
  │ 时序数据 x_ts    │        │ 表格数据 x_tab    │
  │ (bs, c_in, L)   │        │ 分类 + 连续特征   │
  └────────┬────────┘        └────────┬─────────┘
           │                          │
           ▼                          ▼
  ┌─────────────────┐        ┌──────────────────┐
  │ Patching +      │        │ Embedding        │
  │ Position Encode │        │ (类别嵌入+线性层) │
  └────────┬────────┘        └────────┬─────────┘
           │                          │
           │         融合点 ↓         │
           │    ┌────────────────┐   │
           └────→ Transformer    ←───┘
                │ Encoder Layers │
                │ (多次融合)     │
                └────────┬───────┘
                         ▼
                ┌──────────────────┐
                │ Global Pooling   │
                └────────┬─────────┘
                         ▼
                ┌──────────────────┐
                │ MLP Head         │
                └────────┬─────────┘
                         ▼
                   输出 (bs, c_out)

2.2 关键组件拆解

组件1:时序数据的 Patching

为什么需要 Patching?

原始时序数据可能很长(如 seq_len=500),直接送入 Transformer 会导致:

  • 注意力矩阵 O(L²) 的复杂度爆炸
  • 过度关注局部细节,难以捕捉全局模式

Patching 的思想(借鉴 Vision Transformer):

将长序列切分成小块(patch),每个 patch 看作一个"token"。

复制代码
原始序列: [x1, x2, x3, ..., x500]  (500 个时间步)
           ↓ patching (patch_len=16, stride=16)
Patches:   [patch1, patch2, ..., patch31]  (31 个 patch)

优势

  • 序列长度从 500 降到 31,注意力复杂度大幅降低
  • 每个 patch 代表一个局部时间窗口,Transformer 关注的是"窗口级别"的模式
组件2:表格数据的处理

表格数据包含两类特征:

  1. 分类特征(categorical):如"行业类别"、"教育程度"

    • 处理方式:Embedding 层
    python 复制代码
    # 例如 "教育程度" 有 17 种取值
    education_emb = nn.Embedding(num_classes=17, embedding_dim=d_model)
  2. 连续特征(continuous):如"年龄"、"收入"

    • 处理方式:Linear 层 + 激活函数
    python 复制代码
    cont_emb = nn.Linear(n_cont, d_model)
  3. 融合:将所有表格特征嵌入拼接后,再过一层 MLP

    python 复制代码
    # (bs, n_cat*d_model + n_cont*d_model) -> (bs, d_model)
    tab_features = MLP([cat_emb1, cat_emb2, ..., cont_emb])
组件3:特征融合机制

核心问题:如何在 Transformer 中融入表格信息?

方案:在每个 Transformer Block 中,将表格特征加到时序特征上

python 复制代码
# 伪代码
for layer in transformer_layers:
    # 1. 时序自注意力
    ts_features = MultiHeadAttention(ts_features)  # (bs, num_patches, d_model)
    
    # 2. 融入表格信息
    # 将表格特征广播到每个 patch
    tab_broadcast = tab_features.unsqueeze(1).expand(-1, num_patches, -1)
    ts_features = ts_features + tab_broadcast  # 残差连接
    
    # 3. FFN
    ts_features = FeedForward(ts_features)

直觉理解

  • 表格特征像"全局上下文",告诉 Transformer "这是什么类型的样本"
  • 时序特征在每一层都能"感知"到这个全局上下文
  • 类似于你在看图表时,脑中始终记着"这是科技股"这个背景知识

3. 数学建模与理论基础

3.1 问题形式化

给定:

  • 时序数据: X t s ∈ R B × C i n × L \mathbf{X}{ts} \in \mathbb{R}^{B \times C{in} \times L} Xts∈RB×Cin×L

    • B B B: batch size
    • C i n C_{in} Cin: 通道数(如 6 个技术指标)
    • L L L: 序列长度(如 500 个时间步)
  • 表格数据:

    • 分类特征: X c a t ∈ N B × N c a t \mathbf{X}{cat} \in \mathbb{N}^{B \times N{cat}} Xcat∈NB×Ncat(每个值是类别索引)
    • 连续特征: X c o n t ∈ R B × N c o n t \mathbf{X}{cont} \in \mathbb{R}^{B \times N{cont}} Xcont∈RB×Ncont

目标:

  • 预测输出: Y ∈ R B × C o u t \mathbf{Y} \in \mathbb{R}^{B \times C_{out}} Y∈RB×Cout(如分类任务 C o u t = C_{out}= Cout= 类别数,回归任务 C o u t = 1 C_{out}=1 Cout=1)

3.2 时序分支的数学表示

Step 1: Patching

将时序数据切分成 patches:

X t s ∈ R B × C i n × L → patch X p a t c h e s ∈ R B × N p a t c h e s × ( C i n × P ) \mathbf{X}{ts} \in \mathbb{R}^{B \times C{in} \times L} \xrightarrow{\text{patch}} \mathbf{X}{patches} \in \mathbb{R}^{B \times N{patches} \times (C_{in} \times P)} Xts∈RB×Cin×Lpatch Xpatches∈RB×Npatches×(Cin×P)

其中:

  • N p a t c h e s = ⌊ L − P S ⌋ + 1 N_{patches} = \lfloor \frac{L - P}{S} \rfloor + 1 Npatches=⌊SL−P⌋+1
  • P P P: patch length(如 16)
  • S S S: stride(如 16)
Step 2: Patch Embedding

Z ( 0 ) = Linear ( X p a t c h e s ) + E p o s \mathbf{Z}^{(0)} = \text{Linear}(\mathbf{X}{patches}) + \mathbf{E}{pos} Z(0)=Linear(Xpatches)+Epos

  • Linear : R C i n × P → R d m o d e l \text{Linear}: \mathbb{R}^{C_{in} \times P} \to \mathbb{R}^{d_{model}} Linear:RCin×P→Rdmodel(投影到嵌入空间)
  • E p o s \mathbf{E}_{pos} Epos: 位置编码(Learnable 或 Sinusoidal)
Step 3: Transformer Encoder

对于第 l l l 层:

Z ~ ( l ) = LayerNorm ( Z ( l − 1 ) ) Z ′ ( l ) = MultiHeadAttention ( Z ~ ( l ) ) + Z ( l − 1 ) (残差) Z ( l ) = FFN ( LayerNorm ( Z ′ ( l ) ) ) + Z ′ ( l ) \begin{aligned} \tilde{\mathbf{Z}}^{(l)} &= \text{LayerNorm}(\mathbf{Z}^{(l-1)}) \\ \mathbf{Z'}^{(l)} &= \text{MultiHeadAttention}(\tilde{\mathbf{Z}}^{(l)}) + \mathbf{Z}^{(l-1)} \quad \text{(残差)} \\ \mathbf{Z}^{(l)} &= \text{FFN}(\text{LayerNorm}(\mathbf{Z'}^{(l)})) + \mathbf{Z'}^{(l)} \end{aligned} Z~(l)Z′(l)Z(l)=LayerNorm(Z(l−1))=MultiHeadAttention(Z~(l))+Z(l−1)(残差)=FFN(LayerNorm(Z′(l)))+Z′(l)

3.3 表格分支的数学表示

分类特征嵌入

对于第 i i i 个分类特征:

E c a t ( i ) = Embedding ( X c a t ( i ) ) ∈ R B × d m o d e l \mathbf{E}{cat}^{(i)} = \text{Embedding}(\mathbf{X}{cat}^{(i)}) \in \mathbb{R}^{B \times d_{model}} Ecat(i)=Embedding(Xcat(i))∈RB×dmodel

连续特征嵌入

E c o n t = ReLU ( Linear ( X c o n t ) ) ∈ R B × d m o d e l \mathbf{E}{cont} = \text{ReLU}(\text{Linear}(\mathbf{X}{cont})) \in \mathbb{R}^{B \times d_{model}} Econt=ReLU(Linear(Xcont))∈RB×dmodel

表格特征融合

E t a b = MLP ( [ E c a t ( 1 ) , E c a t ( 2 ) , . . . , E c o n t ] ) ∈ R B × d m o d e l \mathbf{E}{tab} = \text{MLP}([\mathbf{E}{cat}^{(1)}, \mathbf{E}{cat}^{(2)}, ..., \mathbf{E}{cont}]) \in \mathbb{R}^{B \times d_{model}} Etab=MLP([Ecat(1),Ecat(2),...,Econt])∈RB×dmodel

3.4 多模态融合机制

关键修改:在 Transformer 的每一层中加入表格信息

修改后的第 l l l 层:

Z ~ ( l ) = LayerNorm ( Z ( l − 1 ) + E t a b ⊙ 1 N p a t c h e s ) ← 融合点 Z ′ ( l ) = MultiHeadAttention ( Z ~ ( l ) ) + Z ( l − 1 ) Z ( l ) = FFN ( LayerNorm ( Z ′ ( l ) ) ) + Z ′ ( l ) \begin{aligned} \tilde{\mathbf{Z}}^{(l)} &= \text{LayerNorm}(\mathbf{Z}^{(l-1)} + \mathbf{E}{tab} \odot \mathbf{1}{N_{patches}}) \quad \text{← 融合点} \\ \mathbf{Z'}^{(l)} &= \text{MultiHeadAttention}(\tilde{\mathbf{Z}}^{(l)}) + \mathbf{Z}^{(l-1)} \\ \mathbf{Z}^{(l)} &= \text{FFN}(\text{LayerNorm}(\mathbf{Z'}^{(l)})) + \mathbf{Z'}^{(l)} \end{aligned} Z~(l)Z′(l)Z(l)=LayerNorm(Z(l−1)+Etab⊙1Npatches)← 融合点=MultiHeadAttention(Z~(l))+Z(l−1)=FFN(LayerNorm(Z′(l)))+Z′(l)

其中:

  • E t a b ⊙ 1 N p a t c h e s \mathbf{E}{tab} \odot \mathbf{1}{N_{patches}} Etab⊙1Npatches: 将表格嵌入广播到每个 patch
  • 在每一层都进行融合,实现深度交互

3.5 输出层

Z g l o b a l = GlobalPooling ( Z ( L ) ) ∈ R B × d m o d e l Y ^ = MLPHead ( Z g l o b a l ) ∈ R B × C o u t \begin{aligned} \mathbf{Z}{global} &= \text{GlobalPooling}(\mathbf{Z}^{(L)}) \in \mathbb{R}^{B \times d{model}} \\ \hat{\mathbf{Y}} &= \text{MLPHead}(\mathbf{Z}{global}) \in \mathbb{R}^{B \times C{out}} \end{aligned} ZglobalY^=GlobalPooling(Z(L))∈RB×dmodel=MLPHead(Zglobal)∈RB×Cout

Global Pooling 可以是:

  • Mean Pooling : 1 N p a t c h e s ∑ i = 1 N p a t c h e s Z i ( L ) \frac{1}{N_{patches}} \sum_{i=1}^{N_{patches}} \mathbf{Z}_i^{(L)} Npatches1∑i=1NpatchesZi(L)
  • Max Pooling : max ⁡ i = 1 N p a t c h e s Z i ( L ) \max_{i=1}^{N_{patches}} \mathbf{Z}_i^{(L)} maxi=1NpatchesZi(L)
  • Attention Pooling: 学习权重进行加权平均

4. 优化方法与训练机制

4.1 损失函数

根据任务类型选择:

回归任务
L = MSE ( Y ^ , Y ) = 1 B ∑ i = 1 B ∥ y ^ i − y i ∥ 2 \mathcal{L} = \text{MSE}(\hat{\mathbf{Y}}, \mathbf{Y}) = \frac{1}{B} \sum_{i=1}^{B} \|\hat{\mathbf{y}}_i - \mathbf{y}_i\|^2 L=MSE(Y^,Y)=B1i=1∑B∥y^i−yi∥2

分类任务
L = CrossEntropy ( Y ^ , Y ) = − 1 B ∑ i = 1 B ∑ c = 1 C o u t y i , c log ⁡ ( y ^ i , c ) \mathcal{L} = \text{CrossEntropy}(\hat{\mathbf{Y}}, \mathbf{Y}) = -\frac{1}{B} \sum_{i=1}^{B} \sum_{c=1}^{C_{out}} y_{i,c} \log(\hat{y}_{i,c}) L=CrossEntropy(Y^,Y)=−B1i=1∑Bc=1∑Coutyi,clog(y^i,c)

4.2 梯度计算与反向传播

注意力层的梯度

多头注意力的梯度计算(以单头为例):

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V} Attention(Q,K,V)=softmax(dk QKT)V

反向传播时:
∂ L ∂ Q = ∂ L ∂ A ⋅ ∂ A ∂ Q \frac{\partial \mathcal{L}}{\partial \mathbf{Q}} = \frac{\partial \mathcal{L}}{\partial \mathbf{A}} \cdot \frac{\partial \mathbf{A}}{\partial \mathbf{Q}} ∂Q∂L=∂A∂L⋅∂Q∂A

其中 A = softmax ( Q K T / d k ) \mathbf{A} = \text{softmax}(\mathbf{Q}\mathbf{K}^T / \sqrt{d_k}) A=softmax(QKT/dk ) 的梯度涉及 Jacobian 矩阵(PyTorch 自动处理)。

表格嵌入的梯度

分类特征的嵌入层:
∂ L ∂ E e m b [ i ] = ∑ b ∈ batch where x c a t ( b ) = i ∂ L ∂ E c a t ( b ) \frac{\partial \mathcal{L}}{\partial \mathbf{E}{emb}[i]} = \sum{b \in \text{batch where } x_{cat}^{(b)} = i} \frac{\partial \mathcal{L}}{\partial \mathbf{E}_{cat}^{(b)}} ∂Eemb[i]∂L=b∈batch where xcat(b)=i∑∂Ecat(b)∂L

关键点:同一个类别的样本共享嵌入向量,梯度会累积。

4.3 优化策略

学习率调度

One-Cycle 策略(fastai 推荐):

复制代码
学习率
    ↑
max_lr ┤     /\
       │    /  \
       │   /    \
       │  /      \___
min_lr ┤ /           \
       └────────────────→ epoch
  • 前半段:学习率从 min_lr 升到 max_lr(探索阶段)
  • 后半段:学习率从 max_lr 降到 min_lr(精细调整阶段)

优势:

  • 避免过早陷入局部最优
  • 后期低学习率保证收敛
正则化技术
  1. Dropout:在注意力层和 FFN 中应用

    python 复制代码
    res_dropout=0.1  # 残差连接的 dropout
    fc_dropout=0.2   # 全连接层的 dropout
  2. Weight Decay (L2 正则):

    \\mathcal{L}*{total} = \\mathcal{L} + \\lambda \\sum*{w \\in \\mathbf{W}} w\^2 ``` ```

4.4 数值稳定性

Layer Normalization

在每个子层前进行归一化:

LayerNorm ( x ) = γ ⋅ x − μ σ 2 + ϵ + β \text{LayerNorm}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γ⋅σ2+ϵ x−μ+β

为什么重要

  • 缓解梯度消失/爆炸
  • 加速收敛
  • 减少对初始化的敏感性
Attention 计算的稳定性

softmax ( x i ) = exp ⁡ ( x i ) ∑ j exp ⁡ ( x j ) \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} softmax(xi)=∑jexp(xj)exp(xi)

数值问题 : exp ⁡ ( x ) \exp(x) exp(x) 对大数值不稳定

解决方案 :减去最大值
softmax ( x i ) = exp ⁡ ( x i − max ⁡ ( x ) ) ∑ j exp ⁡ ( x j − max ⁡ ( x ) ) \text{softmax}(x_i) = \frac{\exp(x_i - \max(x))}{\sum_j \exp(x_j - \max(x))} softmax(xi)=∑jexp(xj−max(x))exp(xi−max(x))

(PyTorch 已自动处理)


5. 代码实现与工程视角

5.1 完整的使用示例

python 复制代码
import torch
from tsai.all import *

# ============ 1. 数据准备 ============
# 时序数据
x_ts = torch.randn(64, 6, 500)  # (batch, channels, seq_len)

# 表格数据 - 分类特征
x_cat = torch.randint(0, 3, (64, 7))  # (batch, n_categorical)

# 表格数据 - 连续特征  
x_cont = torch.randn(64, 3)  # (batch, n_continuous)

# ============ 2. 定义表格特征信息 ============
# 分类特征的类别映射
classes = {
    'education': ['#na#', '10th', '11th', 'Bachelors', ...],
    'occupation': ['#na#', 'Tech', 'Sales', 'Manager', ...],
    # ... 每个分类特征的所有可能取值
}

# 连续特征的名称
cont_names = ['age', 'income', 'hours_per_week']

# ============ 3. 创建模型 ============
model = TSTabFusionTransformer(
    c_in=6,                # 时序通道数
    c_out=2,              # 输出类别数(二分类)
    seq_len=500,          # 序列长度
    classes=classes,      # 分类特征信息
    cont_names=cont_names,# 连续特征名称
    
    # Transformer 配置
    d_model=64,           # 嵌入维度
    n_layers=3,           # Transformer 层数
    n_heads=8,            # 注意力头数
    
    # 正则化
    res_dropout=0.1,      # 残差 dropout
    fc_dropout=0.2,       # 全连接 dropout
)

# ============ 4. 前向传播 ============
# 输入格式:(时序数据, (分类数据, 连续数据))
x = (x_ts, (x_cat, x_cont))
output = model(x)  # (batch_size, c_out)

print(f"输出形状: {output.shape}")  # torch.Size([64, 2])

5.2 关键代码设计解析

设计1: 灵活的输入格式
python 复制代码
# 模型接受的输入是一个 tuple
x = (x_timeseries, (x_categorical, x_continuous))

为什么这样设计

  • 清晰区分三种不同类型的数据
  • 方便 DataLoader 的组织(tsai 的 TSDataLoaders 支持这种格式)
  • 可以轻松处理"只有时序"或"只有表格"的情况
设计2: classes 字典的作用
python 复制代码
classes = {
    'education': ['#na#', '10th', '11th', ...],  # 17 种取值
    'occupation': ['#na#', 'Tech', 'Sales', ...], # 16 种取值
}

内部实现

python 复制代码
# 为每个分类特征创建独立的 Embedding
for feat_name, feat_classes in classes.items():
    n_classes = len(feat_classes)
    self.cat_embeddings[feat_name] = nn.Embedding(n_classes, d_model)

为什么不用一个统一的 Embedding

  • 不同特征的类别数可能差异很大(如性别2类 vs 职业16类)
  • 独立 Embedding 更灵活,避免参数浪费
设计3: 混合精度训练支持
python 复制代码
# 在 tsai 中启用 FP16
learn = Learner(dls, model, ...).to_fp16()

内部原理

  • 模型参数保持 FP32 精度
  • 前向传播时转为 FP16 计算(节省显存)
  • 梯度累积时转回 FP32(保证精度)

5.3 工程实践中的常见坑

坑1: 表格数据的缺失值处理

问题:原始数据中分类特征可能有缺失值

解决方案

python 复制代码
# 在 classes 中添加 '#na#' 作为第一个类别
classes = {
    'education': ['#na#', '10th', '11th', ...],  # '#na#' 对应索引 0
}

# 数据预处理时,将缺失值映射到 0
df['education'] = df['education'].fillna('#na#')
坑2: 序列长度不一致

问题:不同样本的时序长度可能不同

解决方案1 - Padding

python 复制代码
# 补齐到统一长度
x_ts_padded = torch.nn.functional.pad(x_ts, (0, max_len - x_ts.shape[-1]))

解决方案2 - Truncation

python 复制代码
# 截断到统一长度
x_ts_truncated = x_ts[:, :, :max_len]

解决方案3 - Masking(更优雅):

python 复制代码
# 在注意力计算时使用 mask
mask = create_padding_mask(seq_lengths)
attention_output = attention(x, mask=mask)
坑3: 显存爆炸

根本原因

  • 模型参数量过大(参见您之前的问题!)
  • Attention 矩阵 O(L²) 复杂度

优化策略

python 复制代码
# 1. 减小模型
model = TSTabFusionTransformer(
    ...,
    d_model=32,      # 从 64 降到 32
    n_layers=2,      # 从 6 降到 2
    n_heads=4,       # 从 8 降到 4
)

# 2. 增大 patch_len(内部自动处理)
# patch_len 越大,有效序列长度越短
# seq_len=500, patch_len=25 -> 有效长度=20

# 3. 启用梯度检查点(如果模型支持)
from torch.utils.checkpoint import checkpoint
坑4: 过拟合

表现:训练集表现好,验证集差

诊断

python 复制代码
# 观察训练/验证曲线
learn.recorder.plot_loss()

解决方案

python 复制代码
# 1. 增大 Dropout
model = TSTabFusionTransformer(..., res_dropout=0.2, fc_dropout=0.3)

# 2. 数据增强(时序数据)
tfms = [TSMagScale(), TSTimeWarp()]  # tsai 内置
dls = get_ts_dls(..., batch_tfms=tfms)

# 3. Early Stopping
learn.fit_one_cycle(50, cbs=EarlyStoppingCallback(patience=10))

# 4. 减小模型容量

5.4 性能优化技巧

技巧1: DataLoader 优化
python 复制代码
# 多进程加载
dls = TSDataLoaders.from_dsets(
    train_ds, valid_ds,
    bs=32,
    num_workers=4,      # 多进程
    pin_memory=True,    # 加速 GPU 传输
)
技巧2: 混合精度训练
python 复制代码
# 减少 50% 显存,加速 2-3 倍
learn = Learner(dls, model, ...).to_fp16()
技巧3: 梯度累积(小显存时)
python 复制代码
# 有效 batch_size = bs * accumulation_steps
learn.fit_one_cycle(
    epochs,
    lr_max=1e-3,
    # 在自定义 Callback 中实现梯度累积
)
技巧4: 模型剪枝(部署时)
python 复制代码
# 使用 PyTorch 的剪枝工具
import torch.nn.utils.prune as prune

# 剪枝注意力层
prune.l1_unstructured(model.encoder.layers[0].attn, 'weight', amount=0.3)

6. 批判性思考与扩展

6.1 模型的局限性

局限1: 计算复杂度仍然高

分析

  • 即使用了 patching,Attention 的 O(N²) 复杂度仍然存在
  • 对于超长序列(如 seq_len > 10000),仍然吃力

何时不适用

  • 高频传感器数据(如 100kHz 采样率)
  • 长期时序预测(如预测未来 1000 步)

替代方案

  • Informer:稀疏注意力机制
  • Linformer:线性复杂度的注意力
  • Nyströmformer:使用 Nyström 方法近似
局限2: 表格特征的融合方式单一

当前方式:简单的加法或拼接

问题

  • 无法建模复杂的交互(如"年龄 × 收入"的二阶交互)
  • 表格特征只是"静态上下文",未充分利用

改进方向

  • Cross-Attention:让表格特征作为 Key/Value,时序特征作为 Query

  • FiLM 层 (Feature-wise Linear Modulation):用表格特征调制时序特征
    Z f u s i o n = γ ( E t a b ) ⊙ Z t s + β ( E t a b ) \mathbf{Z}{fusion} = \gamma(\mathbf{E}{tab}) \odot \mathbf{Z}{ts} + \beta(\mathbf{E}{tab}) Zfusion=γ(Etab)⊙Zts+β(Etab)

局限3: 多任务学习支持不足

问题:只能预测单一目标

扩展:Multi-Task Learning

python 复制代码
# 同时预测多个目标
model = TSTabFusionTransformer(
    ...,
    c_out_1=2,   # 分类任务
    c_out_2=1,   # 回归任务
)

6.2 模型变体与扩展

扩展1: 时序 + 图像 + 表格

场景:医疗诊断(心电图 + CT 影像 + 病历信息)

架构

复制代码
┌─────────┐   ┌─────────┐   ┌─────────┐
│TS Trans │   │ViT      │   │Tab MLP  │
└────┬────┘   └────┬────┘   └────┬────┘
     │             │              │
     └─────────────┴──────────────┘
               Cross-Attention
                     ▼
                 Final MLP
扩展2: 自监督预训练

动机:时序 + 表格的标注数据稀缺

方法

  • Masked Patch Prediction:遮盖部分 patch,预测被遮盖内容
  • Contrastive Learning:拉近同一样本的时序和表格嵌入
python 复制代码
# 伪代码
# 1. 遮盖 15% 的 patches
masked_x_ts = mask_random_patches(x_ts)

# 2. 预训练任务
pred_patches = model(masked_x_ts, x_tab)
loss = MSE(pred_patches, original_patches)
扩展3: 可解释性增强

问题:Transformer 是"黑盒",难以解释

方法1 - Attention Visualization

python 复制代码
# 可视化注意力权重
attention_weights = model.get_attention_weights(x)
plot_attention_heatmap(attention_weights)

方法2 - Feature Importance

python 复制代码
# SHAP / Integrated Gradients
import shap
explainer = shap.DeepExplainer(model, background_data)
shap_values = explainer.shap_values(x)

6.3 适用场景与选型建议

推荐使用 TSTabFusionTransformer 的场景

金融预测 :股价预测 + 公司基本面

医疗诊断 :生理信号监测 + 患者信息

工业预测 :传感器数据 + 设备属性

推荐系统:用户行为序列 + 用户画像

不推荐的场景

纯时序任务 :如果没有表格数据,用普通 TST 即可

超长序列 :seq_len > 10000,考虑其他架构

实时推理:对延迟极敏感的场景(如高频交易)

与其他模型的对比
模型 时序能力 表格能力 融合方式 计算复杂度
TSTabFusionTransformer ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ 深度融合 O(N²)
LSTM + MLP 拼接 ⭐⭐⭐ ⭐⭐ 浅层融合 O(N)
TabTransformer ⭐⭐⭐⭐⭐ 无时序 O(M²)
XGBoost(时序特征工程) ⭐⭐ ⭐⭐⭐⭐⭐ 手工融合 O(NK)

7. 小结与学习路径建议

7.1 核心要点回顾

  1. 问题本质:真实世界的预测任务需要同时利用动态时序信息和静态表格属性

  2. 核心创新

    • 在 Transformer 的每一层深度融入表格信息
    • 用 Patching 降低长序列的计算复杂度
    • 为分类和连续特征设计了独立的嵌入机制
  3. 数学原理

    • Attention 机制捕捉时序依赖
    • Embedding 技术处理离散特征
    • 残差连接 + LayerNorm 保证训练稳定性
  4. 工程实践

    • 显存优化:减小模型、FP16、梯度累积
    • 数据处理:缺失值、padding、归一化
    • 防止过拟合:Dropout、Early Stopping、数据增强
  5. 适用场景:金融、医疗、工业等需要多模态融合的领域

7.2 进阶学习路径

路径1: 深化 Transformer 理论
  1. 必读论文

  2. 动手实践

    • 从零实现 Multi-Head Attention
    • 理解 Positional Encoding 的作用
    • 可视化 Attention 权重
路径2: 多模态学习
  1. 相关模型

    • CLIP(图像 + 文本)
    • Flamingo(视觉 + 语言)
    • Perceiver(通用多模态架构)
  2. 核心问题

    • 如何对齐不同模态的特征空间?
    • Cross-Modal Attention 的设计
    • 多模态预训练策略
路径3: 时序预测前沿
  1. 新型架构

    • Informer(长序列预测)
    • N-BEATS(可解释的时序模型)
    • TimesNet(2D 卷积建模时序)
  2. 实际项目

    • Kaggle 时序竞赛
    • 复现 SOTA 模型
    • 在真实数据集上调参
路径4: 工程优化
  1. 模型压缩

    • 知识蒸馏(Teacher-Student)
    • 量化(INT8/INT4)
    • 剪枝(Pruning)
  2. 部署实践

    • ONNX 导出
    • TensorRT 优化
    • 边缘设备部署

7.3 推荐资源

代码库

论文

实战项目


附录:快速参考

A. 模型参数速查表

参数 含义 典型值 影响
d_model 嵌入维度 32, 64, 128 ↑参数量,↑容量
n_layers Transformer 层数 2, 4, 6 ↑深度,↑表达能力
n_heads 注意力头数 4, 8, 16 ↑多视角,↑参数
res_dropout 残差 Dropout 0.1, 0.2 ↑正则化
patch_len Patch 长度 16, 32, 64 ↑效率,↓细节

B. 常见错误诊断

错误信息 可能原因 解决方案
CUDA out of memory 模型/batch 太大 减小 d_model/batch_size,启用 FP16
RuntimeError: size mismatch 输入维度错误 检查 c_in, seq_len, classes
KeyError in classes 分类特征缺失 确保所有分类特征都在 classes 中
验证集 loss 不下降 过拟合 增大 Dropout,Early Stopping
训练很慢 CPU 计算/IO 瓶颈 启用 pin_memory, num_workers

C. 调参经验法则

  1. 首次尝试d_model=32, n_layers=2, n_heads=4, bs=32
  2. 如果欠拟合:逐步增大模型(翻倍增长)
  3. 如果过拟合:增大 Dropout(0.1 → 0.2 → 0.3)
  4. 如果显存不足:减半 d_model,增大 patch_len
  5. 学习率 :用 lr_find() 自动搜索,一般在 1e-4 到 1e-2 之间
相关推荐
冬夜戏雪2 小时前
【学习日记】【12.18】【整理了下论文相关的计划】
学习
speop2 小时前
【datawhale组队学习】|TASK02|结构化输入
网络·人工智能·学习
炽烈小老头2 小时前
【每天学习一点算法 2025/12/18】对称二叉树
学习·算法
蒙奇D索大2 小时前
【数据结构】考研408 | 开放定址法精讲:连续探测的艺术与代价
数据结构·笔记·考研·改行学it
子夜江寒2 小时前
pandas基础操作
学习·pandas
EveryPossible2 小时前
宽度撑开容器
学习
OpenBayes2 小时前
教程上新丨微软开源VibeVoice,可实现90分钟4角色自然对话
人工智能·深度学习·机器学习·大语言模型·tts·对话生成·语音生成
深蓝海拓2 小时前
PySide6从0开始学习的笔记(八) 控件(Widget)之QSlider(滑动条)
笔记·python·qt·学习·pyqt
TL滕2 小时前
从0开始学算法——第十九天(并查集练习)
笔记·学习·算法