Transformer架构发展历史

Transformer架构详解:起源、原理与应用

目录

  1. Transformer的起源与历史背景
  2. Transformer核心架构详解
  3. 自注意力机制深入解析
  4. Transformer在结构化数据中的应用
  5. Transformer在图像数据中的应用
  6. 性能优化与变体
  7. 实际应用案例

1. Transformer的起源与历史背景

1.1 深度学习序列建模的演进

在Transformer出现之前,自然语言处理领域主要依赖以下架构:

循环神经网络时代(2010-2017)

  • RNN (Recurrent Neural Networks):基础循环架构,存在梯度消失问题
  • LSTM (Long Short-Term Memory):通过门控机制解决长期依赖问题
  • GRU (Gated Recurrent Units):LSTM的简化版本
  • Seq2Seq模型:编码器-解码器架构,用于机器翻译

RNN架构的局限性

  1. 顺序计算:无法并行化,训练效率低
  2. 长期依赖问题:即使使用LSTM,处理超长序列仍然困难
  3. 梯度传播路径长:容易出现梯度消失或爆炸
  4. 信息瓶颈:编码器需要将所有信息压缩到固定维度的向量中

1.2 注意力机制的诞生

Bahdanau Attention (2014)

  • 由Dzmitry Bahdanau等人提出
  • 允许解码器在生成每个词时关注输入序列的不同部分
  • 解决了Seq2Seq模型的信息瓶颈问题

核心思想

复制代码
对于输出序列的每个位置,计算其与输入序列所有位置的相关性
相关性越高的位置获得更大的权重

1.3 "Attention is All You Need" 论文

发表信息

  • 时间:2017年6月(arXiv)
  • 作者团队:Google Brain和Google Research
  • 核心作者:Ashish Vaswani, Noam Shazeer, Niki Parmar等
  • 发表会议:NIPS 2017(现NeurIPS)

革命性贡献

  1. 完全抛弃循环结构:首次提出纯注意力架构
  2. 并行化计算:大幅提升训练和推理速度
  3. 多头注意力机制:从多个子空间捕获信息
  4. 位置编码:通过三角函数编码位置信息
  5. 残差连接与层归一化:稳定深层网络训练

影响力

  • 截至2024年,论文引用次数超过10万次
  • 催生了BERT、GPT、T5等众多里程碑模型
  • 从NLP扩展到CV、语音、生物信息学等多个领域

2. Transformer核心架构详解

2.1 整体架构

Transformer采用编码器-解码器(Encoder-Decoder)架构:

复制代码
输入序列 → [嵌入层 + 位置编码] → 编码器栈 → 解码器栈 → 输出概率分布
                                        ↓
                                    上下文向量

核心组件

  • 编码器(Encoder):6层堆叠,每层包含多头自注意力和前馈网络
  • 解码器(Decoder):6层堆叠,包含掩码自注意力、交叉注意力和前馈网络
  • 注意力机制:核心计算单元
  • 位置编码:注入序列位置信息

2.2 编码器详细结构

单个编码器层包含

python 复制代码
# 伪代码表示
class EncoderLayer:
    def forward(x):
        # 1. 多头自注意力
        attn_output = MultiHeadAttention(Q=x, K=x, V=x)
        x = LayerNorm(x + attn_output)  # 残差连接 + 层归一化
        
        # 2. 位置前馈网络
        ffn_output = FeedForward(x)
        x = LayerNorm(x + ffn_output)   # 残差连接 + 层归一化
        
        return x

参数配置(原论文)

  • 模型维度 (d_model):512
  • 头数 (num_heads):8
  • 前馈网络维度 (d_ff):2048
  • Dropout率:0.1

2.3 解码器详细结构

单个解码器层包含

python 复制代码
# 伪代码表示
class DecoderLayer:
    def forward(x, encoder_output):
        # 1. 掩码多头自注意力(防止看到未来信息)
        masked_attn = MaskedMultiHeadAttention(Q=x, K=x, V=x)
        x = LayerNorm(x + masked_attn)
        
        # 2. 编码器-解码器注意力(交叉注意力)
        cross_attn = MultiHeadAttention(
            Q=x, 
            K=encoder_output, 
            V=encoder_output
        )
        x = LayerNorm(x + cross_attn)
        
        # 3. 位置前馈网络
        ffn_output = FeedForward(x)
        x = LayerNorm(x + ffn_output)
        
        return x

关键差异

  • 第一个注意力层使用掩码,确保位置i只能关注i之前的位置
  • 增加交叉注意力层,Query来自解码器,Key和Value来自编码器输出

