一、视觉Transformer(ViT)技术解读
1.1 背景
视觉Transformer(Vision Transformer,ViT)是2020年由Google Research提出的一种革命性图像识别架构,其核心思想是将自然语言处理中成功的Transformer模型直接应用于计算机视觉任务。与传统卷积神经网络(CNN)不同,ViT摒弃了卷积操作的归纳偏置(如局部性和平移等变性),完全依赖自注意力机制处理图像数据。
论文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》系统性地证明了在大规模数据预训练下,纯Transformer架构在图像分类任务上可超越最先进的CNN模型。
ViT的意义在于:
- 架构创新:首次实现纯Transformer在图像识别中的有效应用,打破CNN的长期垄断。
- 可扩展性:受益于Transformer的并行计算优势,易于扩展到超大规模参数(如ViT-Huge达6.32亿参数)。
- 跨领域统一:为NLP和CV任务提供了统一的架构框架,促进多模态学习(如CLIP、FLAVA)发展。
- 算力效率提升:同等性能下,ViT预训练所需算力仅为CNN的1/4~1/14,大幅降低大规模模型训练成本。
1.2 ViT方法详解
ViT的核心流程是将图像视为一系列图像块的序列,并通过标准Transformer编码器处理这些序列。
1. 图像分块(Image Patching)
- 输入处理 :给定图像 x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} x∈RH×W×C(高度×宽度×通道数),将其分割成大小为 P × P P \times P P×P的非重叠块(patches)。块数量为 N = H W / P 2 N = HW / P^2 N=HW/P2。
- 示例 :若输入图像为224×224×3,块大小 P = 16 P=16 P=16,则生成 N = 224 2 / 16 2 = 196 N = 224^2 / 16^2 = 196 N=2242/162=196个块;若 P = 32 P=32 P=32,则 N = 49 N=49 N=49,序列长度更短但细节损失更多。
- 作用:将2D图像转换为1D序列,模拟NLP中词序列的处理方式,解决Transformer无法直接处理2D结构的问题。
- 块大小的设计考量 :
- 小尺寸块(如16×16):保留更多图像细节,序列长度更长(计算成本高),适合高精度需求;
- 大尺寸块(如32×32):序列长度短(计算高效),但细节丢失多,适合算力有限或低分辨率场景;
- 论文验证:ViT-L/16(16×16块)性能显著优于ViT-L/32,是精度与算力的最优平衡。

