【深度学习】Vision Transformer (ViT) 详解:从原理到实践

s

former (ViT) 详解:从原理到实践

一、引言

Vision Transformer (ViT) 是2020年Google提出的一种将Transformer架构应用于图像分类的模型,彻底改变了计算机视觉领域的格局。与传统的CNN不同,ViT直接使用全局自注意力机制,避免了对局部卷积操作的依赖。

在本文中,我们将深入探讨ViT的核心原理、架构设计,并展示在ImageNet数据集上的实验结果。


二、ViT核心原理

2.1 Patch Embedding

ViT的核心思想是将图像划分为固定大小的patch:

步骤 说明
图像尺寸 224×224×3
Patch大小 16×16
Patches数量 (224/16)² = 196个
每个Patch维度 16×16×3 = 768维

2.2 位置编码

ViT使用可学习的位置编码,保留[CLS] token用于最终分类:

python 复制代码
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Conv2d同时实现patch分割和线性投影
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                              kernel_size=patch_size, 
                              stride=patch_size)
    
    def forward(self, x):
        # x: [B, 3, 224, 224] -> [B, 768, 14, 14] -> [B, 196, 768]
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

2.3 Transformer Encoder

每个Transformer Block包含:

  • Multi-Head Self-Attention (MHSA)
  • MLP Block
  • Layer Norm
  • 残差连接

三、实验结果

我们在ImageNet-1K数据集上进行了实验,结果如下:

指标 ViT-Base ViT-Large DeiT-Small
ImageNet Top-1 84.2% 86.3% 79.8%
ImageNet Top-5 97.1% 97.9% 95.1%
参数量 86M 307M 22M
FLOPs 17.6G 61.6G 4.6G

四、代码实践

4.1 完整ViT实现

python 复制代码
import torch
import torch.nn as nn
from functools import partial

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, 
                 in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.,
                 qkv_bias=True, drop_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size,
            in_chans=in_chans, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # [CLS] token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        # Transformer Encoder
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)
        
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

4.2 训练配置

python 复制代码
# 训练超参数
config = {
    'epochs': 300,
    'batch_size': 256,
    'learning_rate': 1e-3,
    'weight_decay': 0.05,
    'warmup_epochs': 20,
    'label_smoothing': 0.1,
}

五、总结与展望

ViT的优势

✅ 全局注意力机制,捕获长距离依赖

✅ 架构简洁,易于扩展

✅ 预训练-微调范式效果出色

局限性

❌ 需要大量数据训练(从零训练需要JFT-300M)

❌ 计算复杂度随图像尺寸平方增长

❌ 缺乏CNN的归纳偏置(局部性、翻译不变性)

未来方向

  • DeiT: 数据高效ViT
  • Swin Transformer: 层次化Transformer
  • BEiT: ViT的BERT式预训练

参考论文


💡 创作不易,如果本文对你有帮助,欢迎点赞、评论、转发!

相关推荐
fuquxiaoguang1 小时前
0.8W跑10B模型:端侧AI的“寒武纪爆发“与中间件的轻量进化
人工智能·中间件·端侧ai
XMAIPC_Robot1 小时前
基于RK3588 高算力,小尺寸,轻重量6T算力无人机AI模块,可接两路同步相机模组
运维·人工智能·深度学习·fpga开发·无人机·边缘计算
SuperHeroWu71 小时前
【AI大模型】Self-Attention:为什么它能取代 RNN 解决长距离依赖?
人工智能·rnn·深度学习·循环神经网络·自注意力机制·self-attention
数信云 DCloud1 小时前
人工智能安全观察:漫谈与AI新物种相处之道
人工智能·安全·ai·智能体
朝新_1 小时前
【LangChain】少样本提示(few-shorting) 掌握 Few-Shot 提示,让大模型按你的规则输出
java·人工智能·langchain
AI科技星1 小时前
全域数学(GM)体系终极逻辑闭环综述
人工智能·线性代数·机器学习·量子计算·agi
2zcode1 小时前
原创文档:基于MATLAB卷积神经网络的多颜色车牌识别系统设计与实现
深度学习·计算机视觉·cnn
XD7429716361 小时前
科技早报|2026年5月8日:AI 开始更深地进入手表、代码库和企业网关
人工智能·科技·开发者工具·科技早报
TEC_INO1 小时前
Linux48:rockx常用的API
人工智能·计算机视觉·目标跟踪