【DL】Transformer算法应用

一、Transformer 工程化使用

1.1 技术选型建议

NLP

  • Hugging Face Transformers:预训练模型生态(必须)
  • PyTorch:训练控制(必须)
  • datasets:数据管道(推荐)
  • DeepSpeed / Accelerate:大模型训练

时序预测

  • PyTorch Forecasting:结构化强
  • GluonTS:概率预测强
  • Darts:工程体验最好(推荐)

二、NLP 实战

2.1 文本分类

python 复制代码
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
import numpy as np
import evaluate

# 1. 数据
dataset = load_dataset("imdb")

# 2. tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length")

dataset = dataset.map(preprocess, batched=True)

# 3. 模型
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2
)

# 4. 评估指标
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return metric.compute(predictions=preds, references=labels)

# 5. 训练参数
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100
)

# 6. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"].select(range(2000)),
    eval_dataset=dataset["test"].select(range(1000)),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

三、时序预测

3.1 问题定义

输入:

复制代码
X: [batch, M, N]  

输出:

复制代码
Y: [batch, F, E]

含义:

  • N:输入特征维度
  • M:历史时间长度
  • E:预测变量数
  • F:预测步长(multi-step)

3.2 核心设计

必须解决3件事:

1️⃣ 时间编码

python 复制代码
class TimeEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

2️⃣ Encoder-Decoder

python 复制代码
class Seq2SeqTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, output_dim):
        super().__init__()

        self.input_proj = nn.Linear(input_dim, d_model)
        self.output_proj = nn.Linear(output_dim, d_model)

        self.pos_encoder = TimeEncoding(d_model)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True
        )

        self.fc = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt):
        # src: [B, M, N]
        # tgt: [B, F, E]

        src = self.pos_encoder(self.input_proj(src))
        tgt = self.pos_encoder(self.output_proj(tgt))

        out = self.transformer(src, tgt)

        return self.fc(out)

3️⃣ 训练方式

Teacher Forcing
python 复制代码
tgt_input = y[:, :-1, :]
tgt_output = y[:, 1:, :]

3.3 滑动窗口

python 复制代码
def create_dataset(data, M, F):
    X, Y = [], []
    for i in range(len(data) - M - F):
        X.append(data[i:i+M])
        Y.append(data[i+M:i+M+F])
    return np.array(X), np.array(Y)

3.4 工业优化

多变量处理

  • 静态特征(设备ID)
  • 时间特征(hour/day/week)

loss设计

  • MSE(基础)
  • Quantile Loss(预测区间)

四、Transformer 变体

原始 Transformer(Vaswani 2017)核心瓶颈:

问题 本质
计算复杂度 Self-Attention = O(n²)
长序列退化 注意力稀释
无结构归纳偏置 不懂"时间/局部"
对噪声敏感 全局attention放大噪声

1. Longformer

✔ 解决问题

  • 长序列计算爆炸

✔ 核心改动

👉 Attention结构改变:

复制代码
全连接 → 局部窗口 + 全局token

✔ 改动层级

  • 改的是:Attention Pattern(最关键层)

✔ 效果

场景 表现
NLP 长文档(法律/论文)显著提升
时序 一般(不适合周期建模)

✔ 相比原始Transformer

指标 变化
复杂度 O(n²) → O(n)
精度 ≈ 或略降
稳定性

👉 结论:

工程最稳的长序列优化方案

2. Linformer

✔ 解决问题

  • Attention矩阵过大

✔ 核心改动

👉 低秩分解:

复制代码
K,V → 低维投影

✔ 改动层级

  • 改的是:Attention矩阵表示

✔ 效果

场景 表现
NLP 中等任务可用
时序 较差(信息压缩过头)

✔ 对比原始Transformer

指标 变化
复杂度 O(n²) → O(n)
精度
内存 ↓↓↓

👉 结论:

用空间换精度,不适合高精度场景


3. Performer

✔ 解决问题

  • Attention计算瓶颈

✔ 核心改动

👉 Kernel Trick:

复制代码
Softmax(QK^T) ≈ Φ(Q)Φ(K)^T

✔ 改动层级

  • 改的是:Attention计算方式

✔ 效果

场景 表现
NLP 不稳定
时序 很少用

✔ 对比原始Transformer

指标 变化
复杂度 O(n²) → O(n)
精度 波动
稳定性