2. 块嵌入(Patch Embedding)
- 每个图像块被展平为 R P 2 ⋅ C \mathbb{R}^{P^2 \cdot C} RP2⋅C 的一维向量(如16×16×3的块展平为768维),并通过可训练的线性投影层(嵌入矩阵 E E E)映射到固定维度 D D D(Transformer隐藏层维度):
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E pos , E ∈ R ( P 2 ⋅ C ) × D , E pos ∈ R ( N + 1 ) × D z_0 = [x_{\text{class}}; x_p^1 E; x_p^2 E; \cdots; x_p^N E] + E_{\text{pos}}, \quad E \in \mathbb{R}^{(P^2 \cdot C) \times D}, \quad E_{\text{pos}} \in \mathbb{R}^{(N+1) \times D} z0=[xclass;xp1E;xp2E;⋯;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D - 公式拆解 :
- x class x_{\text{class}} xclass:额外添加的可学习分类词元(类比BERT的[CLS]),唯一用于最终分类的特征向量,避免对所有块向量平均导致的信息损失;
- x p i E x_p^i E xpiE:第 i i i个图像块的线性嵌入结果;
- E pos E_{\text{pos}} Epos:位置嵌入,为每个块补充空间位置信息;
- 维度匹配:无论块大小如何,最终所有块都被映射到 D D D维(如ViT-Base的 D = 768 D=768 D=768),保证Transformer输入维度统一。
3. 位置编码(Positional Encoding)
- ViT使用可学习的1D位置嵌入,而非NLP中常用的正弦位置编码,为每个块(含分类词元)分配一个唯一的可训练向量。尽管图像是2D结构,但论文实验证明:简单的1D编码已能让模型学习到2D空间关系(如同一行/列的块嵌入相似度更高),更复杂的2D位置编码未带来显著性能提升。
- 核心特性 :
- 位置嵌入会自适应学习图像拓扑:相近的块具有更相似的嵌入向量,甚至出现行列结构、正弦周期模式;
- 高分辨率微调适配:当微调图像分辨率高于预训练时(如预训练224×224→微调512×512),通过2D插值扩展位置嵌入,保证空间关系的连续性。

4. Transformer编码器
-
ViT采用标准Transformer编码器,由 L L L个完全相同的"多头自注意力(MSA)+ MLP"模块堆叠而成,核心公式:
z ℓ ′ = MSA ( LN ( z ℓ − 1 ) ) + z ℓ − 1 , z ℓ = MLP ( LN ( z ℓ ′ ) ) + z ℓ ′ z'\ell = \text{MSA}(\text{LN}(z{\ell-1})) + z_{\ell-1}, \quad z_\ell = \text{MLP}(\text{LN}(z'\ell)) + z'\ell zℓ′=MSA(LN(zℓ−1))+zℓ−1,zℓ=MLP(LN(zℓ′))+zℓ′ -
模块细节补充 :
组件 作用与关键设计 层归一化(LN) 应用于每个子模块(MSA/MLP)输入前,相比CNN常用的批归一化(BN),LN不依赖批量大小,更适合Transformer的并行训练 残差连接 每个子模块的输出与输入相加,解决深度网络的梯度消失问题,ViT最多可堆叠32层(ViT-Huge) 多头自注意力(MSA) 拆分隐藏层维度为多个"注意力头"(如ViT-Base为12头),每个头独立计算注意力权重,捕捉不同类型的空间依赖(部分头关注局部细节,部分头关注全局结构) MLP块 由两层全连接层组成,中间层维度为 4 D 4D 4D(如ViT-Base的 4 × 768 = 3072 4×768=3072 4×768=3072),使用GELU激活(比ReLU更平滑,降低训练震荡),输出层还原为 D D D维 -
计算效率对比 :ViT的自注意力复杂度为 O ( N 2 D ) O(N^2D) O(N2D),其中 N N N是块数量(通常<1000),远低于直接处理像素的 O ( H 2 W 2 C ) O(H^2W^2C) O(H2W2C),因此比像素级注意力更高效。
5. 分类头与输出
- 编码器输出序列中,仅提取分类词元对应的向量 z L 0 z_L^0 zL0(第 L L L层编码器输出的第一个位置向量),经过层归一化后作为最终图像特征:
y = LN ( z L 0 ) y = \text{LN}(z_L^0) y=LN(zL0) - 分类头的设计随阶段调整:
- 预训练阶段:使用"单隐藏层MLP"作为分类头,适配大规模数据集(如JFT-300M的18k类别)的复杂分布;
- 微调阶段:替换为"单层线性层",零初始化权重,减少参数冗余,提升小数据集泛化能力。
6. 混合架构(Hybrid Architecture)
- 作为ViT的变体,混合架构解决"小数据集下ViT性能不足"的问题:将CNN(如ResNet)的中间特征图作为输入,替代原始图像块。
- 具体流程 :
- 用ResNet提取图像特征(如ResNet50的第3/4阶段特征图);
- 将特征图分割为1×1的"伪块"(避免再次分块损失信息);
- 对伪块执行线性嵌入、添加位置嵌入后输入Transformer;
- 优势:结合CNN的局部特征提取能力(归纳偏置)和Transformer的全局建模能力,在ImageNet-1k(130万张图)等中小数据集上性能接近CNN,无需依赖超大规模预训练。
7. 微调与高分辨率适配
- 核心微调策略 :
- 冻结预训练的ViT骨干网络,仅训练任务特定的线性分类头(快速适配);
- 解冻部分编码器层(如最后6层),联合微调(更高精度);
- 高分辨率适配关键 :
- 保持块大小 P P P不变,高分辨率图像会生成更多块(如224×224→512×512,块数量从196→1024);
- 对预训练的位置嵌入执行双线性插值,扩展为新长度的位置嵌入,保证空间位置的连续性;
- 论文验证:ViT-L/16在512×512分辨率下微调,ImageNet准确率比224×224提升1.5%以上。
1.3 代码示例
结合代码与可视化,ViT的核心流程可拆解为以下步骤:
-
图片序列化(Patch嵌入)
- 操作:将224×224×3的输入图切分为16×16的Patch,通过线性层将每个Patch(16×16×3=768像素)投影为768维特征向量
- 对应代码:
model.patch_embed - 形状变化:输入
[1,3,224,224]→ 输出[1,196,768](196为Patch数量:14×14)
-
补全序列信息
- 操作:在Patch序列前加
[1,768]的Class Token(用于分类),并添加Position Embedding注入空间信息 - 形状变化:序列长度从196→197,对应
Block_1输入形状[1,197,768]
- 操作:在Patch序列前加
-
Transformer Block特征抽象
- 操作:通过12个堆叠的Block(每块含"多头自注意力+多层感知机"),捕捉全局依赖并增强特征表达
- 形状特点:所有Block的输入/输出形状均保持
[1,197,768],仅特征内容逐步抽象
-
分类头输出结果
- 操作:提取Class Token的特征,经
model.norm(层归一化)+model.head(线性层)映射到1000类 - 形状变化:输出
[1,1000](对应ImageNet分类结果)
- 操作:提取Class Token的特征,经
-
特征可视化变化
- PatchEmbed:保留猫的轮廓,是初始Patch级特征
- Block_1→Block_12:从"像素细节"逐步抽象为"猫的核心语义信息"
- Class Token:作为分类载体,持续整合所有Patch的语义信息
- 分类结果:Top10预测成功识别出"猫"相关类别