2.4 位置编码(Positional Encoding)

由于Transformer没有循环或卷积结构,需要显式注入位置信息。

公式

复制代码
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中

  • pos:序列中的位置(0到max_len-1)
  • i:维度索引(0到d_model/2-1)
  • d_model:模型维度

特点

  1. 确定性:不需要学习,节省参数
  2. 外推性:理论上可以处理训练时未见过的序列长度
  3. 相对位置关系:PE(pos+k)可以表示为PE(pos)的线性函数

可视化示例

复制代码
位置0:   [sin(0/10000^0), cos(0/10000^0), sin(0/10000^(2/512)), ...]
位置1:   [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/512)), ...]
位置2:   [sin(2/10000^0), cos(2/10000^0), sin(2/10000^(2/512)), ...]
...

3. 自注意力机制深入解析

3.1 缩放点积注意力(Scaled Dot-Product Attention)

核心公式

复制代码
Attention(Q, K, V) = softmax(QK^T / √d_k) V

计算步骤

  1. 计算相似度得分

    复制代码
    Scores = QK^T
    维度:(seq_len_q, d_k) × (d_k, seq_len_k) = (seq_len_q, seq_len_k)
  2. 缩放

    复制代码
    Scaled_Scores = Scores / √d_k
    • 目的:防止点积结果过大导致softmax梯度过小
    • 当d_k较大时,点积方差为d_k,缩放使方差归一化
  3. 应用Softmax

    复制代码
    Attention_Weights = softmax(Scaled_Scores)
    • 每一行的权重和为1
    • 表示Query对所有Key的注意力分布
  4. 加权求和

    复制代码
    Output = Attention_Weights × V
    维度:(seq_len_q, seq_len_k) × (seq_len_k, d_v) = (seq_len_q, d_v)

示例计算(简化版):

python 复制代码
import numpy as np

# 假设我们有3个词,每个词的表示维度为4
Q = np.array([[1, 0, 1, 0],
              [0, 1, 0, 1],
              [1, 1, 0, 0]])

K = np.array([[1, 0, 1, 0],
              [0, 1, 0, 1],
              [1, 1, 0, 0]])

V = np.array([[1, 2, 3, 4],
              [5, 6, 7, 8],
              [9, 10, 11, 12]])

# 1. 计算QK^T
scores = np.dot(Q, K.T)  # 结果:[[2, 0, 2], [0, 2, 1], [2, 1, 2]]

# 2. 缩放
d_k = Q.shape[-1]
scaled_scores = scores / np.sqrt(d_k)  # 除以2

# 3. Softmax
attention_weights = np.exp(scaled_scores) / np.exp(scaled_scores).sum(axis=-1, keepdims=True)

# 4. 加权求和
output = np.dot(attention_weights, V)

3.2 多头注意力(Multi-Head Attention)

核心思想

  • 不是只计算一次注意力,而是并行计算h次
  • 每个"头"学习不同的注意力模式
  • 类似于CNN中的多通道

公式

复制代码
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O

其中 head_i = Attention(QW^Q_i, KW^K_i, VW^V_i)

参数矩阵

  • W^Q_i ∈ ℝ^(d_model × d_k):Query投影矩阵
  • W^K_i ∈ ℝ^(d_model × d_k):Key投影矩阵
  • W^V_i ∈ ℝ^(d_model × d_v):Value投影矩阵
  • W^O ∈ ℝ^(h·d_v × d_model):输出投影矩阵

典型配置

复制代码
d_model = 512
h = 8
d_k = d_v = d_model / h = 64

计算流程图

复制代码
输入 (batch, seq_len, d_model=512)
    ↓
分割成8个头,每个头维度64
    ↓
并行计算8个注意力
    ↓
    Head 1: Attention(Q1, K1, V1) → (batch, seq_len, 64)
    Head 2: Attention(Q2, K2, V2) → (batch, seq_len, 64)
    ...
    Head 8: Attention(Q8, K8, V8) → (batch, seq_len, 64)
    ↓
拼接 (batch, seq_len, 512)
    ↓
线性投影 W^O
    ↓
输出 (batch, seq_len, d_model=512)

为什么使用多头?

  1. 不同表示子空间:每个头可以关注不同方面的信息

    • 某个头可能关注语法关系
    • 某个头可能关注语义相似性
    • 某个头可能关注位置关系
  2. 增加模型容量:在不增加参数量的情况下增强表达能力

  3. 并行计算:多个头可以同时计算

3.3 掩码机制(Masking)

