Vision Transformer (ViT) 解析
一、核心思想:图像即序列
ViT 的本质是将图像视为一系列 "图像单词"(image tokens),完全沿用 NLP 中 Transformer 的架构来处理视觉数据:
- 将输入图像分割为固定大小的非重叠图像块 (patches)
- 将每个图像块线性投影为向量 (embedding)
- 添加位置编码以保留空间信息
- 输入 Transformer 编码器进行全局特征学习
- 用特殊的 [CLS] token 进行分类预测
这与 CNN 的局部卷积 + 池化的分层特征提取范式形成鲜明对比。
二、ViT 完整架构详解
1. 输入处理模块
图像分块 (Patch Embedding)
- 输入:RGB 图像,尺寸为 H×W×3(高度 × 宽度 × 通道数)
- 分块:使用大小为 P×P 的非重叠窗口分割图像,得到 N = (H/P) × (W/P) 个图像块
- 标准设置:P=16,对于 224×224 图像,N=14×14=196 个块
- 展平:每个图像块展平为长度为 P²×3 的向量(例如 16×16×3=768)
- 线性投影:通过全连接层将展平向量映射到 D 维嵌入空间(通常 D=768)
数学表达
python
x_patch = flatten(patch) # 形状: (P²×3)
z_0^p = Linear(x_patch) # 形状: (D), p=1..N
特殊 Token 与位置编码
- [CLS] Token :在图像块序列前添加一个可学习的嵌入向量 z_0^0,用于最终分类,它会聚合整个图像的全局信息
- 位置编码 (Positional Embedding) :由于 Transformer 是顺序无关的,必须添加位置信息以保留空间结构
- 可学习位置编码:与图像块嵌入维度相同 (D),通过训练学习
- 固定位置编码:如正弦余弦编码(原 Transformer 使用)
- 最终输入序列:
z_0 = [z_0^0; z_0^1; z_0^2; ...; z_0^N] + E_pos,形状为 (N+1)×D
2. Transformer 编码器模块
ViT 使用标准 Transformer 编码器,由L 个相同层堆叠而成,每一层包含两个子层:
|--------------|-------------------|---------------------------|
| 子层 | 功能 | 结构 |
| 多头自注意力 (MSA) | 捕捉所有图像块间的全局依赖关系 | 多个并行的自注意力头,输出拼接后线性变换 |
| 多层感知机 (MLP) | 对每个位置的特征独立进行非线性变换 | 两层全连接 + GELU 激活 + Dropout |
残差连接与层归一化
- 每个子层都有残差连接:
LayerNorm(x + Sublayer(x)) - 这是 Transformer 训练稳定性的关键,允许模型深度堆叠
多头自注意力计算
python
# 线性投影得到Q, K, V
Q = Linear(z_l-1) # (N+1)×D
K = Linear(z_l-1) # (N+1)×D
V = Linear(z_l-1) # (N+1)×D
# 分割为M个头
Q = split_heads(Q, M) # M×(N+1)×(D/M)
K = split_heads(K, M) # M×(N+1)×(D/M)
V = split_heads(V, M) # M×(N+1)×(D/M)
# 自注意力计算
attn_scores = softmax((Q @ K^T) / sqrt(D/M)) # M×(N+1)×(N+1)
attn_output = attn_scores @ V # M×(N+1)×(D/M)
# 拼接并线性变换
attn_output = concat_heads(attn_output) # (N+1)×D
z_l^1 = Linear(attn_output) # (N+1)×D
MLP 计算
python
z_l^2 = MLP(LayerNorm(z_l^1 + z_l-1)) # 残差连接+层归一化
z_l = LayerNorm(z_l^2 + z_l^1) # 第二个残差连接+层归一化
3. 分类头模块
- 取最后一层编码器输出中 **[CLS] token** 的表示:
z_L^0 - 通过 MLP Head(通常是两层全连接)映射到类别空间:
python
y = MLP_Head(z_L^0) # 形状: (num_classes)
三、ViT 变体与尺度
Google 提供了多种尺度的 ViT 模型,参数如下:
|-----------|--------|-----------|-----------|-----------|------|
| 模型 | 层数 (L) | 隐藏层维度 (D) | 注意力头数 (M) | 图像块大小 (P) | 参数数量 |
| ViT-Base | 12 | 768 | 12 | 16 | 86M |
| ViT-Large | 24 | 1024 | 16 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 16 | 16 | 632M |
训练策略
- 大规模预训练:在 JFT-300M 等超大数据集上预训练,再微调至下游任务
- 混合精度训练:使用 FP16 加速训练并减少内存占用
- 数据增强:大量使用 RandAugment 等增强技术提升泛化能力