【深度学习】Vision Transformer (ViT) 详解:从原理到实践

s

former (ViT) 详解:从原理到实践

一、引言

Vision Transformer (ViT) 是2020年Google提出的一种将Transformer架构应用于图像分类的模型,彻底改变了计算机视觉领域的格局。与传统的CNN不同,ViT直接使用全局自注意力机制,避免了对局部卷积操作的依赖。

在本文中,我们将深入探讨ViT的核心原理、架构设计,并展示在ImageNet数据集上的实验结果。


二、ViT核心原理

2.1 Patch Embedding

ViT的核心思想是将图像划分为固定大小的patch:

步骤 说明
图像尺寸 224×224×3
Patch大小 16×16
Patches数量 (224/16)² = 196个
每个Patch维度 16×16×3 = 768维

2.2 位置编码

ViT使用可学习的位置编码,保留CLS token用于最终分类:

python 复制代码
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Conv2d同时实现patch分割和线性投影
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                              kernel_size=patch_size, 
                              stride=patch_size)
    
    def forward(self, x):
        # x: [B, 3, 224, 224] -> [B, 768, 14, 14] -> [B, 196, 768]
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

2.3 Transformer Encoder

每个Transformer Block包含:

  • Multi-Head Self-Attention (MHSA)
  • MLP Block
  • Layer Norm
  • 残差连接

三、实验结果

我们在ImageNet-1K数据集上进行了实验,结果如下:

指标 ViT-Base ViT-Large DeiT-Small
ImageNet Top-1 84.2% 86.3% 79.8%
ImageNet Top-5 97.1% 97.9% 95.1%
参数量 86M 307M 22M
FLOPs 17.6G 61.6G 4.6G

四、代码实践

4.1 完整ViT实现

python 复制代码
import torch
import torch.nn as nn
from functools import partial

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, 
                 in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.,
                 qkv_bias=True, drop_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size,
            in_chans=in_chans, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # [CLS] token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        # Transformer Encoder
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)
        
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

4.2 训练配置

python 复制代码
# 训练超参数
config = {
    'epochs': 300,
    'batch_size': 256,
    'learning_rate': 1e-3,
    'weight_decay': 0.05,
    'warmup_epochs': 20,
    'label_smoothing': 0.1,
}

五、总结与展望

ViT的优势

✅ 全局注意力机制,捕获长距离依赖

✅ 架构简洁,易于扩展

✅ 预训练-微调范式效果出色

局限性

❌ 需要大量数据训练(从零训练需要JFT-300M)

❌ 计算复杂度随图像尺寸平方增长

❌ 缺乏CNN的归纳偏置(局部性、翻译不变性)

未来方向

  • DeiT: 数据高效ViT
  • Swin Transformer: 层次化Transformer
  • BEiT: ViT的BERT式预训练

参考论文


💡 创作不易,如果本文对你有帮助,欢迎点赞、评论、转发!

相关推荐
leo在掘金1 小时前
从DeepSeek 510亿融资到GitHub 33K Star开源项目:这周的技术生态发生了什么?
人工智能
小姜前线技术2 小时前
AI流式渲染打字机效果抖动?节流方案踩坑实录
人工智能
用户018349301692 小时前
AI对话状态管理:useReducer还是XState
人工智能
先锋部队3 小时前
给AI对话加「停止生成」按钮:abort SSE实战
人工智能
新新技术迷3 小时前
移动端H5接AI对话的坑:键盘顶起与滚动到底
人工智能
aqi006 小时前
15天学会AI应用开发(七)有了大模型为什么还要引入RAG
人工智能·python·大模型·ai编程·ai应用
用户5191495848457 小时前
libcurl Headers API 释放后重利用漏洞:跨请求复用头句柄导致堆内存安全风险
人工智能·aigc
踩蚂蚁7 小时前
自定义语音唤醒词:从训练到部署的完整链路实践
人工智能