填充掩码(Padding Mask)

python 复制代码
# 用于处理变长序列
# 假设序列长度为5,实际有效长度为3
sequence = [token1, token2, token3, <PAD>, <PAD>]
padding_mask = [0, 0, 0, 1, 1]  # 1表示需要掩码的位置

# 在计算注意力前应用
scores = scores.masked_fill(padding_mask == 1, -1e9)
# Softmax后,被掩码的位置权重接近0

前瞻掩码(Look-Ahead Mask)

python 复制代码
# 用于解码器自注意力,防止看到未来信息
# 序列长度为4的掩码矩阵
look_ahead_mask = [
    [0, 1, 1, 1],  # 位置0只能看到位置0
    [0, 0, 1, 1],  # 位置1只能看到位置0-1
    [0, 0, 0, 1],  # 位置2只能看到位置0-2
    [0, 0, 0, 0],  # 位置3可以看到所有位置
]
# 0表示允许注意,1表示掩码

3.4 前馈网络(Feed-Forward Network)

结构

复制代码
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2

特点

  • 两层全连接网络
  • 中间层使用ReLU激活
  • 每个位置独立应用(position-wise)
  • 参数在所有位置共享

维度变化

复制代码
(batch, seq_len, d_model=512) 
    ↓ 第一层
(batch, seq_len, d_ff=2048)
    ↓ ReLU + 第二层
(batch, seq_len, d_model=512)

作用

  1. 增加模型的非线性变换能力
  2. 每个位置可以进行独立的特征变换
  3. 类似于1×1卷积的作用

4. Transformer在结构化数据中的应用

4.1 结构化数据的挑战

表格数据特点

  • 异构特征:数值型、类别型混合
  • 特征交互:特征之间存在复杂的非线性关系
  • 数据量相对较小:通常几千到几十万样本
  • 可解释性需求:金融、医疗等领域需要理解模型决策

传统方法

  • 树模型:XGBoost, LightGBM, CatBoost(表格数据的黄金标准)
  • 深度学习:MLP、Wide & Deep、DeepFM
  • 挑战:深度学习模型在小规模表格数据上通常不如树模型

4.2 TabTransformer (2020)

论文:"TabTransformer: Tabular Data Modeling Using Contextual Embeddings"

核心创新

  1. 将类别特征转换为嵌入向量
  2. 使用Transformer学习类别特征之间的上下文关系
  3. 数值特征保持原样或简单归一化

架构设计

复制代码
类别特征 → 嵌入层 → Transformer编码器 → 拼接数值特征 → MLP → 输出

具体流程:
1. 类别特征处理:
   - 特征1: [A] → Embedding → [e1_1, e1_2, ..., e1_d]
   - 特征2: [B] → Embedding → [e2_1, e2_2, ..., e2_d]
   - ...
   
2. Transformer编码:
   - 输入: (num_categorical_features, embedding_dim)
   - 多层自注意力捕获特征间关系
   - 输出: (num_categorical_features, embedding_dim)
   
3. 特征融合:
   - Flatten编码后的类别特征
   - 拼接原始数值特征
   - 通过MLP进行最终预测

代码示例

python 复制代码
import torch
import torch.nn as nn