python
import os
import matplotlib.pyplot as plt
import numpy as np
import requests
import timm
import torch
from PIL import Image
# 基础配置
model_name = 'vit_base_patch16_224'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
TEMP_IMAGE_PATH = "./temp_coco_image.jpg" # 下载到当前路径
vis_channels = 8
np.random.seed(42)
torch.manual_seed(42)
# ImageNet简化标签(兼容所有timm版本)
IMAGENET_CLASSES_SIMPLE = {
281: 'tabby, tabby cat',
282: 'tiger cat',
283: 'Persian cat',
284: 'Siamese cat, Siamese',
285: 'Egyptian cat',
151: 'Chihuahua',
152: 'Japanese spaniel',
153: 'Maltese dog, Maltese terrier, Maltese',
**{i: f'Class_{i}' for i in range(1000)}
}
# 下载图片到当前路径
def download_image_from_url(url, save_path):
try:
print(f"正在从URL下载图片: {url}")
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
with open(save_path, 'wb') as f:
f.write(response.content)
print(f"图片下载成功,保存到: {save_path}")
return True
except Exception as e:
print(f"图片下载失败!错误信息: {e}")
return False
# 反归一化张量为可显示图片
def denormalize_tensor(tensor, mean, std):
tensor = tensor.clone()
for t, m, s in zip(tensor, mean, std):
t.mul_(s).add_(m)
tensor = torch.clamp(tensor, 0, 1)
return tensor.numpy().transpose(1, 2, 0)
# 获取分类头Top-K预测结果
def get_topk_predictions(logits, k=10):
probs = torch.softmax(logits, dim=1)
topk_probs, topk_indices = torch.topk(probs, k=k)
topk_probs = topk_probs.cpu().numpy()[0]
topk_indices = topk_indices.cpu().numpy()[0]
topk_names = [IMAGENET_CLASSES_SIMPLE.get(idx, f'Class_{idx}') for idx in topk_indices]
return topk_names, topk_probs
# 加载模型并注册钩子提取特征
model = timm.create_model(model_name, pretrained=True).to(device).eval()
core_layer_info = []
def core_hook_fn(module, input, output):
if isinstance(output, torch.Tensor):
feature = output.detach().cpu().numpy()
core_layer_info.append({
'module': module.__class__.__name__,
'input_shape': input[0].shape if (input and len(input) > 0) else None,
'output_shape': output.shape,
'feature': feature
})
# 注册各层钩子
model.patch_embed.register_forward_hook(core_hook_fn)
for block in model.blocks:
block.register_forward_hook(core_hook_fn)
model.norm.register_forward_hook(core_hook_fn)
model.head.register_forward_hook(core_hook_fn)
# 下载并预处理图片
download_success = download_image_from_url(IMAGE_URL, TEMP_IMAGE_PATH)
original_img = None
denorm_img = None
if download_success and os.path.exists(TEMP_IMAGE_PATH):
data_config = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_config)
original_img = Image.open(TEMP_IMAGE_PATH).convert('RGB')
print(f"\n下载的图片原始尺寸: {original_img.size} (宽×高)")
x = transform(original_img).unsqueeze(0).to(device)
print(f"预处理后输入张量形状: {x.shape}")
mean = data_config['mean']
std = data_config['std']
denorm_img = denormalize_tensor(x[0].cpu(), mean, std)
else:
print("\n图片下载失败,自动使用模拟图片替代...")
x = torch.randn(1, 3, 224, 224).to(device)
original_img = Image.fromarray(np.uint8(np.random.rand(224, 224, 3) * 255))
denorm_img = np.random.rand(224, 224, 3)
# 前向传播提取特征
with torch.no_grad():
output = model(x)
# 重命名Block层
for i in range(12):
core_layer_info[1 + i]['module'] = f'Block_{i + 1}'
# 打印核心层信息
print("\n" + "=" * 70)
print("ViT-Base 核心层级信息(COCO图片输入):")
print("=" * 70)
print(f"输入图片最终形状: {x.shape}")
print("=" * 70)
for idx, layer in enumerate(core_layer_info):
print(f"[{idx + 1}] 核心层: {layer['module']}")
print(f" 输入形状: {layer['input_shape']}")
print(f" 输出形状: {layer['output_shape']}")
print("-" * 50)
print(f"最终分类头输出形状: {output.shape}")
print(f"预测类别索引: {torch.argmax(output, dim=1).item()}")
# 完整特征可视化(一张大图)
def visualize_all_vit_features(layer_info, original_img, denorm_img, output_logits):
plt.rcParams['font.sans-serif'] = ['Arial']
fig, axes = plt.subplots(5, 3, figsize=(20, 25))
fig.suptitle('ViT-Base Full Feature Visualization (COCO Image)', fontsize=20, fontweight='bold', y=0.98)
# 第1行:原始图 + 预处理图 + PatchEmbed
ax1 = axes[0, 0]
ax1.imshow(original_img)
ax1.set_title(f'Original Image\nSize: {original_img.size[0]}×{original_img.size[1]}', fontsize=12)
ax1.axis('off')
ax2 = axes[0, 1]
ax2.imshow(denorm_img)
ax2.set_title(f'Preprocessed Image\nSize: {denorm_img.shape[0]}×{denorm_img.shape[1]}', fontsize=12)
ax2.axis('off')
ax3 = axes[0, 2]
pe_feature = layer_info[0]['feature'][0]
pe_patch_feature = pe_feature if pe_feature.shape[0] != 197 else pe_feature[1:, :]
pe_patch_2d = pe_patch_feature.T.reshape(-1, 14, 14)
pe_vis_feature = pe_patch_2d[:vis_channels]
pe_vis_feature = (pe_vis_feature - pe_vis_feature.min()) / (pe_vis_feature.max() - pe_vis_feature.min())
pe_grid = np.concatenate([np.concatenate(pe_vis_feature[:4], axis=1), np.concatenate(pe_vis_feature[4:], axis=1)],
axis=0)
im_pe = ax3.imshow(pe_grid, cmap='viridis')
ax3.set_title(f'PatchEmbed\nTop {vis_channels} Channels (14×14)', fontsize=10)
ax3.axis('off')
plt.colorbar(im_pe, ax=ax3, shrink=0.5)
# 第2行:Block_1 + Block_3 + Block_6
block_layers_1 = [('Block_1', layer_info[1]), ('Block_3', layer_info[3]), ('Block_6', layer_info[6])]
for idx, (layer_name, info) in enumerate(block_layers_1):
ax = axes[1, idx]
feature = info['feature'][0]
patch_feature = feature[1:, :] if feature.shape[0] == 197 else feature
patch_2d = patch_feature.T.reshape(-1, 14, 14)
vis_feature = patch_2d[:vis_channels]
vis_feature = (vis_feature - vis_feature.min()) / (vis_feature.max() - vis_feature.min())
grid = np.concatenate([np.concatenate(vis_feature[:4], axis=1), np.concatenate(vis_feature[4:], axis=1)],
axis=0)
im = ax.imshow(grid, cmap='viridis')
ax.set_title(f'{layer_name}\nTop {vis_channels} Channels (14×14)', fontsize=10)
ax.axis('off')
plt.colorbar(im, ax=ax, shrink=0.5)
# 第3行:Block_9 + Block_12 + LayerNorm
block_layers_2 = [('Block_9', layer_info[9]), ('Block_12', layer_info[12]), ('LayerNorm', layer_info[13])]
for idx, (layer_name, info) in enumerate(block_layers_2):
ax = axes[2, idx]
feature = info['feature'][0]
patch_feature = feature[1:, :] if feature.shape[0] == 197 else feature
patch_2d = patch_feature.T.reshape(-1, 14, 14)
vis_feature = patch_2d[:vis_channels]
vis_feature = (vis_feature - vis_feature.min()) / (vis_feature.max() - vis_feature.min())
grid = np.concatenate([np.concatenate(vis_feature[:4], axis=1), np.concatenate(vis_feature[4:], axis=1)],
axis=0)
im = ax.imshow(grid, cmap='viridis')
ax.set_title(f'{layer_name}\nTop {vis_channels} Channels (14×14)', fontsize=10)
ax.axis('off')
plt.colorbar(im, ax=ax, shrink=0.5)
# 第4行:Class Token + 空白占位
ax_ct = axes[3, 0]
class_token = layer_info[13]['feature'][0, 0, :]
top_dim = 20
dim_indices = np.arange(top_dim)
dim_values = class_token[:top_dim]
ax_ct.bar(dim_indices, dim_values, color='darkred', alpha=0.7)
ax_ct.set_title(f'Class Token\nFirst 20 Dimensions', fontsize=10)
ax_ct.set_xlabel('Dimension Index', fontsize=8)
ax_ct.set_ylabel('Feature Value', fontsize=8)
ax_ct.grid(alpha=0.3, axis='y')
ax_ct.set_xticks(dim_indices[::2])
axes[3, 1].axis('off')
axes[3, 2].axis('off')
# 第5行:分类头Top10预测 + 空白占位
ax_head = axes[4, 0]
topk_names, topk_probs = get_topk_predictions(output_logits, k=10)
topk_names = topk_names[::-1]
topk_probs = topk_probs[::-1]
ax_head.barh(np.arange(len(topk_names)), topk_probs, color='darkblue', alpha=0.7)
ax_head.set_yticks(np.arange(len(topk_names)))
ax_head.set_yticklabels(topk_names, fontsize=7)
ax_head.set_xlabel('Prediction Probability', fontsize=8)
ax_head.set_title('Head Layer\nTop 10 Predictions', fontsize=10)
ax_head.grid(alpha=0.3, axis='x')
axes[4, 1].axis('off')
axes[4, 2].axis('off')
# 保存到当前路径
save_vis_path = "./vit_full_feature_visualization.png"
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(save_vis_path, dpi=150, bbox_inches='tight')
print(f"\n完整特征可视化大图已保存到: {save_vis_path}")
plt.show()
# 执行可视化
print("\n开始可视化所有关键层特征(一张大图)...")
visualize_all_vit_features(core_layer_info, original_img, denorm_img, output)
# 删除临时下载的图片
if os.path.exists(TEMP_IMAGE_PATH):
os.remove(TEMP_IMAGE_PATH)
print(f"\n临时图片已删除: {TEMP_IMAGE_PATH}")
python
下载的图片原始尺寸: (640, 480) (宽×高)
预处理后输入张量形状: torch.Size([1, 3, 224, 224])
======================================================================
ViT-Base 核心层级信息(COCO图片输入):
======================================================================
输入图片最终形状: torch.Size([1, 3, 224, 224])
======================================================================
[1] 核心层: PatchEmbed
输入形状: torch.Size([1, 3, 224, 224])
输出形状: torch.Size([1, 196, 768])
--------------------------------------------------
[2] 核心层: Block_1
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[3] 核心层: Block_2
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[4] 核心层: Block_3
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[5] 核心层: Block_4
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[6] 核心层: Block_5
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[7] 核心层: Block_6
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[8] 核心层: Block_7
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[9] 核心层: Block_8
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[10] 核心层: Block_9
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[11] 核心层: Block_10
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[12] 核心层: Block_11
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[13] 核心层: Block_12
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[14] 核心层: LayerNorm
输入形状: torch.Size([1, 197, 768])
输出形状: torch.Size([1, 197, 768])
--------------------------------------------------
[15] 核心层: Linear
输入形状: torch.Size([1, 768])
输出形状: torch.Size([1, 1000])
--------------------------------------------------
最终分类头输出形状: torch.Size([1, 1000])
预测类别索引: 761
二、论文翻译:一幅图像值16x16个词:大规模图像识别的Transformer
AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
摘要
虽然Transformer架构已成为自然语言处理任务的事实标准,但其在计算机视觉中的应用仍然有限。在视觉领域,注意力机制要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持其整体结构不变。我们证明这种对CNN的依赖并非必要,直接应用于图像块序列的纯Transformer可以在图像分类任务上表现优异。当在大量数据上进行预训练并迁移到多个中型或小型图像识别基准(ImageNet, CIFAR-100, VTAB等)时,视觉Transformer(ViT)与最先进的卷积网络相比获得了优异的结果,同时训练所需的计算资源显著减少。 1 {}^{1} 1
1 引言
基于自注意力的架构,特别是Transformer(Vaswani et al., 2017),已成为自然语言处理(NLP)的首选模型。主流方法是在大型文本语料库上进行预训练,然后在较小的特定任务数据集上进行微调(Devlin et al., 2019)。得益于Transformer的计算效率和可扩展性,训练具有超过1000亿参数的模型已成为可能(Brown et al., 2020; Lepikhin et al., 2020)。随着模型和数据集的增长,性能仍未出现饱和迹象。
然而,在计算机视觉中,卷积架构仍然占主导地位(LeCun et al., 1989; Krizhevsky et al., 2012; He et al., 2016)。受NLP成功的启发,许多工作尝试将类似CNN的架构与自注意力结合(Wang et al., 2018; Carion et al., 2020),一些工作完全替换了卷积(Ramachandran et al., 2019; Wang et al., 2020a)。后一类模型虽然在理论上高效,但由于使用了专门的注意力模式,尚未在现代硬件加速器上有效扩展。因此,在大规模图像识别中,经典的类ResNet架构仍然是最先进的(Mahajan et al., 2018; Xie et al., 2020; Kolesnikov et al., 2020)。
受Transformer在NLP中缩放成功的启发,我们尝试将标准Transformer直接应用于图像,并尽可能减少修改。为此,我们将图像分割成块,并将这些块的线性嵌入序列作为Transformer的输入。图像块的处理方式与NLP应用中的词元(单词)相同。我们以监督方式在图像分类上训练模型。
当在没有强正则化的中型数据集(如ImageNet)上训练时,这些模型的准确率比同等大小的ResNet低几个百分点。这一看似令人沮丧的结果可能在意料之中:Transformer缺乏CNN固有的一些归纳偏置,例如平移等变性和局部性,因此在数据量不足时泛化能力不佳。
然而,如果在更大数据集(14M-300M图像)上训练模型,情况会发生变化。我们发现大规模训练战胜了归纳偏置。当在足够规模上预训练并迁移到数据点较少的任务时,我们的视觉Transformer(ViT)取得了优异的结果。当在公共ImageNet-21k数据集或内部JFT-300M数据集上预训练时,ViT在多个图像识别基准上接近或击败了最先进技术。特别是,最佳模型在ImageNet上达到88.55%的准确率,在ImageNet-ReaL上达到90.72%,在CIFAR-100上达到94.55%,在包含19个任务的VTAB套件上达到77.63%。
2 相关工作
Transformer由Vaswani等人(2017)提出用于机器翻译,此后已成为许多NLP任务的最先进方法。基于Transformer的大型模型通常在大规模语料库上预训练,然后为手头任务进行微调:BERT(Devlin et al., 2019)使用去噪自监督预训练任务,而GPT系列工作使用语言建模作为预训练任务(Radford et al., 2018; 2019; Brown et al., 2020)。
将自注意力直接应用于图像需要每个像素关注所有其他像素。由于像素数量的二次方成本,这无法扩展到实际的输入大小。因此,为了在图像处理背景下应用Transformer,过去已经尝试了几种近似方法。Parmar等人(2018)仅在每个查询像素的局部邻域内应用自注意力,而不是全局应用。这种局部多头点积自注意力块可以完全替代卷积(Hu et al., 2019; Ramachandran et al., 2019; Zhao et al., 2020)。在另一项工作中,稀疏Transformer(Child et al., 201)采用可扩展的全局自注意力近似以适用于图像。另一种缩放注意力的方法是将其应用于不同大小的块中(Weissenborn et al., 2019),在极端情况下仅沿单个轴应用(Ho et al., 2019; Wang et al., 2020a)。许多这些专门的注意力架构在计算机视觉任务上展示了有希望的结果,但需要复杂的工程才能在硬件加速器上高效实现。
与我们的工作最相关的是Cordonnier等人(2020)的模型,该模型从输入图像中提取2x2大小的块并在其上应用完整的自注意力。该模型与ViT非常相似,但我们的工作进一步证明大规模预训练使普通Transformer能够与最先进的CNN竞争(甚至更好)。此外,Cordonnier等人(2020)使用2x2像素的小块大小,这使得模型仅适用于小分辨率图像,而我们也处理中等分辨率图像。
结合卷积神经网络(CNN)与各种形式的自注意力也引起了广泛兴趣,例如通过增强图像分类的特征图(Bello et al., 2019)或通过自注意力进一步处理CNN的输出,例如用于目标检测(Hu et al., 2018; Carion et al., 2020)、视频处理(Wang et al., 2018; Sun et al., 2019)、图像分类(Wu et al., 2020)、无监督目标发现(Locatello et al., 2020)或统一文本-视觉任务(Chen et al., 2020c; Lu et al., 2019; Li et al., 2019)。
另一个最近的相关模型是图像GPT(iGPT)(Chen et al., 2020a),该模型在降低图像分辨率和颜色空间后将Transformer应用于图像像素。该模型以无监督方式作为生成模型进行训练,生成的表示可以随后进行微调或线性探测以评估分类性能,在ImageNet上达到最大72%的准确率。
我们的工作增加了越来越多探索比标准ImageNet数据集更大规模图像识别的论文集合。使用额外数据源可以在标准基准上实现最先进的结果(Mahajan et al., 2018; Touvron et al., 2019; Xie et al., 2020)。此外,Sun等人(2017)研究了CNN性能如何随数据集大小缩放,Kolesnikov等人(2020);Djolonga等人(2020)对从大规模数据集(如ImageNet-21k和JFT-300M)进行CNN迁移学习进行了实证探索。我们也关注后两个数据集,但训练的是Transformer而不是先前工作中使用的基于ResNet的模型。
3 方法
在模型设计中,我们尽可能遵循原始Transformer(Vaswani et al., 2017)。这种有意简单设置的一个优点是,可扩展的NLP Transformer架构及其高效实现几乎可以即插即用。