👉 结论:

理论漂亮,工程不稳定

2️⃣ 结构表达增强类(解决"表达能力")

4. BERT

✔ 解决问题

  • 单向建模(GPT问题)

✔ 核心改动

👉 双向编码:

复制代码
Masked Language Model

✔ 改动层级

  • 输入mask机制 + 训练目标

✔ 效果

场景 表现
NLP 理解任务SOTA
时序 几乎不用

✔ 对比原始Transformer

指标 变化
表达能力 ↑↑
生成能力

👉 结论:

NLP理解任务标准答案

5. GPT

✔ 解决问题

  • 生成能力不足

✔ 核心改动

👉 单向mask:

复制代码
只能看左边

✔ 改动层级

  • Attention Mask

✔ 效果

场景 表现
NLP 生成任务最强
时序 可用于序列生成

✔ 对比原始Transformer

指标 变化
生成能力 ↑↑
并行性

👉 结论:

生成类任务唯一主流路径

6. DeBERTa

✔ 解决问题

  • Attention表达能力不足

✔ 核心改动

👉 解耦位置编码:

复制代码
content + position 分开建模

✔ 改动层级

  • Attention score计算

✔ 效果

场景 表现
NLP 比BERT更强
时序 有潜力但少用

✔ 对比原始Transformer

指标 变化
表达能力
复杂度

👉 结论:

Attention机制优化的代表

3️⃣ 时序专用类(解决"时间结构")

7. Informer

✔ 解决问题

  • 长序列 + 稀疏性

✔ 核心改动

👉 ProbSparse Attention:

复制代码
只关注重要query

✔ 改动层级

  • Attention采样机制

✔ 效果

场景 表现
NLP 几乎不用
时序 长序列有效

✔ 对比原始Transformer

指标 变化
复杂度
精度
长序列能力

👉 结论:

第一个真正可用的时序Transformer

8. Autoformer

✔ 解决问题

  • 时间序列"趋势+周期"

✔ 核心改动

👉 分解机制:

复制代码
trend + seasonal

👉 自相关替代attention

✔ 改动层级

  • Attention → Auto-Correlation

✔ 效果

场景 表现
NLP 不适用
时序 强周期数据表现极好

✔ 对比原始Transformer

指标 变化
可解释性 ↑↑
稳定性
泛化

👉 结论:

最"懂时间"的Transformer

9. FEDformer

✔ 解决问题

  • 噪声 & 长周期

✔ 核心改动

👉 频域Attention:

复制代码
time → frequency

✔ 改动层级

  • Attention空间改变(时域→频域)

✔ 效果

场景 表现
NLP 无意义
时序 抗噪最强

✔ 对比原始Transformer

指标 变化
抗噪能力 ↑↑
长周期建模
复杂度

👉 结论:

工业噪声数据首选

10. PatchTST

✔ 解决问题

  • Transformer不适合局部结构

✔ 核心改动

👉 Patch化(类似ViT):

复制代码
时间序列 → patch token

✔ 改动层级

  • 输入表示(非常关键)

✔ 效果

场景 表现
NLP 不适用
时序 当前SOTA之一

✔ 对比原始Transformer

指标 变化
精度 ↑↑
局部建模
泛化

👉 结论:

目前最强通用时序Transformer

相关推荐
2301_795741792 小时前
C++中的代理模式变体
开发语言·c++·算法
fof9202 小时前
Base LLM | 从 NLP 到 LLM 的算法全栈教程 第二天
人工智能·算法·自然语言处理
2301_789015622 小时前
封装RBTree(红黑树)实现myset和mymap
开发语言·数据结构·c++·算法·r-tree
云蝠呼叫大模型联络中心2 小时前
零售行业智能客服与客户数据分析:技术架构与实战案例
大数据·人工智能·架构·数据分析·零售·#智能外呼合规·#云蝠智能
2301_764441332 小时前
PocketPal AI版本与部署
人工智能
乱世刀疤2 小时前
openclaw常用指令
人工智能·openclaw
liliangcsdn2 小时前
LLM长文本场景-多文档分布式分析示例
人工智能·学习
大鹏说大话2 小时前
后端开发指南:同步与异步接口的选型策略与实战场景
算法
Book思议-2 小时前
【数据结构实战】双向链表:删除节点
c语言·数据结构·算法·链表