class TabTransformer(nn.Module):
    def __init__(self, 
                 num_categories,      # 每个类别特征的类别数
                 num_numerical,       # 数值特征数量
                 embedding_dim=32,    # 嵌入维度
                 num_layers=6,        # Transformer层数
                 num_heads=8,         # 注意力头数
                 ffn_dim=128,         # 前馈网络维度
                 output_dim=1):       # 输出维度
        super().__init__()
        
        # 类别特征嵌入
        self.embeddings = nn.ModuleList([
            nn.Embedding(num_cat, embedding_dim) 
            for num_cat in num_categories
        ])
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=ffn_dim,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        # 输出MLP
        total_dim = len(num_categories) * embedding_dim + num_numerical
        self.mlp = nn.Sequential(
            nn.Linear(total_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, categorical_features, numerical_features):
        # 嵌入类别特征
        embeddings = [
            emb(categorical_features[:, i]) 
            for i, emb in enumerate(self.embeddings)
        ]
        # (batch, num_categorical, embedding_dim)
        cat_embedded = torch.stack(embeddings, dim=1)
        
        # Transformer编码
        # (batch, num_categorical, embedding_dim)
        encoded = self.transformer(cat_embedded)
        
        # Flatten类别特征
        encoded_flat = encoded.reshape(encoded.size(0), -1)
        
        # 拼接数值特征
        combined = torch.cat([encoded_flat, numerical_features], dim=1)
        
        # MLP预测
        output = self.mlp(combined)
        return output

# 使用示例
model = TabTransformer(
    num_categories=[10, 5, 20, 8],  # 4个类别特征
    num_numerical=15,                # 15个数值特征
    embedding_dim=32,
    num_layers=4,
    num_heads=4
)

# 假设batch_size=64
cat_features = torch.randint(0, 10, (64, 4))  # 类别特征
num_features = torch.randn(64, 15)            # 数值特征
output = model(cat_features, num_features)

优势

  1. 特征交互:自动学习类别特征间的复杂关系
  2. 鲁棒性:对缺失值和噪声有一定容忍度
  3. 迁移学习:预训练的嵌入可以迁移到相关任务

实验结果(原论文):

  • 在多个UCI数据集上超越传统MLP
  • 在某些数据集上接近甚至超越树模型
  • 特别适合类别特征较多的场景

4.3 FT-Transformer (2021)

论文:"Revisiting Deep Learning Models for Tabular Data"

核心改进

  1. 所有特征都嵌入化:数值特征也转换为嵌入
  2. 特征级注意力:每个特征作为一个token
  3. 更简洁的架构:去除额外的MLP层

架构

复制代码
所有特征 → 特征嵌入 → [CLS] Token → Transformer → [CLS]输出 → 预测

特征嵌入方式:
- 类别特征: Embedding(category_value)
- 数值特征: Linear(numerical_value) + Feature_embedding

代码示例

python 复制代码
class FTTransformer(nn.Module):
    def __init__(self, 
                 num_features,
                 num_categories,
                 embedding_dim=64,
                 num_layers=3,
                 num_heads=8):
        super().__init__()
        
        # 类别特征嵌入
        self.cat_embeddings = nn.ModuleList([
            nn.Embedding(num_cat, embedding_dim) 
            for num_cat in num_categories
        ])
        
        # 数值特征线性投影
        self.num_projections = nn.ModuleList([
            nn.Linear(1, embedding_dim) 
            for _ in range(num_features)
        ])
        
        # 特征位置嵌入
        total_features = len(num_categories) + num_features
        self.feature_embeddings = nn.Parameter(
            torch.randn(1, total_features, embedding_dim)
        )
        
        # CLS token
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, embedding_dim)
        )
        
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        # 输出层
        self.output = nn.Linear(embedding_dim, 1)
    
    def forward(self, categorical_features, numerical_features):
        batch_size = categorical_features.size(0)
        
        # 处理类别特征
        cat_embeds = [
            emb(categorical_features[:, i]) 
            for i, emb in enumerate(self.cat_embeddings)
        ]
        
        # 处理数值特征
        num_embeds = [
            proj(numerical_features[:, i:i+1]) 
            for i, proj in enumerate(self.num_projections)
        ]
        
        # 拼接所有特征嵌入
        all_embeds = cat_embeds + num_embeds
        features = torch.stack(all_embeds, dim=1)  # (batch, num_features, dim)
        
        # 添加特征位置嵌入
        features = features + self.feature_embeddings
        
        # 添加CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        features = torch.cat([cls_tokens, features], dim=1)
        
        # Transformer编码
        encoded = self.transformer(features)
        
        # 使用CLS token的输出进行预测
        cls_output = encoded[:, 0, :]
        output = self.output(cls_output)
        
        return output

优势

  • 统一处理所有类型的特征
  • 更好的特征交互学习
  • 在多个基准测试中达到SOTA

4.4 时间序列预测中的应用

Temporal Fusion Transformer (TFT, 2019)

场景:多变量时间序列预测

架构特点

  1. 变量选择网络:学习哪些特征重要
  2. 时序融合解码器:融合不同时间尺度的信息
  3. 多头注意力:捕获长期依赖关系

应用

  • 电力负荷预测
  • 股票价格预测
  • 零售需求预测

代码框架

python 复制代码
class TemporalFusionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # 变量选择网络
        self.variable_selection = VariableSelectionNetwork(
            input_dim=config.num_features,
            hidden_dim=config.hidden_dim
        )
        
        # LSTM编码器(处理历史序列)
        self.encoder_lstm = nn.LSTM(
            input_size=config.hidden_dim,
            hidden_size=config.hidden_dim,
            num_layers=1,
            batch_first=True
        )
        
        # LSTM解码器(生成未来序列)
        self.decoder_lstm = nn.LSTM(
            input_size=config.hidden_dim,
            hidden_size=config.hidden_dim,
            num_layers=1,
            batch_first=True
        )
        
        # 时序注意力层
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=config.hidden_dim,
            num_heads=config.num_heads,
            batch_first=True
        )
        
        # 门控残差网络
        self.grn = GatedResidualNetwork(config.hidden_dim)
        
        # 输出层
        self.output_layer = nn.Linear(
            config.hidden_dim, 
            config.forecast_horizon
        )