3.1 视觉Transformer(VIT)
模型概述如图1所示。标准Transformer接收1D词元嵌入序列作为输入。为了处理2D图像,我们将图像 x ∈ R H × W × C x\in R^{H\times W\times C} x∈RH×W×C 重塑为展平的2D块序列 x p ∈ R N × ( P 2 ⋅ C ) x_{p}\in R^{N\times(P^{2}\cdot C)} xp∈RN×(P2⋅C) ,其中 ( H , W ) (H,W) (H,W) 是原始图像的分辨率, C C C是通道数, ( P , P ) (P,P) (P,P)是每个图像块的分辨率, N = H W / P 2 N=HW/P^{2} N=HW/P2 是生成的块数,也作为Transformer的有效输入序列长度。Transformer在所有层中使用恒定的潜在向量大小 D D D,因此我们将块展平并使用可训练的线性投影映射到 D D D维(公式1)。我们将此投影的输出称为块嵌入。
与BERT的[class]词元类似,我们在嵌入块序列前添加一个可学习的嵌入( z 0 0 = x class z_0^0 = x_{\text{class}} z00=xclass ),其在Transformer编码器输出处的状态( z L 0 z_L^0 zL0 )作为图像表示y(公式4)。在预训练和微调期间,分类头都附加到 z L 0 z_L^0 zL0 上。分类头在预训练时由具有一个隐藏层的MLP实现,在微调时由单个线性层实现。
位置嵌入被添加到块嵌入中以保留位置信息。我们使用标准的可学习1D位置嵌入,因为我们没有观察到使用更先进的2D感知位置嵌入带来显著的性能提升(附录D.4)。生成的嵌入向量序列作为编码器的输入。
Transformer编码器(Vaswani et al., 2017)由多头自注意力(MSA,见附录A)和MLP块交替层组成(公式2,3)。层归一化(LN)应用于每个块之前,残差连接应用于每个块之后(Wang et al., 2019; Baevski& Auli, 2019)。MLP包含两层,使用GELU非线性激活函数。
z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s , E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D ( 1 ) z_{0}=[x_{class};\,x_{p}^{1}E;\,x_{p}^{2}E;\cdots;\,x_{p}^{N}E]+E_{pos},\qquad E\in R^{(P^{2}\cdot C)\times D},\,E_{pos}\in R^{(N+1)\times D}\qquad(1) z0=[xclass;xp1E;xp2E;⋯;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D(1)
z ′ ℓ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 , ℓ = 1 ... L ( 2 ) z^{\prime}{}{\ell}=MSA(LN(z{\ell-1}))+z_{\ell-1},\qquad\ell=1\ldots L\qquad(2) z′ℓ=MSA(LN(zℓ−1))+zℓ−1,ℓ=1...L(2)
z ℓ = M L P ( L N ( z ′ ℓ ) ) + z ′ ℓ , ℓ = 1 ... L ( 3 ) z_{\ell}=MLP(LN(z^{\prime}{}{\ell}))+z^{\prime}{}{\ell},\qquad\ell=1\dots L\qquad(3) zℓ=MLP(LN(z′ℓ))+z′ℓ,ℓ=1...L(3)
y = L N ( z L 0 ) ( 4 ) y=LN(z_{L}^{0})\qquad(4) y=LN(zL0)(4)
归纳偏置。我们注意到视觉Transformer比CNN具有更少的图像特定归纳偏置。在CNN中,局部性、二维邻域结构和平移等变性被烘焙到整个模型的每一层中。在ViT中,只有MLP层是局部和平移等变的,而自注意力层是全局的。二维邻域结构的使用非常节俭:在模型开始时通过将图像切割成块,以及在微调时为了调整不同分辨率图像的位置嵌入(如下所述)。除此之外,初始化时的位置嵌入不携带关于块的2D位置的信息,所有块之间的空间关系都必须从头开始学习。
混合架构。作为原始图像块的替代方案,输入序列可以由CNN的特征图形成(LeCun et al., 1989)。在这种混合模型中,块嵌入投影E(公式1)应用于从CNN特征图提取的块。作为一种特殊情况,块可以具有1x1的空间大小,这意味着输入序列是通过简单展平特征图的空间维度并投影到Transformer维度而获得的。分类输入嵌入和位置嵌入按上述方式添加。
3.2 微调和更高分辨率
通常,我们在大型数据集上预训练ViT,并微调到(较小的)下游任务。为此,我们移除预训练的预测头,并附加一个零初始化的D x K前馈层,其中K是下游类别数。以比预训练更高的分辨率进行微调通常是有益的(Touvron et al., 2019; Kolesnikov et al., 2020)。当输入更高分辨率的图像时,我们保持块大小不变,这会导致更大的有效序列长度。视觉Transformer可以处理任意序列长度(受内存限制),然而,预训练的位置嵌入可能不再有意义。因此,我们根据预训练位置嵌入在原始图像中的位置进行2D插值。请注意,这种分辨率调整和块提取是手动将关于图像2D结构的归纳偏置注入视觉Transformer的唯一时刻。
4 实验
我们评估了ResNet、视觉Transformer(ViT)和混合模型的表示学习能力。为了理解每个模型的数据需求,我们在不同大小的数据集上预训练并评估许多基准任务。当考虑预训练模型的计算成本时,ViT表现非常有利,以较低的预训练成本在大多数识别基准上达到最先进水平。最后,我们使用自监督进行了一个小实验,并表明自监督ViT对未来充满希望。
4.1 设置
数据集 。为了探索模型的可扩展性,我们使用具有1k类别和130万张图像的ILSVRC-2012 ImageNet数据集(下文称为ImageNet),其超集ImageNet-21k具有21k类别和1400万张图像(Deng et al., 2009),以及JFT(Sun et al., 2017)具有18k类别和3.03亿张高分辨率图像。我们按照Kolesnikov等人(2020)的方法对预训练数据集进行去重,去除下游任务测试集中的重复数据。我们将这些数据集上训练的模型迁移到几个基准任务:原始验证标签和清理后的ReaL标签(Beyer et al., 2020)的ImageNet、CIFAR-10/100(Krizhevsky, 2009)、Oxford-IIIT Pets(Parkhi et al., 2012)和Oxford Flowers-102(Nilsback& Zisserman, 2008)。对于这些数据集,预处理遵循Kolesnikov等人(2020)。

模型变体。我们基于BERT(Devlin et al., 2019)使用的配置制定ViT配置,如表1所示。"Base"和"Large"模型直接采用自BERT,我们添加了更大的"Huge"模型。在下文中,我们使用简写表示模型大小和输入块大小:例如,ViT-L/16表示具有16x16输入块大小的"Large"变体。请注意,Transformer的序列长度与块大小的平方成反比,因此具有较小块大小的模型在计算上更昂贵。
对于基线CNN,我们使用ResNet(He et al., 2016),但将批归一化层(Ioffe& Szegedy, 2015)替换为组归一化(Wu& He, 2018),并使用标准化卷积(Qiao et al., 2019)。这些修改改善了迁移(Kolesnikov et al., 2020),我们将修改后的模型称为"ResNet(BiT)"。对于混合模型,我们将中间特征图输入到ViT,块大小为1个"像素"。为了实验不同的序列长度,我们要么(i)取常规ResNet50第4阶段的输出,要么(ii)移除第4阶段,在第3阶段放置相同数量的层(保持总层数),并取此扩展的第3阶段的输出。选项(ii)导致序列长度增加4倍,以及更昂贵的ViT模型。
训练和微调 。我们使用Adam(Kingma& Ba, 2015)训练所有模型,包括ResNet,其中 β 1 = 0.9 , β 2 = 0.999 \beta_{1}=0.9,\beta_{2}=0.999 β1=0.9,β2=0.999 ,批大小为4096,并应用0.1的高权重衰减,我们发现这对所有模型的迁移都有用(附录D.1显示,与常见做法相反,在我们的设置中,Adam略优于SGD用于ResNet)。我们使用线性学习率预热和衰减,详见附录B.1。对于微调,我们对所有模型使用带动量的SGD,批大小为512,见附录B.1.1。对于表2中的ImageNet结果,我们以更高分辨率进行微调:ViT-L/16为512,ViT-H/14为518,并且还使用了Polyak& Juditsky(1992)平均,因子为0.9999(Ramachandran et al., 2019; Wang et al., 2020b)。
指标。我们通过少样本或微调准确率报告下游数据集的结果。微调准确率捕获每个模型在相应数据集上微调后的性能。少样本准确率是通过求解正则化最小二乘回归问题获得的,该问题将(冻结的)训练图像子集的表示映射到{-1,1}目标向量。这种表述允许我们以封闭形式恢复精确解。尽管我们主要关注微调性能,但有时在微调成本过高的情况下,使用线性少样本准确率进行快速即时评估。
4.2 与最先进技术的比较
我们首先将我们最大的模型------ViT-H/14和ViT-L/16------与文献中最先进的CNN进行比较。第一个比较点是Big Transfer(BiT)(Kolesnikov et al., 2020),它使用大型ResNet进行监督迁移学习。第二个是Noisy Student(Xie et al., 2020),这是一个大型EfficientNet,在ImageNet和JFT-300M上使用半监督学习进行训练(标签已移除)。目前,Noisy Student是ImageNet上的最先进技术,BiT-L是此处报告的其他数据集上的最先进技术。所有模型都在TPUv3硬件上训练,我们报告预训练每个模型所花费的TPUv3核心天数,即用于训练的TPU v3核心数(每个芯片2个)乘以训练天数。

表2显示了结果。在JFT-300M上预训练的较小ViT-L/16模型在所有任务上都优于BiT-L(在同一数据集上预训练),同时训练所需的计算资源显著减少。更大的模型ViT-H/14进一步提高了性能,尤其是在更具挑战性的数据集上------ImageNet、CIFAR-100和VTAB套件。有趣的是,该模型的预训练计算量仍然远低于先前的最先进技术。然而,我们注意到预训练效率不仅受架构选择的影响,还受其他参数的影响,例如训练计划、优化器、权重衰减等。我们在第4.4节中对不同架构的性能与计算量进行了对照研究。最后,在公共ImageNet-21k数据集上预训练的ViT-L/16模型在大多数数据集上也表现良好,同时预训练所需的资源更少:它可以使用标准的8核云TPUv3在大约30天内训练完成。

图2将VTAB任务分解为各自的组,并与该基准上的先前SOTA方法进行比较:BiT、VIVI------一个在ImageNet和YouTube上共同训练的ResNet(Tschannen et al., 2020),以及S4L------在ImageNet上的监督加半监督学习(Zhai et al., 2019a)。ViT-H/14在自然和结构化任务上优于BiT-R152x4和其他方法。在专业任务上,前两个模型的性能相似。
4.3 预训练数据需求
视觉Transformer在大型JFT-300M数据集上预训练时表现良好。与ResNet相比,ViT具有更少的视觉归纳偏置,那么数据集大小有多关键?我们进行了两个系列的实验。

首先,我们在不断增加大小的数据集上预训练ViT模型:ImageNet、ImageNet-21k和JFT-300M。为了提升在较小数据集上的性能,我们优化了三个基本正则化参数------权重衰减、dropout和标签平滑。图3显示了微调到ImageNet后的结果(其他数据集的结果显示在表5中) 2 {}^{2} 2。当在最小的数据集ImageNet上预训练时,尽管有(适度的)正则化,ViT-Large模型的表现仍低于ViT-Base模型。使用ImageNet-21k预训练时,它们的性能相似。只有使用JFT-300M,我们才能看到更大模型的全部优势。图3还显示了不同大小的BiT模型所跨越的性能区域。BiT CNN在ImageNet上优于ViT,但在更大的数据集上,ViT反超。
其次,我们在9M、30M和90M的随机子集以及完整的JFT-300M数据集上训练我们的模型。我们没有在较小的子集上执行额外的正则化,并对所有设置使用相同的超参数。通过这种方式,我们评估了内在的模型特性,而不是正则化的效果。然而,我们确实使用了早停,并报告了训练期间达到的最佳验证准确率。为了节省计算,我们报告少样本线性准确率而不是完整微调准确率。图4包含了结果。视觉Transformer在计算成本相当的小型数据集上比ResNet更容易过拟合。例如,ViT-B/32比ResNet50稍快;它在9M子集上表现差得多,但在90M+子集上表现更好。ResNet152x2和ViT-L/16也是如此。这一结果强化了直觉,即卷积归纳偏置对于较小数据集有用,但对于较大数据集,直接从数据中学习相关模式是足够的,甚至是有益的。
总体而言,ImageNet上的少样本结果(图4)以及VTAB上的低数据结果(表2)对于极低数据迁移似乎很有希望。进一步分析ViT的少样本特性是未来工作的一个令人兴奋的方向。
4.4 缩放研究
我们通过评估从JFT-300M迁移的性能,对不同模型进行了受控的缩放研究。在这种设置下,数据大小不会限制模型的性能,我们评估每个模型的性能与预训练成本。模型集包括:7个ResNet,R50x1、R50x2、R101x1、R152x1、R152x2,预训练7个周期,加上R152x2和R200x3预训练14个周期;6个视觉Transformer,ViT-B/32、B/16、L/32、L/16,预训练7个周期,加上L/16和H/14预训练14个周期;以及5个混合模型,R50+ViT-B/32、B/16、L/32、L/16预训练7个周期,加上R50+ViT-L/16预训练14个周期(对于混合模型,模型名称末尾的数字不代表块大小,而是ResNet主干中的总下采样率)。

图5包含了迁移性能与总预训练计算量的关系(计算成本详情见附录D.5)。每个模型的详细结果在附录的表6中提供。可以观察到几种模式。首先,视觉Transformer在性能/计算权衡上主导ResNet。ViT使用大约少2-4倍的计算量达到相同性能(5个数据集的平均值)。其次,混合模型在较小计算预算下略优于ViT,但对于较大模型,差异消失。这一结果有些令人惊讶,因为人们可能期望卷积局部特征处理在任何规模下都能辅助ViT。第三,视觉Transformer在尝试的范围内似乎没有饱和,激励未来的缩放努力。
4.5 检查视觉Transformer
为了开始理解视觉Transformer如何处理图像数据,我们分析了其内部表示。视觉Transformer的第一层将展平的块线性投影到低维空间(公式1)。图7(左)显示了学习到的嵌入滤波器的顶部主成分。这些成分类似于每个块内精细结构的低维表示的合理基函数。

投影后,学习到的位置嵌入被添加到块表示中。图7(中)显示模型学习在位置嵌入的相似性中对图像中的距离进行编码,即较近的块往往具有更相似的位置嵌入。此外,出现了行列结构;同一行/列中的块具有相似的嵌入。最后,对于较大的网格,有时会出现正弦结构(附录D)。位置嵌入学习表示2D图像拓扑解释了为什么手工制作的2D感知嵌入变体没有产生改进(附录D.4)。
自注意力允许ViT即使在最低层也能整合整个图像的信息。我们调查了网络在多大程度上利用这种能力。具体来说,我们基于注意力权重计算图像空间中信息被整合的平均距离(图7,右)。这种"注意力距离"类似于CNN中的感受野大小。
我们发现一些头部在最低层已经关注图像的大部分,表明整合全局信息的能力确实被模型使用。其他注意力头部在低层具有持续较小的注意力距离。这种高度局部化的注意力在应用ResNet before Transformer的混合模型中不太明显(图7,右),表明其可能服务于与CNN中早期卷积层类似的功能。此外,注意力距离随着网络深度增加。全局上,我们发现模型关注与分类语义相关的图像区域(图6)。

4.6 自监督
Transformer在NLP任务上显示出令人印象深刻的性能。然而,其大部分成功不仅源于出色的可扩展性,还源于大规模自监督预训练(Devlin et al., 2019; Radford et al., 2018)。我们还对用于自监督的掩码块预测进行了初步探索,模仿BERT中使用的掩码语言建模任务。通过自监督预训练,我们较小的ViT-B/16模型在ImageNet上达到79.9%的准确率,比从头训练显著提高了2%,但仍比监督预训练低4%。附录B.1.2包含进一步细节。我们将对比预训练的探索(Chen et al., 2020b; He et al., 2020; Bachman et al., 2019; Hénaff et al., 2020)留给未来工作。
5 结论
我们探索了将Transformer直接应用于图像识别。与先前在计算机视觉中使用自注意力的工作不同,除了初始块提取步骤外,我们没有将图像特定的归纳偏置引入架构。相反,我们将图像解释为块序列,并使用NLP中使用的标准Transformer编码器进行处理。这种简单但可扩展的策略在与大规模数据集预训练结合时效果出人意料地好。因此,视觉Transformer在许多图像分类数据集上匹配或超过了最先进技术,同时预训练相对便宜。
虽然这些初步结果令人鼓舞,但许多挑战仍然存在。一是将ViT应用于其他计算机视觉任务,例如检测和分割。我们的结果与Carion等人(2020)的结果结合表明了这种方法的前景。另一个挑战是继续探索自监督预训练方法。我们的初步实验显示自监督预训练有所改进,但自监督与大规模监督预训练之间仍然存在巨大差距。最后,进一步缩放ViT可能会带来改进的性能。