【深度学习】Vision Transformer (ViT) 详解:从原理到实践

s

former (ViT) 详解:从原理到实践

一、引言

Vision Transformer (ViT) 是2020年Google提出的一种将Transformer架构应用于图像分类的模型,彻底改变了计算机视觉领域的格局。与传统的CNN不同,ViT直接使用全局自注意力机制,避免了对局部卷积操作的依赖。

在本文中,我们将深入探讨ViT的核心原理、架构设计,并展示在ImageNet数据集上的实验结果。


二、ViT核心原理

2.1 Patch Embedding

ViT的核心思想是将图像划分为固定大小的patch:

步骤 说明
图像尺寸 224×224×3
Patch大小 16×16
Patches数量 (224/16)² = 196个
每个Patch维度 16×16×3 = 768维

2.2 位置编码

ViT使用可学习的位置编码,保留CLS token用于最终分类:

python 复制代码
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # Conv2d同时实现patch分割和线性投影
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                              kernel_size=patch_size, 
                              stride=patch_size)
    
    def forward(self, x):
        # x: [B, 3, 224, 224] -> [B, 768, 14, 14] -> [B, 196, 768]
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

2.3 Transformer Encoder

每个Transformer Block包含:

  • Multi-Head Self-Attention (MHSA)
  • MLP Block
  • Layer Norm
  • 残差连接

三、实验结果

我们在ImageNet-1K数据集上进行了实验,结果如下:

指标 ViT-Base ViT-Large DeiT-Small
ImageNet Top-1 84.2% 86.3% 79.8%
ImageNet Top-5 97.1% 97.9% 95.1%
参数量 86M 307M 22M
FLOPs 17.6G 61.6G 4.6G

四、代码实践

4.1 完整ViT实现

python 复制代码
import torch
import torch.nn as nn
from functools import partial

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, 
                 in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.,
                 qkv_bias=True, drop_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size,
            in_chans=in_chans, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # [CLS] token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        # Transformer Encoder
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)
        
        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

4.2 训练配置

python 复制代码
# 训练超参数
config = {
    'epochs': 300,
    'batch_size': 256,
    'learning_rate': 1e-3,
    'weight_decay': 0.05,
    'warmup_epochs': 20,
    'label_smoothing': 0.1,
}

五、总结与展望

ViT的优势

✅ 全局注意力机制,捕获长距离依赖

✅ 架构简洁,易于扩展

✅ 预训练-微调范式效果出色

局限性

❌ 需要大量数据训练(从零训练需要JFT-300M)

❌ 计算复杂度随图像尺寸平方增长

❌ 缺乏CNN的归纳偏置(局部性、翻译不变性)

未来方向

  • DeiT: 数据高效ViT
  • Swin Transformer: 层次化Transformer
  • BEiT: ViT的BERT式预训练

参考论文


💡 创作不易,如果本文对你有帮助,欢迎点赞、评论、转发!

相关推荐
jiayong23几秒前
ZeroClaw 使用方式与启动指南
人工智能·ai·智能体·zeroclaw
有来有去9527几秒前
【模型评测】SWE-bench Verified数据集-1-配置评测任务
人工智能·深度学习·语言模型
Lsland..2 分钟前
AI Agent到底是什么
java·人工智能·llm
Akamai中国2 分钟前
针对 Akamai Cloud 上的 NVIDIA RTX Pro 6000 Blackwell 进行基准测试
人工智能·云计算·gpu算力·云服务
code 小楊2 分钟前
AI Agent 进阶范式 Plan-and-Execute 深度详解:原理、架构、实战与工程落地
人工智能·架构
ai产品老杨4 分钟前
解耦视频流利器:如何利用 GB28181 与 RTSP 协议统一收敛多厂商设备?一套支持 Docker 部署与源码交付的边缘计算 AI 视频中台深度解析
人工智能·docker·边缘计算
Lsland..5 分钟前
MCP协议AI时代的HTTP
人工智能·网络协议·http
谷哥的小弟7 分钟前
大模型核心基础知识(12)—机器学习的基本概念与常见方法
人工智能·深度学习·机器学习·大模型·大语言模型
csdnor_017 分钟前
Codex Desktop App 使用 Ollama 本地模型技术方案
人工智能·免费·codex·ollama
_Oracle10 分钟前
机器学习——绪论
人工智能·机器学习