4.5 推荐系统中的应用

BERT4Rec (2019)

思想:将用户行为序列看作"句子",物品看作"词"

架构

复制代码
用户行为序列: [item1, item2, [MASK], item4, item5]
    ↓
Transformer编码器
    ↓
预测被mask的item

训练方式

  • 随机mask序列中的物品
  • 模型预测被mask的物品
  • 类似BERT的预训练方式

优势

  • 双向建模:考虑前后文信息
  • 捕获长期兴趣
  • 自监督学习:无需额外标注

5. Transformer在图像数据中的应用

5.1 从CNN到Transformer的转变

传统CNN的局限

  1. 局部感受野:卷积核只能看到局部区域
  2. 归纳偏置强:平移不变性和局部性是硬编码的
  3. 长距离依赖:需要堆叠很多层才能建立全局关系

Transformer的优势

  1. 全局感受野:每个位置都能关注整个图像
  2. 灵活的归纳偏置:通过数据学习而非硬编码
  3. 可扩展性:性能随数据量和模型大小持续提升

5.2 Vision Transformer (ViT, 2020)

论文:"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"(Google Research)

核心思想:将图像分割成固定大小的patch,每个patch作为一个token

架构流程

复制代码
1. 图像分割
   输入图像 (224×224×3)
   ↓
   分割成patches (14×14个patches,每个16×16×3)

2. Patch嵌入
   每个patch → flatten → 线性投影
   (16×16×3 = 768维) → 嵌入空间 (D维,如768)

3. 添加位置嵌入
   + 可学习的位置编码

4. 添加[CLS] token
   [CLS] + patch_1 + patch_2 + ... + patch_196

5. Transformer编码器
   多层自注意力 + FFN

6. 分类头
   [CLS] token的输出 → MLP → 类别概率

详细代码实现

python 复制代码
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    """将图像分割成patches并嵌入"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # 使用卷积实现patch提取和线性投影
        self.projection = nn.Conv2d(
            in_channels, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
    
    def forward(self, x):
        # x: (batch, 3, 224, 224)
        x = self.projection(x)  # (batch, embed_dim, 14, 14)
        x = x.flatten(2)         # (batch, embed_dim, 196)
        x = x.transpose(1, 2)    # (batch, 196, embed_dim)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, 
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 num_classes=1000,
                 embed_dim=768,
                 depth=12,           # Transformer层数
                 num_heads=12,
                 mlp_ratio=4.0,      # FFN隐藏层维度 = embed_dim * mlp_ratio
                 dropout=0.1):
        super().__init__()
        
        # Patch嵌入
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # 位置嵌入(可学习)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(p=dropout)
        
        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=depth
        )
        
        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # 初始化权重
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Patch嵌入
        x = self.patch_embed(x)  # (batch, 196, 768)
        
        # 添加CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch, 197, 768)
        
        # 添加位置嵌入
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer编码
        x = self.transformer(x)
        
        # 分类(使用CLS token)
        x = self.norm(x[:, 0])
        x = self.head(x)
        
        return x

# 模型变体
def vit_base_patch16_224():
    """ViT-Base: 86M参数"""
    return VisionTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12
    )

def vit_large_patch16_224():
    """ViT-Large: 307M参数"""
    return VisionTransformer(
        img_size=224,
        patch_size=16,
        embed_dim=1024,
        depth=24,
        num_heads=16
    )

def vit_huge_patch14_224():
    """ViT-Huge: 632M参数"""
    return VisionTransformer(
        img_size=224,
        patch_size=14,
        embed_dim=1280,
        depth=32,
        num_heads=16
    )

关键发现(原论文):

  1. 数据量是关键

    • 在ImageNet-1K(130万图像)上:ViT < ResNet
    • 在ImageNet-21K(1400万图像)上:ViT ≈ ResNet
    • 在JFT-300M(3亿图像)上:ViT > ResNet
  2. 归纳偏置

    • ViT几乎没有图像特定的归纳偏置
    • 完全依赖数据学习,因此需要大规模数据
  3. 计算效率

    • 相同精度下,ViT训练成本更低
    • 更容易扩展到更大模型

5.3 Swin Transformer (2021)

论文:"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"(Microsoft Research Asia)

动机

  • ViT计算复杂度是O(N²),其中N是patch数量
  • 对于高分辨率图像,计算量巨大
  • 缺少层次化结构,难以用于密集预测任务

核心创新

  1. 局部窗口注意力:限制注意力在固定大小的窗口内
  2. 滑动窗口机制:通过移位窗口建立跨窗口连接
  3. 层次化结构:类似CNN的金字塔结构

架构设计

复制代码
Stage 1: 56×56 patches
  ↓ Window Attention
  ↓ Shifted Window Attention

Stage 2: 28×28 patches (Patch Merging)
  ↓ Window Attention
  ↓ Shifted Window Attention

Stage 3: 14×14 patches (Patch Merging)
  ↓ Window Attention
  ↓ Shifted Window Attention

Stage 4: 7×7 patches (Patch Merging)
  ↓ Window Attention
  ↓ Shifted Window Attention

窗口注意力示例

python 复制代码
class WindowAttention(nn.Module):
    """在固定大小窗口内计算注意力"""
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (window_height, window_width)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        # 相对位置偏置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 
                        num_heads)
        )
        
        # QKV投影
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x, mask=None):
        """
        x: (num_windows*batch, window_size*window_size, C)
        mask: (num_windows, window_size*window_size, window_size*window_size)
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 计算注意力
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        
        # 添加相对位置偏置
        # ... (省略细节)
        
        # 应用掩码(用于shifted window)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

