TSTabFusionTransformer 深度学习学习笔记
目标读者:具备基础深度学习知识,希望理解如何融合时序数据与表格数据进行预测的学习者
1. 问题动机与直觉理解
1.1 现实世界的数据是多模态的
想象你是一位金融分析师,需要预测某只股票未来的走势。你手上有什么数据?
- 时序数据:过去 500 天的股价、成交量、技术指标(MACD、RSI等)
- 表格数据:公司的财务指标(市盈率、负债率)、行业分类、地理位置
或者你是医疗研究者,需要预测病人的健康风险:
- 时序数据:过去一周的心率、血压、体温监测曲线
- 表格数据:年龄、性别、既往病史、基因标记
关键观察 :真实世界的预测任务往往需要同时利用动态的时序信息 和静态的表格属性。
1.2 传统方法的困境
方案1:只用时序模型(如纯 Transformer)
python
# 只用时序数据
model = TimeSeriesTransformer(c_in=6, seq_len=500)
pred = model(x_timeseries) # 忽略了表格特征
问题:丢失了宝贵的静态信息!比如股票所属行业、病人的年龄性别等关键信息。
方案2:简单拼接后送入模型
python
# 将表格特征复制到每个时间步
x_combined = torch.cat([x_timeseries, x_tabular.unsqueeze(-1).repeat(1,1,seq_len)], dim=1)
问题:
- 破坏了时序数据的结构性
- 表格特征被强行"时序化",不自然
- 参数冗余(表格特征在每个时间步重复)
方案3:两个模型分别处理后融合
python
# 分别处理
ts_features = ts_model(x_timeseries)
tab_features = tab_model(x_tabular)
# 后期融合
combined = torch.cat([ts_features, tab_features], dim=1)
pred = final_mlp(combined)
问题:
- 两个模型独立训练,无法"互相学习"
- 融合发生在最后阶段,交互不充分
- 时序特征无法根据表格信息调整注意力
1.3 TSTabFusionTransformer 的核心思想
"让时序 Transformer 在处理时序数据时,能够'看到'表格信息,并据此调整注意力模式"
直觉上:
- 时序分支:用 Transformer 提取时序模式(趋势、周期、异常)
- 表格分支:用 MLP 提取表格特征的嵌入
- 深度融合:在 Transformer 的每一层,都将表格信息融入到时序特征中
类比:就像你在看股价走势图时,脑海中同时记着"这是一家科技公司"、"市盈率偏高"等信息,这些静态知识会影响你对时序模式的解读。
2. 核心思想与概念拆解
2.1 整体架构概览
输入:
┌─────────────────┐ ┌──────────────────┐
│ 时序数据 x_ts │ │ 表格数据 x_tab │
│ (bs, c_in, L) │ │ 分类 + 连续特征 │
└────────┬────────┘ └────────┬─────────┘
│ │
▼ ▼
┌─────────────────┐ ┌──────────────────┐
│ Patching + │ │ Embedding │
│ Position Encode │ │ (类别嵌入+线性层) │
└────────┬────────┘ └────────┬─────────┘
│ │
│ 融合点 ↓ │
│ ┌────────────────┐ │
└────→ Transformer ←───┘
│ Encoder Layers │
│ (多次融合) │
└────────┬───────┘
▼
┌──────────────────┐
│ Global Pooling │
└────────┬─────────┘
▼
┌──────────────────┐
│ MLP Head │
└────────┬─────────┘
▼
输出 (bs, c_out)
2.2 关键组件拆解
组件1:时序数据的 Patching
为什么需要 Patching?
原始时序数据可能很长(如 seq_len=500),直接送入 Transformer 会导致:
- 注意力矩阵 O(L²) 的复杂度爆炸
- 过度关注局部细节,难以捕捉全局模式
Patching 的思想(借鉴 Vision Transformer):
将长序列切分成小块(patch),每个 patch 看作一个"token"。
原始序列: [x1, x2, x3, ..., x500] (500 个时间步)
↓ patching (patch_len=16, stride=16)
Patches: [patch1, patch2, ..., patch31] (31 个 patch)
优势:
- 序列长度从 500 降到 31,注意力复杂度大幅降低
- 每个 patch 代表一个局部时间窗口,Transformer 关注的是"窗口级别"的模式
组件2:表格数据的处理
表格数据包含两类特征:
-
分类特征(categorical):如"行业类别"、"教育程度"
- 处理方式:Embedding 层
python# 例如 "教育程度" 有 17 种取值 education_emb = nn.Embedding(num_classes=17, embedding_dim=d_model) -
连续特征(continuous):如"年龄"、"收入"
- 处理方式:Linear 层 + 激活函数
pythoncont_emb = nn.Linear(n_cont, d_model) -
融合:将所有表格特征嵌入拼接后,再过一层 MLP
python# (bs, n_cat*d_model + n_cont*d_model) -> (bs, d_model) tab_features = MLP([cat_emb1, cat_emb2, ..., cont_emb])
组件3:特征融合机制
核心问题:如何在 Transformer 中融入表格信息?
方案:在每个 Transformer Block 中,将表格特征加到时序特征上
python
# 伪代码
for layer in transformer_layers:
# 1. 时序自注意力
ts_features = MultiHeadAttention(ts_features) # (bs, num_patches, d_model)
# 2. 融入表格信息
# 将表格特征广播到每个 patch
tab_broadcast = tab_features.unsqueeze(1).expand(-1, num_patches, -1)
ts_features = ts_features + tab_broadcast # 残差连接
# 3. FFN
ts_features = FeedForward(ts_features)
直觉理解:
- 表格特征像"全局上下文",告诉 Transformer "这是什么类型的样本"
- 时序特征在每一层都能"感知"到这个全局上下文
- 类似于你在看图表时,脑中始终记着"这是科技股"这个背景知识
3. 数学建模与理论基础
3.1 问题形式化
给定:
-
时序数据: X t s ∈ R B × C i n × L \mathbf{X}{ts} \in \mathbb{R}^{B \times C{in} \times L} Xts∈RB×Cin×L
- B B B: batch size
- C i n C_{in} Cin: 通道数(如 6 个技术指标)
- L L L: 序列长度(如 500 个时间步)
-
表格数据:
- 分类特征: X c a t ∈ N B × N c a t \mathbf{X}{cat} \in \mathbb{N}^{B \times N{cat}} Xcat∈NB×Ncat(每个值是类别索引)
- 连续特征: X c o n t ∈ R B × N c o n t \mathbf{X}{cont} \in \mathbb{R}^{B \times N{cont}} Xcont∈RB×Ncont
目标:
- 预测输出: Y ∈ R B × C o u t \mathbf{Y} \in \mathbb{R}^{B \times C_{out}} Y∈RB×Cout(如分类任务 C o u t = C_{out}= Cout= 类别数,回归任务 C o u t = 1 C_{out}=1 Cout=1)
3.2 时序分支的数学表示
Step 1: Patching
将时序数据切分成 patches:
X t s ∈ R B × C i n × L → patch X p a t c h e s ∈ R B × N p a t c h e s × ( C i n × P ) \mathbf{X}{ts} \in \mathbb{R}^{B \times C{in} \times L} \xrightarrow{\text{patch}} \mathbf{X}{patches} \in \mathbb{R}^{B \times N{patches} \times (C_{in} \times P)} Xts∈RB×Cin×Lpatch Xpatches∈RB×Npatches×(Cin×P)
其中:
- N p a t c h e s = ⌊ L − P S ⌋ + 1 N_{patches} = \lfloor \frac{L - P}{S} \rfloor + 1 Npatches=⌊SL−P⌋+1
- P P P: patch length(如 16)
- S S S: stride(如 16)
Step 2: Patch Embedding
Z ( 0 ) = Linear ( X p a t c h e s ) + E p o s \mathbf{Z}^{(0)} = \text{Linear}(\mathbf{X}{patches}) + \mathbf{E}{pos} Z(0)=Linear(Xpatches)+Epos
- Linear : R C i n × P → R d m o d e l \text{Linear}: \mathbb{R}^{C_{in} \times P} \to \mathbb{R}^{d_{model}} Linear:RCin×P→Rdmodel(投影到嵌入空间)
- E p o s \mathbf{E}_{pos} Epos: 位置编码(Learnable 或 Sinusoidal)
Step 3: Transformer Encoder
对于第 l l l 层:
Z ~ ( l ) = LayerNorm ( Z ( l − 1 ) ) Z ′ ( l ) = MultiHeadAttention ( Z ~ ( l ) ) + Z ( l − 1 ) (残差) Z ( l ) = FFN ( LayerNorm ( Z ′ ( l ) ) ) + Z ′ ( l ) \begin{aligned} \tilde{\mathbf{Z}}^{(l)} &= \text{LayerNorm}(\mathbf{Z}^{(l-1)}) \\ \mathbf{Z'}^{(l)} &= \text{MultiHeadAttention}(\tilde{\mathbf{Z}}^{(l)}) + \mathbf{Z}^{(l-1)} \quad \text{(残差)} \\ \mathbf{Z}^{(l)} &= \text{FFN}(\text{LayerNorm}(\mathbf{Z'}^{(l)})) + \mathbf{Z'}^{(l)} \end{aligned} Z~(l)Z′(l)Z(l)=LayerNorm(Z(l−1))=MultiHeadAttention(Z~(l))+Z(l−1)(残差)=FFN(LayerNorm(Z′(l)))+Z′(l)
3.3 表格分支的数学表示
分类特征嵌入
对于第 i i i 个分类特征:
E c a t ( i ) = Embedding ( X c a t ( i ) ) ∈ R B × d m o d e l \mathbf{E}{cat}^{(i)} = \text{Embedding}(\mathbf{X}{cat}^{(i)}) \in \mathbb{R}^{B \times d_{model}} Ecat(i)=Embedding(Xcat(i))∈RB×dmodel
连续特征嵌入
E c o n t = ReLU ( Linear ( X c o n t ) ) ∈ R B × d m o d e l \mathbf{E}{cont} = \text{ReLU}(\text{Linear}(\mathbf{X}{cont})) \in \mathbb{R}^{B \times d_{model}} Econt=ReLU(Linear(Xcont))∈RB×dmodel
表格特征融合
E t a b = MLP ( [ E c a t ( 1 ) , E c a t ( 2 ) , . . . , E c o n t ] ) ∈ R B × d m o d e l \mathbf{E}{tab} = \text{MLP}([\mathbf{E}{cat}^{(1)}, \mathbf{E}{cat}^{(2)}, ..., \mathbf{E}{cont}]) \in \mathbb{R}^{B \times d_{model}} Etab=MLP([Ecat(1),Ecat(2),...,Econt])∈RB×dmodel
3.4 多模态融合机制
关键修改:在 Transformer 的每一层中加入表格信息
修改后的第 l l l 层:
Z ~ ( l ) = LayerNorm ( Z ( l − 1 ) + E t a b ⊙ 1 N p a t c h e s ) ← 融合点 Z ′ ( l ) = MultiHeadAttention ( Z ~ ( l ) ) + Z ( l − 1 ) Z ( l ) = FFN ( LayerNorm ( Z ′ ( l ) ) ) + Z ′ ( l ) \begin{aligned} \tilde{\mathbf{Z}}^{(l)} &= \text{LayerNorm}(\mathbf{Z}^{(l-1)} + \mathbf{E}{tab} \odot \mathbf{1}{N_{patches}}) \quad \text{← 融合点} \\ \mathbf{Z'}^{(l)} &= \text{MultiHeadAttention}(\tilde{\mathbf{Z}}^{(l)}) + \mathbf{Z}^{(l-1)} \\ \mathbf{Z}^{(l)} &= \text{FFN}(\text{LayerNorm}(\mathbf{Z'}^{(l)})) + \mathbf{Z'}^{(l)} \end{aligned} Z~(l)Z′(l)Z(l)=LayerNorm(Z(l−1)+Etab⊙1Npatches)← 融合点=MultiHeadAttention(Z~(l))+Z(l−1)=FFN(LayerNorm(Z′(l)))+Z′(l)
其中:
- E t a b ⊙ 1 N p a t c h e s \mathbf{E}{tab} \odot \mathbf{1}{N_{patches}} Etab⊙1Npatches: 将表格嵌入广播到每个 patch
- 在每一层都进行融合,实现深度交互
3.5 输出层
Z g l o b a l = GlobalPooling ( Z ( L ) ) ∈ R B × d m o d e l Y ^ = MLPHead ( Z g l o b a l ) ∈ R B × C o u t \begin{aligned} \mathbf{Z}{global} &= \text{GlobalPooling}(\mathbf{Z}^{(L)}) \in \mathbb{R}^{B \times d{model}} \\ \hat{\mathbf{Y}} &= \text{MLPHead}(\mathbf{Z}{global}) \in \mathbb{R}^{B \times C{out}} \end{aligned} ZglobalY^=GlobalPooling(Z(L))∈RB×dmodel=MLPHead(Zglobal)∈RB×Cout
Global Pooling 可以是:
- Mean Pooling : 1 N p a t c h e s ∑ i = 1 N p a t c h e s Z i ( L ) \frac{1}{N_{patches}} \sum_{i=1}^{N_{patches}} \mathbf{Z}_i^{(L)} Npatches1∑i=1NpatchesZi(L)
- Max Pooling : max i = 1 N p a t c h e s Z i ( L ) \max_{i=1}^{N_{patches}} \mathbf{Z}_i^{(L)} maxi=1NpatchesZi(L)
- Attention Pooling: 学习权重进行加权平均
4. 优化方法与训练机制
4.1 损失函数
根据任务类型选择:
回归任务 :
L = MSE ( Y ^ , Y ) = 1 B ∑ i = 1 B ∥ y ^ i − y i ∥ 2 \mathcal{L} = \text{MSE}(\hat{\mathbf{Y}}, \mathbf{Y}) = \frac{1}{B} \sum_{i=1}^{B} \|\hat{\mathbf{y}}_i - \mathbf{y}_i\|^2 L=MSE(Y^,Y)=B1i=1∑B∥y^i−yi∥2
分类任务 :
L = CrossEntropy ( Y ^ , Y ) = − 1 B ∑ i = 1 B ∑ c = 1 C o u t y i , c log ( y ^ i , c ) \mathcal{L} = \text{CrossEntropy}(\hat{\mathbf{Y}}, \mathbf{Y}) = -\frac{1}{B} \sum_{i=1}^{B} \sum_{c=1}^{C_{out}} y_{i,c} \log(\hat{y}_{i,c}) L=CrossEntropy(Y^,Y)=−B1i=1∑Bc=1∑Coutyi,clog(y^i,c)
4.2 梯度计算与反向传播
注意力层的梯度
多头注意力的梯度计算(以单头为例):
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right) \mathbf{V} Attention(Q,K,V)=softmax(dk QKT)V
反向传播时:
∂ L ∂ Q = ∂ L ∂ A ⋅ ∂ A ∂ Q \frac{\partial \mathcal{L}}{\partial \mathbf{Q}} = \frac{\partial \mathcal{L}}{\partial \mathbf{A}} \cdot \frac{\partial \mathbf{A}}{\partial \mathbf{Q}} ∂Q∂L=∂A∂L⋅∂Q∂A
其中 A = softmax ( Q K T / d k ) \mathbf{A} = \text{softmax}(\mathbf{Q}\mathbf{K}^T / \sqrt{d_k}) A=softmax(QKT/dk ) 的梯度涉及 Jacobian 矩阵(PyTorch 自动处理)。
表格嵌入的梯度
分类特征的嵌入层:
∂ L ∂ E e m b [ i ] = ∑ b ∈ batch where x c a t ( b ) = i ∂ L ∂ E c a t ( b ) \frac{\partial \mathcal{L}}{\partial \mathbf{E}{emb}[i]} = \sum{b \in \text{batch where } x_{cat}^{(b)} = i} \frac{\partial \mathcal{L}}{\partial \mathbf{E}_{cat}^{(b)}} ∂Eemb[i]∂L=b∈batch where xcat(b)=i∑∂Ecat(b)∂L
关键点:同一个类别的样本共享嵌入向量,梯度会累积。
4.3 优化策略
学习率调度
One-Cycle 策略(fastai 推荐):
学习率
↑
max_lr ┤ /\
│ / \
│ / \
│ / \___
min_lr ┤ / \
└────────────────→ epoch
- 前半段:学习率从 min_lr 升到 max_lr(探索阶段)
- 后半段:学习率从 max_lr 降到 min_lr(精细调整阶段)
优势:
- 避免过早陷入局部最优
- 后期低学习率保证收敛
正则化技术
-
Dropout:在注意力层和 FFN 中应用
pythonres_dropout=0.1 # 残差连接的 dropout fc_dropout=0.2 # 全连接层的 dropout -
Weight Decay (L2 正则):
\\mathcal{L}*{total} = \\mathcal{L} + \\lambda \\sum*{w \\in \\mathbf{W}} w\^2 ``` ```
4.4 数值稳定性
Layer Normalization
在每个子层前进行归一化:
LayerNorm ( x ) = γ ⋅ x − μ σ 2 + ϵ + β \text{LayerNorm}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γ⋅σ2+ϵ x−μ+β
为什么重要:
- 缓解梯度消失/爆炸
- 加速收敛
- 减少对初始化的敏感性
Attention 计算的稳定性
softmax ( x i ) = exp ( x i ) ∑ j exp ( x j ) \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} softmax(xi)=∑jexp(xj)exp(xi)
数值问题 : exp ( x ) \exp(x) exp(x) 对大数值不稳定
解决方案 :减去最大值
softmax ( x i ) = exp ( x i − max ( x ) ) ∑ j exp ( x j − max ( x ) ) \text{softmax}(x_i) = \frac{\exp(x_i - \max(x))}{\sum_j \exp(x_j - \max(x))} softmax(xi)=∑jexp(xj−max(x))exp(xi−max(x))
(PyTorch 已自动处理)
5. 代码实现与工程视角
5.1 完整的使用示例
python
import torch
from tsai.all import *
# ============ 1. 数据准备 ============
# 时序数据
x_ts = torch.randn(64, 6, 500) # (batch, channels, seq_len)
# 表格数据 - 分类特征
x_cat = torch.randint(0, 3, (64, 7)) # (batch, n_categorical)
# 表格数据 - 连续特征
x_cont = torch.randn(64, 3) # (batch, n_continuous)
# ============ 2. 定义表格特征信息 ============
# 分类特征的类别映射
classes = {
'education': ['#na#', '10th', '11th', 'Bachelors', ...],
'occupation': ['#na#', 'Tech', 'Sales', 'Manager', ...],
# ... 每个分类特征的所有可能取值
}
# 连续特征的名称
cont_names = ['age', 'income', 'hours_per_week']
# ============ 3. 创建模型 ============
model = TSTabFusionTransformer(
c_in=6, # 时序通道数
c_out=2, # 输出类别数(二分类)
seq_len=500, # 序列长度
classes=classes, # 分类特征信息
cont_names=cont_names,# 连续特征名称
# Transformer 配置
d_model=64, # 嵌入维度
n_layers=3, # Transformer 层数
n_heads=8, # 注意力头数
# 正则化
res_dropout=0.1, # 残差 dropout
fc_dropout=0.2, # 全连接 dropout
)
# ============ 4. 前向传播 ============
# 输入格式:(时序数据, (分类数据, 连续数据))
x = (x_ts, (x_cat, x_cont))
output = model(x) # (batch_size, c_out)
print(f"输出形状: {output.shape}") # torch.Size([64, 2])
5.2 关键代码设计解析
设计1: 灵活的输入格式
python
# 模型接受的输入是一个 tuple
x = (x_timeseries, (x_categorical, x_continuous))
为什么这样设计?
- 清晰区分三种不同类型的数据
- 方便 DataLoader 的组织(tsai 的
TSDataLoaders支持这种格式) - 可以轻松处理"只有时序"或"只有表格"的情况
设计2: classes 字典的作用
python
classes = {
'education': ['#na#', '10th', '11th', ...], # 17 种取值
'occupation': ['#na#', 'Tech', 'Sales', ...], # 16 种取值
}
内部实现:
python
# 为每个分类特征创建独立的 Embedding
for feat_name, feat_classes in classes.items():
n_classes = len(feat_classes)
self.cat_embeddings[feat_name] = nn.Embedding(n_classes, d_model)
为什么不用一个统一的 Embedding?
- 不同特征的类别数可能差异很大(如性别2类 vs 职业16类)
- 独立 Embedding 更灵活,避免参数浪费
设计3: 混合精度训练支持
python
# 在 tsai 中启用 FP16
learn = Learner(dls, model, ...).to_fp16()
内部原理:
- 模型参数保持 FP32 精度
- 前向传播时转为 FP16 计算(节省显存)
- 梯度累积时转回 FP32(保证精度)
5.3 工程实践中的常见坑
坑1: 表格数据的缺失值处理
问题:原始数据中分类特征可能有缺失值
解决方案:
python
# 在 classes 中添加 '#na#' 作为第一个类别
classes = {
'education': ['#na#', '10th', '11th', ...], # '#na#' 对应索引 0
}
# 数据预处理时,将缺失值映射到 0
df['education'] = df['education'].fillna('#na#')
坑2: 序列长度不一致
问题:不同样本的时序长度可能不同
解决方案1 - Padding:
python
# 补齐到统一长度
x_ts_padded = torch.nn.functional.pad(x_ts, (0, max_len - x_ts.shape[-1]))
解决方案2 - Truncation:
python
# 截断到统一长度
x_ts_truncated = x_ts[:, :, :max_len]
解决方案3 - Masking(更优雅):
python
# 在注意力计算时使用 mask
mask = create_padding_mask(seq_lengths)
attention_output = attention(x, mask=mask)
坑3: 显存爆炸
根本原因:
- 模型参数量过大(参见您之前的问题!)
- Attention 矩阵 O(L²) 复杂度
优化策略:
python
# 1. 减小模型
model = TSTabFusionTransformer(
...,
d_model=32, # 从 64 降到 32
n_layers=2, # 从 6 降到 2
n_heads=4, # 从 8 降到 4
)
# 2. 增大 patch_len(内部自动处理)
# patch_len 越大,有效序列长度越短
# seq_len=500, patch_len=25 -> 有效长度=20
# 3. 启用梯度检查点(如果模型支持)
from torch.utils.checkpoint import checkpoint
坑4: 过拟合
表现:训练集表现好,验证集差
诊断:
python
# 观察训练/验证曲线
learn.recorder.plot_loss()
解决方案:
python
# 1. 增大 Dropout
model = TSTabFusionTransformer(..., res_dropout=0.2, fc_dropout=0.3)
# 2. 数据增强(时序数据)
tfms = [TSMagScale(), TSTimeWarp()] # tsai 内置
dls = get_ts_dls(..., batch_tfms=tfms)
# 3. Early Stopping
learn.fit_one_cycle(50, cbs=EarlyStoppingCallback(patience=10))
# 4. 减小模型容量
5.4 性能优化技巧
技巧1: DataLoader 优化
python
# 多进程加载
dls = TSDataLoaders.from_dsets(
train_ds, valid_ds,
bs=32,
num_workers=4, # 多进程
pin_memory=True, # 加速 GPU 传输
)
技巧2: 混合精度训练
python
# 减少 50% 显存,加速 2-3 倍
learn = Learner(dls, model, ...).to_fp16()
技巧3: 梯度累积(小显存时)
python
# 有效 batch_size = bs * accumulation_steps
learn.fit_one_cycle(
epochs,
lr_max=1e-3,
# 在自定义 Callback 中实现梯度累积
)
技巧4: 模型剪枝(部署时)
python
# 使用 PyTorch 的剪枝工具
import torch.nn.utils.prune as prune
# 剪枝注意力层
prune.l1_unstructured(model.encoder.layers[0].attn, 'weight', amount=0.3)
6. 批判性思考与扩展
6.1 模型的局限性
局限1: 计算复杂度仍然高
分析:
- 即使用了 patching,Attention 的 O(N²) 复杂度仍然存在
- 对于超长序列(如 seq_len > 10000),仍然吃力
何时不适用:
- 高频传感器数据(如 100kHz 采样率)
- 长期时序预测(如预测未来 1000 步)
替代方案:
- Informer:稀疏注意力机制
- Linformer:线性复杂度的注意力
- Nyströmformer:使用 Nyström 方法近似
局限2: 表格特征的融合方式单一
当前方式:简单的加法或拼接
问题:
- 无法建模复杂的交互(如"年龄 × 收入"的二阶交互)
- 表格特征只是"静态上下文",未充分利用
改进方向:
-
Cross-Attention:让表格特征作为 Key/Value,时序特征作为 Query
-
FiLM 层 (Feature-wise Linear Modulation):用表格特征调制时序特征
Z f u s i o n = γ ( E t a b ) ⊙ Z t s + β ( E t a b ) \mathbf{Z}{fusion} = \gamma(\mathbf{E}{tab}) \odot \mathbf{Z}{ts} + \beta(\mathbf{E}{tab}) Zfusion=γ(Etab)⊙Zts+β(Etab)
局限3: 多任务学习支持不足
问题:只能预测单一目标
扩展:Multi-Task Learning
python
# 同时预测多个目标
model = TSTabFusionTransformer(
...,
c_out_1=2, # 分类任务
c_out_2=1, # 回归任务
)
6.2 模型变体与扩展
扩展1: 时序 + 图像 + 表格
场景:医疗诊断(心电图 + CT 影像 + 病历信息)
架构:
┌─────────┐ ┌─────────┐ ┌─────────┐
│TS Trans │ │ViT │ │Tab MLP │
└────┬────┘ └────┬────┘ └────┬────┘
│ │ │
└─────────────┴──────────────┘
Cross-Attention
▼
Final MLP
扩展2: 自监督预训练
动机:时序 + 表格的标注数据稀缺
方法:
- Masked Patch Prediction:遮盖部分 patch,预测被遮盖内容
- Contrastive Learning:拉近同一样本的时序和表格嵌入
python
# 伪代码
# 1. 遮盖 15% 的 patches
masked_x_ts = mask_random_patches(x_ts)
# 2. 预训练任务
pred_patches = model(masked_x_ts, x_tab)
loss = MSE(pred_patches, original_patches)
扩展3: 可解释性增强
问题:Transformer 是"黑盒",难以解释
方法1 - Attention Visualization:
python
# 可视化注意力权重
attention_weights = model.get_attention_weights(x)
plot_attention_heatmap(attention_weights)
方法2 - Feature Importance:
python
# SHAP / Integrated Gradients
import shap
explainer = shap.DeepExplainer(model, background_data)
shap_values = explainer.shap_values(x)
6.3 适用场景与选型建议
推荐使用 TSTabFusionTransformer 的场景
✅ 金融预测 :股价预测 + 公司基本面
✅ 医疗诊断 :生理信号监测 + 患者信息
✅ 工业预测 :传感器数据 + 设备属性
✅ 推荐系统:用户行为序列 + 用户画像
不推荐的场景
❌ 纯时序任务 :如果没有表格数据,用普通 TST 即可
❌ 超长序列 :seq_len > 10000,考虑其他架构
❌ 实时推理:对延迟极敏感的场景(如高频交易)
与其他模型的对比
| 模型 | 时序能力 | 表格能力 | 融合方式 | 计算复杂度 |
|---|---|---|---|---|
| TSTabFusionTransformer | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | 深度融合 | O(N²) |
| LSTM + MLP 拼接 | ⭐⭐⭐ | ⭐⭐ | 浅层融合 | O(N) |
| TabTransformer | ⭐ | ⭐⭐⭐⭐⭐ | 无时序 | O(M²) |
| XGBoost(时序特征工程) | ⭐⭐ | ⭐⭐⭐⭐⭐ | 手工融合 | O(NK) |
7. 小结与学习路径建议
7.1 核心要点回顾
-
问题本质:真实世界的预测任务需要同时利用动态时序信息和静态表格属性
-
核心创新:
- 在 Transformer 的每一层深度融入表格信息
- 用 Patching 降低长序列的计算复杂度
- 为分类和连续特征设计了独立的嵌入机制
-
数学原理:
- Attention 机制捕捉时序依赖
- Embedding 技术处理离散特征
- 残差连接 + LayerNorm 保证训练稳定性
-
工程实践:
- 显存优化:减小模型、FP16、梯度累积
- 数据处理:缺失值、padding、归一化
- 防止过拟合:Dropout、Early Stopping、数据增强
-
适用场景:金融、医疗、工业等需要多模态融合的领域
7.2 进阶学习路径
路径1: 深化 Transformer 理论
-
必读论文:
- Attention Is All You Need(原始 Transformer)
- BERT(预训练思想)
- Vision Transformer(Patching 的来源)
-
动手实践:
- 从零实现 Multi-Head Attention
- 理解 Positional Encoding 的作用
- 可视化 Attention 权重
路径2: 多模态学习
-
相关模型:
- CLIP(图像 + 文本)
- Flamingo(视觉 + 语言)
- Perceiver(通用多模态架构)
-
核心问题:
- 如何对齐不同模态的特征空间?
- Cross-Modal Attention 的设计
- 多模态预训练策略
路径3: 时序预测前沿
-
新型架构:
- Informer(长序列预测)
- N-BEATS(可解释的时序模型)
- TimesNet(2D 卷积建模时序)
-
实际项目:
- Kaggle 时序竞赛
- 复现 SOTA 模型
- 在真实数据集上调参
路径4: 工程优化
-
模型压缩:
- 知识蒸馏(Teacher-Student)
- 量化(INT8/INT4)
- 剪枝(Pruning)
-
部署实践:
- ONNX 导出
- TensorRT 优化
- 边缘设备部署
7.3 推荐资源
代码库:
论文:
- A Survey on Deep Learning for Time Series Forecasting
- Multimodal Learning with Transformers: A Survey
实战项目:
附录:快速参考
A. 模型参数速查表
| 参数 | 含义 | 典型值 | 影响 |
|---|---|---|---|
d_model |
嵌入维度 | 32, 64, 128 | ↑参数量,↑容量 |
n_layers |
Transformer 层数 | 2, 4, 6 | ↑深度,↑表达能力 |
n_heads |
注意力头数 | 4, 8, 16 | ↑多视角,↑参数 |
res_dropout |
残差 Dropout | 0.1, 0.2 | ↑正则化 |
patch_len |
Patch 长度 | 16, 32, 64 | ↑效率,↓细节 |
B. 常见错误诊断
| 错误信息 | 可能原因 | 解决方案 |
|---|---|---|
CUDA out of memory |
模型/batch 太大 | 减小 d_model/batch_size,启用 FP16 |
RuntimeError: size mismatch |
输入维度错误 | 检查 c_in, seq_len, classes |
KeyError in classes |
分类特征缺失 | 确保所有分类特征都在 classes 中 |
| 验证集 loss 不下降 | 过拟合 | 增大 Dropout,Early Stopping |
| 训练很慢 | CPU 计算/IO 瓶颈 | 启用 pin_memory, num_workers |
C. 调参经验法则
- 首次尝试 :
d_model=32, n_layers=2, n_heads=4, bs=32 - 如果欠拟合:逐步增大模型(翻倍增长)
- 如果过拟合:增大 Dropout(0.1 → 0.2 → 0.3)
- 如果显存不足:减半 d_model,增大 patch_len
- 学习率 :用
lr_find()自动搜索,一般在 1e-4 到 1e-2 之间