视觉Transformer实战——Vision Transformer(ViT)详解与实现

视觉Transformer实战------Vision Transformer

    • [0. 前言](#0. 前言)
    • [1. ViT 技术原理](#1. ViT 技术原理)
      • [1.1 核心思想](#1.1 核心思想)
      • [1.2 使用 Transformer 处理图像数据](#1.2 使用 Transformer 处理图像数据)
    • [2. ViT 关键组件](#2. ViT 关键组件)
      • [2.1 图像分块](#2.1 图像分块)
      • [2.2 patch 嵌入](#2.2 patch 嵌入)
      • [2.3 位置编码](#2.3 位置编码)
      • [2.4 分类 token](#2.4 分类 token)
    • [3. 使用 PyTorch 实现 ViT](#3. 使用 PyTorch 实现 ViT)
      • [3.1 模型构建](#3.1 模型构建)
      • [3.2 模型训练](#3.2 模型训练)

0. 前言

在计算机视觉领域,卷积神经网络 (Convolutional Neural Network, CNN)长期以来一直是处理图像任务的主流架构。然而,随着 Transformer 在自然语言处理领域的巨大成功,研究人员开始探索将这种基于自注意力机制的架构应用于视觉任务。Vision Transformer (ViT) 是这一探索的重要里程碑,它首次证明了纯 Transformer 架构在图像分类任务上可以超越最先进的 CNN 模型。本文将详细介绍 ViT 的技术原理,并使用 PyTorch 从零开始构建 ViT 模型用于图像分类任务。

1. ViT 技术原理

1.1 核心思想

Vision Transformer (ViT) 的核心思想是将图像分割成固定大小的小块 (patch),将这些 patch 线性嵌入后加上位置编码,然后像自然语言处理 (Natuarl Language Processing, NLP)中的词元 (token) 一样将这些 patch 序列输入标准的 Transformer 编码器中进行处理。

1.2 使用 Transformer 处理图像数据

Transformer 非常擅长处理时间序列数据,图像在某种程度上也可以视为时间序列。例如,将图像分解成大小为 16 x 16 的小块,如果我们按顺序将这些图像块依次输入模型,那么这些块也具有序列格式。这与卷积神经网络非常相似,在卷积神经网络 (Convolutional Neural Network, CNN) 中,我们也将图像视为多个小块,并在块上应用卷积核(即创建一个卷积核并在图像上移动)。Transformer 会在在此基础上,增加一个基于全连接层的嵌入 (embedding) 层,这将使得每个块的大小不再是 16 x 16,而是该图像部分的密集表示,此外,还需要添加位置嵌入 (positional embedding)。

这些模型也可以仅包含编码器。例如,可以在每个操作的开头添加一个额外的词元,以创建整个图像的表示。在分类过程中,我们可以使用该词元将整个图像分类为给定的类别。ViT 架构如下图所示:

架构的其余部分与 Transformer 编码器块相同。ViT 架构的主要思想是分块并在图像块上应用位置嵌入。

2. ViT 关键组件

ViT 的成功依赖于几个精心设计的核心组件,这些组件共同实现了将 Transformer 架构有效应用于图像数据的创新方法。接下来,我们将深入剖析每个关键组件的设计原理和实现细节。

2.1 图像分块

Transformer 原本是为序列数据设计的,而图像是 2D 结构,图像分块 (Image Patching) 是将 2D 图像转换为 1D 序列的最直接方法,每个块 (patch) 相当于 NLP 中的一个 token。假设输入图像尺寸为 H × W × C (高度×宽度×通道),patch 大小为 P × P (通常 16×16),那么分块数量为 N = H W / P 2 N=HW/P^2 N=HW/P2。可以通过使用卷积实现高效分块:

python 复制代码
self.proj = nn.Conv2d(in_channels, embed_dim, 
                     kernel_size=patch_size, 
                     stride=patch_size)

较大的 patch 会丢失局部细节但计算效率高,较小的 patch 保留更多细节但增加序列长度。

2.2 patch 嵌入

patch 嵌入 (patch Embedding) 将每个 patch 展平并通过线性投影映射到 D 维空间,类似于 NLP 中的词嵌入,包括展平 patch ( P × P × C → P 2 C P×P×C → P²C P×P×C→P2C 维向量)和线性投影( P 2 C → D P²C → D P2C→D,通常 D=768),在 PyTorch 中可以使用以下代码实现:

python 复制代码
x = x.flatten(2).transpose(1,2)  # [B, N, P²C]
self.proj = nn.Linear(P²C, D)

除此之外,也可以直接使用卷积层实现。

2.3 位置编码

Transformer 本身是排列不变的,因此必须注入空间位置信息,不同于 Transformer 的固定编码,ViT 使用可学习的位置编码 (position Embedding),形状为 N+1 × D (Npatche + 1 个分类 token),在 PyTorch 中可以使用以下代码实现:

python 复制代码
self.pos_embed = nn.Parameter(torch.zeros(1, N+1, D))
nn.init.trunc_normal_(self.pos_embed, std=0.02)

2.4 分类 token

分类 token (Class Token) 类似 BERT[CLS] token,用于分类任务,作为整个图像的表征,通过自注意力聚合全局信息,在 PyTorch 中可以使用以下代码添加分类 token

python 复制代码
self.cls_token = nn.Parameter(torch.zeros(1, 1, D))

3. 使用 PyTorch 实现 ViT

接下来,下面我们将从零开始实现 ViT 模型,并使用 CIFAR-10 数据集训练模型。ViT 工作流程如下:

  • 输入图像 H×W×C
  • 分割为 NP×P×Cpatch ( N = H W / P 2 N = HW/P² N=HW/P2)
  • 每个 patch 展平为 P 2 C P²C P2C 维向量
  • 通过线性投影映射到 D 维 (Patch Embedding)
  • 添加位置编码和分类 token
  • 输入 L 层的 Transformer 编码器
  • 使用分类 token 对应的输出进行分类

3.1 模型构建

(1) 首先,导入所需库:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm

(2) 将图像分割为 patch 并线性嵌入到 D 维空间:

python 复制代码
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # 使用卷积层实现patch分割和嵌入
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    
    def forward(self, x):
        # 输入x形状: [batch_size, in_channels, img_size, img_size]
        # 输出形状: [batch_size, n_patches, embed_dim]
        x = self.proj(x)  # [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]
        x = x.flatten(2)  # [batch_size, embed_dim, n_patches]
        x = x.transpose(1, 2)  # [batch_size, n_patches, embed_dim]
        return x

(3) 实现位置编码:

python 复制代码
class PositionEmbedding(nn.Module):
    def __init__(self, n_patches, embed_dim, dropout=0.1):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))  # +1 for class token
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x形状: [batch_size, n_patches+1, embed_dim]
        x = x + self.pos_embed # 添加位置编码
        x = self.dropout(x)
        return x

(4) 实现多头注意力机制:

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)  # 同时计算Q,K,V
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
    
    def forward(self, x):
        batch_size, n_patches, embed_dim = x.shape
        
        # 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]
        qkv = self.qkv(x).reshape(batch_size, n_patches, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        
        # 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]
        out = attn @ v
        out = out.transpose(1, 2).reshape(batch_size, n_patches, embed_dim)
        
        # 线性投影和dropout
        out = self.proj(out)
        out = self.proj_dropout(out)
        return out

(5) 实现多层感知机 (Multilayer Perceptron, MLP) 模块,自注意力机制后进行非线性特征变换和维度扩展/收缩:

python 复制代码
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

(6) 实现 Transformer 编码器模块 TransformerBlock

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(
            in_features=embed_dim,
            hidden_features=embed_dim * mlp_ratio,
            out_features=embed_dim,
            dropout=dropout
        )
    
    def forward(self, x):
        # 残差连接和层归一化
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

(7) 实现 ViT 模型:

python 复制代码
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # 分类token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = PositionEmbedding(n_patches, embed_dim, dropout)
        
        # Transformer编码器
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)
        
        # 初始化权重
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        # 生成patch嵌入
        x = self.patch_embed(x)  # [batch_size, n_patches, embed_dim]
        
        # 添加class token
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)  # [batch_size, n_patches+1, embed_dim]
        
        # 添加位置编码
        x = self.pos_embed(x)
        
        # 通过Transformer编码器
        for block in self.blocks:
            x = block(x)
        
        # 分类
        x = self.norm(x)
        cls_token_final = x[:, 0]  # 只取class token对应的输出
        x = self.head(cls_token_final)
        
        return x

3.2 模型训练

(1) 实现模型训练与评估函数:

python 复制代码
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计信息
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

(2) 定义模型超参数:

python 复制代码
img_size = 224
patch_size = 16
batch_size = 32
num_epochs = 20
learning_rate = 0.0001
num_classes = 10  # CIFAR-10有10个类别

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

(3) 加载 CIFAR-10 数据集,并进行数据预处理:

python 复制代码
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

(4) 初始化模型、损失函数和优化器:

python 复制代码
model = VisionTransformer(
    img_size=img_size,
    patch_size=patch_size,
    n_classes=num_classes,
    embed_dim=768,
    depth=6,  # 减少深度以加快训练
    num_heads=8
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

(5) 训练模型 20epoch

python 复制代码
train_losses, train_accs = [], []
test_losses, test_accs = [], []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    # 训练
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # 评估
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    print()

(6) 绘制模型训练过程中损失值和分类性能变化曲线:

python 复制代码
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy')

plt.show()

可以看到,从零开始训练的 ViTCIFAR-10 数据集上的准确率大约在 67% 左右,在小规模数据集上从头训练时,ViT 的表现通常不如 CNN,这是由于ViT 的核心是全局自注意力机制,它需要足够多的数据来学习长距离依赖关系,在小规模数据集(如 CIFAR-10,仅 5 万张 32×32 图像)上,ViT 容易过拟合,无法有效学习有意义的特征映射。而使用在 ImageNet 上预训练的 ViT 进行微调,在 CIFAR-10 上可达到 98.5% 的准确率。

相关推荐
风象南11 小时前
普通人用AI加持赚到的第一个100块
人工智能·后端
牛奶11 小时前
2026年大模型怎么选?前端人实用对比
前端·人工智能·ai编程
牛奶11 小时前
前端人为什么要学AI?
前端·人工智能·ai编程
罗西的思考14 小时前
AI Agent框架探秘:拆解 OpenHands(10)--- Runtime
人工智能·算法·机器学习
冬奇Lab15 小时前
OpenClaw 源码精读(2):Channel & Routing——一条消息如何找到它的 Agent?
人工智能·开源·源码阅读
冬奇Lab15 小时前
一天一个开源项目(第38篇):Claude Code Telegram - 用 Telegram 远程用 Claude Code,随时随地聊项目
人工智能·开源·资讯
格砸16 小时前
从入门到辞职|从ChatGPT到OpenClaw,跟上智能时代的进化
前端·人工智能·后端
可观测性用观测云16 小时前
可观测性 4.0:教系统如何思考
人工智能
sunny86517 小时前
Claude Code 跨会话上下文恢复:从 8 次纠正到 0 次的工程实践
人工智能·开源·github
小笼包包仔17 小时前
OpenClaw 多Agent软件开发最佳实践指南
人工智能