class SwinTransformerBlock(nn.Module):
    """Swin Transformer基本块"""
    def __init__(self, dim, num_heads, window_size=7, shift_size=0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Linear(4 * dim, dim)
        )
    
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        
        # 滑动窗口(如果shift_size > 0)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), 
                                  dims=(1, 2))
        else:
            shifted_x = x
        
        # 分割成窗口
        x_windows = window_partition(shifted_x, self.window_size)
        
        # 窗口注意力
        attn_windows = self.attn(x_windows)
        
        # 合并窗口
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        
        # 反向滑动
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), 
                          dims=(1, 2))
        else:
            x = shifted_x
        
        x = x.view(B, H * W, C)
        
        # FFN
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))
        
        return x

优势

  1. 计算效率:O(N)复杂度,可处理高分辨率图像
  2. 层次化表示:适合检测、分割等任务
  3. 性能优异:ImageNet分类、COCO检测都达到SOTA

实验结果

  • ImageNet-1K: Top-1 87.3% (Swin-L)
  • COCO目标检测: 58.7 box AP
  • ADE20K语义分割: 53.5 mIoU

5.4 Detection Transformer (DETR, 2020)

论文:"End-to-End Object Detection with Transformers"(Facebook AI)

革命性创新

  • 第一个端到端的目标检测器
  • 不需要NMS(非极大值抑制)
  • 不需要anchor boxes

架构

复制代码
输入图像
  ↓
CNN Backbone (ResNet) → 特征图
  ↓
Flatten + 位置编码
  ↓
Transformer编码器
  ↓
Object Queries(可学习)
  ↓
Transformer解码器
  ↓
FFN → N个预测 (类别 + 边界框)
  ↓
匈牙利匹配 + Loss

关键组件

  1. Object Queries
python 复制代码
# N个可学习的查询向量(N=100)
self.query_embed = nn.Embedding(num_queries, hidden_dim)

# 在解码器中使用
query_pos = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)
  1. 二分匹配
python 复制代码
def hungarian_matching(pred_boxes, pred_logits, target_boxes, target_labels):
    """使用匈牙利算法进行预测和GT的最优匹配"""
    # 计算分类成本
    cost_class = -pred_logits[:, target_labels]
    
    # 计算L1距离成本
    cost_bbox = torch.cdist(pred_boxes, target_boxes, p=1)
    
    # 计算GIoU成本
    cost_giou = -generalized_box_iou(pred_boxes, target_boxes)
    
    # 总成本
    C = cost_class + 5 * cost_bbox + 2 * cost_giou
    
    # 匈牙利算法
    indices = linear_sum_assignment(C.cpu())
    return indices

完整实现框架

