s
former (ViT) 详解:从原理到实践
一、引言
Vision Transformer (ViT) 是2020年Google提出的一种将Transformer架构应用于图像分类的模型,彻底改变了计算机视觉领域的格局。与传统的CNN不同,ViT直接使用全局自注意力机制,避免了对局部卷积操作的依赖。
在本文中,我们将深入探讨ViT的核心原理、架构设计,并展示在ImageNet数据集上的实验结果。

二、ViT核心原理
2.1 Patch Embedding
ViT的核心思想是将图像划分为固定大小的patch:
| 步骤 | 说明 |
|---|---|
| 图像尺寸 | 224×224×3 |
| Patch大小 | 16×16 |
| Patches数量 | (224/16)² = 196个 |
| 每个Patch维度 | 16×16×3 = 768维 |
2.2 位置编码
ViT使用可学习的位置编码,保留[CLS] token用于最终分类:
python
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Conv2d同时实现patch分割和线性投影
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
# x: [B, 3, 224, 224] -> [B, 768, 14, 14] -> [B, 196, 768]
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
2.3 Transformer Encoder
每个Transformer Block包含:
- Multi-Head Self-Attention (MHSA)
- MLP Block
- Layer Norm
- 残差连接
三、实验结果
我们在ImageNet-1K数据集上进行了实验,结果如下:

| 指标 | ViT-Base | ViT-Large | DeiT-Small |
|---|---|---|---|
| ImageNet Top-1 | 84.2% | 86.3% | 79.8% |
| ImageNet Top-5 | 97.1% | 97.9% | 95.1% |
| 参数量 | 86M | 307M | 22M |
| FLOPs | 17.6G | 61.6G | 4.6G |

四、代码实践
4.1 完整ViT实现
python
import torch
import torch.nn as nn
from functools import partial
class ViT(nn.Module):
def __init__(self, img_size=224, patch_size=16,
in_chans=3, num_classes=1000,
embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4.,
qkv_bias=True, drop_rate=0.):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
# Patch Embedding
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size,
in_chans=in_chans, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches
# [CLS] token和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.pos_drop = nn.Dropout(p=drop_rate)
# Transformer Encoder
self.blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# 分类头
self.head = nn.Linear(embed_dim, num_classes)
nn.init.trunc_normal_(self.pos_embed, std=.02)
nn.init.trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
4.2 训练配置
python
# 训练超参数
config = {
'epochs': 300,
'batch_size': 256,
'learning_rate': 1e-3,
'weight_decay': 0.05,
'warmup_epochs': 20,
'label_smoothing': 0.1,
}
五、总结与展望
ViT的优势
✅ 全局注意力机制,捕获长距离依赖
✅ 架构简洁,易于扩展
✅ 预训练-微调范式效果出色
局限性
❌ 需要大量数据训练(从零训练需要JFT-300M)
❌ 计算复杂度随图像尺寸平方增长
❌ 缺乏CNN的归纳偏置(局部性、翻译不变性)
未来方向
- DeiT: 数据高效ViT
- Swin Transformer: 层次化Transformer
- BEiT: ViT的BERT式预训练
参考论文:
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- Training data-efficient image transformers & distillation through attention
💡 创作不易,如果本文对你有帮助,欢迎点赞、评论、转发!