python 复制代码
class DETR(nn.Module):
    def __init__(self, num_classes, num_queries=100, hidden_dim=256):
        super().__init__()
        
        # CNN骨干网络
        self.backbone = resnet50(pretrained=True)
        
        # 降维
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        
        # 位置编码
        self.pos_encoder = PositionEmbeddingSine(hidden_dim // 2)
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=8,
            num_encoder_layers=6,
            num_decoder_layers=6,
            dim_feedforward=2048
        )
        
        # Object queries
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        
        # 预测头
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)  # +1 for no-object
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)  # (x, y, w, h)
    
    def forward(self, images):
        # 特征提取
        features = self.backbone(images)  # (batch, 2048, H, W)
        features = self.conv(features)     # (batch, 256, H, W)
        
        # 位置编码
        pos = self.pos_encoder(features)
        
        # Flatten
        batch_size = features.shape[0]
        features = features.flatten(2).permute(2, 0, 1)  # (HW, batch, 256)
        pos = pos.flatten(2).permute(2, 0, 1)
        
        # Object queries
        query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)
        
        # Transformer
        memory = self.transformer.encoder(features, pos=pos)
        hs = self.transformer.decoder(
            query_embed, 
            memory, 
            memory_key_padding_mask=None,
            pos=pos,
            query_pos=query_embed
        )
        
        # 预测
        outputs_class = self.class_embed(hs)  # (num_queries, batch, num_classes+1)
        outputs_coord = self.bbox_embed(hs).sigmoid()  # (num_queries, batch, 4)
        
        return {
            'pred_logits': outputs_class[-1],
            'pred_boxes': outputs_coord[-1]
        }

优势

  • 真正的端到端训练
  • 全局推理能力
  • 代码简洁,易于扩展

局限性

  • 训练收敛较慢(需要500 epochs)
  • 小物体检测性能较弱
  • 计算量较大

5.5 图像分割中的应用

Segmenter (2021)

将ViT应用于语义分割:

复制代码
图像 → Patch嵌入 → Transformer编码器 → 逐patch分类 → 上采样 → 分割图

Mask2Former (2022)

结合DETR思想进行实例分割和全景分割:

复制代码
图像 → Backbone → Pixel Decoder → Transformer Decoder → Mask预测

6. 性能优化与变体

6.1 高效注意力机制

Linformer (2020)

  • 将注意力复杂度从O(N²)降到O(N)
  • 方法:低秩矩阵近似K和V
python 复制代码
# 核心思想
K_projected = K @ E  # E: (N, k) 其中 k << N
V_projected = V @ F  # F: (N, k)
Attention = softmax(Q @ K_projected.T) @ V_projected

Performer (2020)

  • 使用快速注意力算法(FAVOR+)
  • 线性复杂度,无近似误差

Reformer (2020)

  • 局部敏感哈希(LSH)注意力
  • 只计算相似query和key的注意力

6.2 位置编码改进

相对位置编码

python 复制代码
# 不是编码绝对位置,而是编码相对距离
relative_position = position_i - position_j
bias = learned_bias[relative_position]
attention_score = (Q @ K.T) + bias

旋转位置编码(RoPE)

  • 通过旋转矩阵注入位置信息
  • 用于LLaMA、PaLM等大模型

ALiBi(Attention with Linear Biases)

  • 在注意力分数上添加线性偏置
  • 外推性能优异

6.3 稀疏注意力

局部注意力

python 复制代码
# 只关注周围k个位置
for i in range(seq_len):
    attend_to = range(max(0, i-k), min(seq_len, i+k+1))

Strided注意力

python 复制代码
# 每隔s个位置关注一次
attend_to = [0, s, 2s, 3s, ...]

Longformer

  • 结合局部注意力和全局注意力
  • 可处理长达16k的序列

7. 实际应用案例

7.1 案例1:电商推荐系统

场景:用户商品点击序列预测

数据

  • 用户历史点击:[item_1, item_2, ..., item_n]
  • 用户特征:年龄、性别、地域
  • 商品特征:类别、价格、品牌

方案

python 复制代码
class RecommendationTransformer(nn.Module):
    def __init__(self, num_items, embedding_dim=128):
        super().__init__()
        
        # 商品嵌入
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        
        # Transformer编码器
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embedding_dim,
                nhead=8,
                dim_feedforward=512,
                batch_first=True
            ),
            num_layers=4
        )
        
        # 预测头
        self.output = nn.Linear(embedding_dim, num_items)
    
    def forward(self, item_sequence):
        # 嵌入
        x = self.item_embedding(item_sequence)
        
        # 位置编码
        x = x + self.positional_encoding(x)
        
        # Transformer编码
        encoded = self.transformer(x)
        
        # 预测下一个商品
        logits = self.output(encoded[:, -1, :])
        return logits

效果

  • 相比LSTM提升5-10%的点击率
  • 能够捕获长期用户兴趣

7.2 案例2:医学图像分割

场景:CT扫描肿瘤分割

挑战

  • 医学图像分辨率高(512×512甚至更大)
  • 标注数据有限
  • 需要高精度

方案:使用Swin-UNETR

复制代码
输入CT切片
  ↓
Swin Transformer编码器(多尺度特征)
  ↓
UNet风格解码器
  ↓
分割结果

优势

  • 全局上下文建模
  • 层次化特征提取
  • 在BTCV数据集上达到SOTA

7.3 案例3:金融风控

场景:信用卡欺诈检测

数据特点

  • 高维特征:100+维
  • 混合类型:交易金额(数值)、商户类别(类别)
  • 类别不平衡:欺诈样本<1%

方案:使用TabTransformer

python 复制代码
model = TabTransformer(
    num_categories=[1000, 50, 20, ...],  # 商户ID、地区、卡类型
    num_numerical=45,                     # 交易金额、时间等
    embedding_dim=64,
    num_layers=6,
    num_heads=8
)

# 训练时使用Focal Loss处理类别不平衡
criterion = FocalLoss(alpha=0.25, gamma=2.0)

效果

  • F1-score提升3-5%
  • 特征交互自动学习,减少特征工程

7.4 案例4:自动驾驶

场景:3D目标检测

数据:激光雷达点云 + 相机图像

方案:TransFusion

复制代码
点云 → PointNet提取特征
  ↓
相机图像 → CNN提取特征
  ↓
Transformer融合多模态特征
  ↓
检测头 → 3D边界框

创新点

  • 跨模态注意力机制
  • 查询向量同时关注点云和图像特征

总结

Transformer的核心优势

  1. 并行化:不同于RNN的顺序计算,可以充分利用GPU
  2. 长距离依赖:通过注意力机制直接建立任意位置间的连接
  3. 可解释性:注意力权重可视化,了解模型关注点
  4. 可扩展性:性能随模型大小和数据量持续提升
  5. 通用性:从NLP到CV,从表格数据到多模态

未来发展方向

  1. 效率优化

    • 更高效的注意力机制(Flash Attention)
    • 模型压缩与量化
    • 稀疏模型(MoE)
  2. 架构创新

    • 更好的位置编码
    • 动态深度网络
    • 神经架构搜索
  3. 应用拓展

    • 多模态融合(CLIP, DALL-E)
    • 科学计算(AlphaFold)
    • 强化学习(Decision Transformer)
  4. 理论理解

    • 为什么Transformer有效?
    • 如何设计更好的归纳偏置?
    • 泛化性理论

学习资源

论文必读

  1. Attention is All You Need (2017)
  2. BERT: Pre-training of Deep Bidirectional Transformers (2018)
  3. An Image is Worth 16x16 Words (ViT, 2020)
  4. Swin Transformer (2021)

代码实践

  • Hugging Face Transformers库
  • PyTorch官方教程
  • Annotated Transformer

课程推荐

  • Stanford CS224N (NLP with Deep Learning)
  • Stanford CS231N (Computer Vision)
  • Deep Learning Specialization (Andrew Ng)

文档版本 :v1.0
最后更新 :2024年
作者:Claude (Anthropic)

相关推荐
番茄寿司4 小时前
具身智能六大前沿创新思路深度解析
论文阅读·人工智能·深度学习·计算机网络·机器学习
LinXunFeng4 小时前
Flutter 多仓库本地 Monorepo 方案与体验优化
前端·flutter·架构
IT小番茄5 小时前
kubernetes云平台管理实战:deployment通过标签管理pod(十)
架构
吃饺子不吃馅6 小时前
Canvas实现协同电影选座
前端·架构·canvas
递归不收敛6 小时前
四、高效注意力机制与模型架构
人工智能·笔记·自然语言处理·架构
碧海银沙音频科技研究院7 小时前
DiVE长尾识别的虚拟实例蒸馏方法
arm开发·人工智能·深度学习·算法·音视频
AI浩7 小时前
基于多焦点高斯邻域注意力机制与大规模基准的视频人群定位
人工智能·深度学习·音视频
8Qi87 小时前
A Survey of Camouflaged Object Detection and Beyond论文阅读笔记
人工智能·深度学习·目标检测·计算机视觉·伪装目标检测
AI规划师-南木7 小时前
低代码开发医疗AI工具:5分钟搭建用药推荐系统,零基础也能落地
人工智能·深度学习·低代码·计算机视觉·推荐系统·rxjava·医疗ai