Vision Transformer(ViT)保姆级教程:从原理到CIFAR-10实战(PyTorch)!

一、ViT凭啥颠覆CNN?------从"卷积霸权"到"Transformer统治"的逆袭

在2020年之前,图像领域是CNN的"一言堂"------从LeNet到ResNet,从MobileNet到EfficientNet,卷积操作凭借局部感受野+参数共享 的优势,垄断了图像分类、检测、分割等几乎所有任务。但2020年Google提出的Vision Transformer(ViT),彻底打破了这一格局:它完全抛弃卷积,仅用Transformer的自注意力机制,就在ImageNet等数据集上实现了超越CNN的性能,甚至衍生出Swin Transformer、ViT-L/16等"性能怪兽"。

原文 资料 这里!

ViT的核心优势在于:

  • 全局建模能力:CNN需要通过多层卷积扩大感受野,而ViT直接捕捉图像全局依赖,对长距离特征关联更敏感
  • 泛化能力强:在小数据集预训练后,迁移到其他任务(如医疗图像、遥感图像)的效果远超CNN
  • 并行计算友好:自注意力机制可通过矩阵运算高效并行,服务器端训练速度比深层CNN更快
  • 结构灵活:只需调整" patch 大小"、"注意力头数"等参数,就能适配不同分辨率图像

二、ViT核心原理深度拆解(含数学公式)

ViT的本质是"将图像拆分成小块,再用Transformer编码器处理这些小块",核心流程可概括为:图像分块→线性嵌入→添加位置编码→Transformer编码器→分类头

2.1 核心模块数学原理

(1)图像分块(Patch Embedding):把图像变成"单词"

CNN处理图像是逐像素滑动卷积,而ViT第一步是将图像分割成固定大小的非重叠patch(类似NLP中把句子拆分成单词)。

假设输入图像尺寸为 H × W × C H \times W \times C H×W×C( H H H=高度, W W W=宽度, C C C=通道数),patch大小为 P × P P \times P P×P,则:

  • 每个patch的像素数: P × P × C P \times P \times C P×P×C
  • 图像拆分后的patch总数: N = H × W P × P N = \frac{H \times W}{P \times P} N=P×PH×W(需满足 H H H、 W W W 能被 P P P 整除)

例如:输入图像为 224 × 224 × 3 224 \times 224 \times 3 224×224×3(ImageNet标准尺寸),patch大小 16 × 16 16 \times 16 16×16,则:

  • 每个patch像素数: 16 × 16 × 3 = 768 16 \times 16 \times 3 = 768 16×16×3=768
  • patch总数: N = 224 × 224 16 × 16 = 196 N = \frac{224 \times 224}{16 \times 16} = 196 N=16×16224×224=196

之后,通过一个线性层 将每个patch映射到维度为 D D D 的向量(称为"patch embedding"),数学表达为:
patch embedding = Linear ( P × P × C , D ) × patch \text{patch embedding} = \text{Linear}(P \times P \times C, D) \times \text{patch} patch embedding=Linear(P×P×C,D)×patch

其中 Linear ( i n _ d i m , o u t _ d i m ) \text{Linear}(in\_dim, out\_dim) Linear(in_dim,out_dim) 表示线性变换(权重矩阵维度为 D × ( P × P × C ) D \times (P \times P \times C) D×(P×P×C)),最终得到 N N N 个维度为 D D D 的向量,形状为 N × D N \times D N×D。

(2)位置编码(Positional Embedding):告诉模型"patch在哪"

Transformer的自注意力是无序的 (对输入序列顺序不敏感),但图像中patch的位置信息至关重要(例如"猫的头"和"猫的尾巴"位置不同,语义不同)。因此ViT需要添加位置编码,为每个patch注入位置信息。

ViT采用可学习的位置编码 (区别于Transformer的正弦位置编码),其形状与patch embedding完全一致( N × D N \times D N×D),数学上表示为:
encoded patches = patch embedding + positional embedding \text{encoded patches} = \text{patch embedding} + \text{positional embedding} encoded patches=patch embedding+positional embedding

其中"+"表示逐元素相加(广播机制),位置编码会在训练过程中与模型参数一起更新,最终学会捕捉patch间的位置依赖。

此外,ViT还会在编码序列的最前面添加一个特殊的"分类token" ( class token \text{class token} class token),形状为 1 × D 1 \times D 1×D,用于最终的分类任务。此时输入序列的总长度变为 N + 1 N+1 N+1,形状为 ( N + 1 ) × D (N+1) \times D (N+1)×D。

(3)Transformer编码器:ViT的"大脑"

Transformer编码器是ViT的核心,由多头自注意力(Multi-Head Self-Attention, MHSA)多层感知机(MLP) 两个子模块组成,且每个子模块前都有层归一化(Layer Normalization, LN) ,模块间有残差连接(Residual Connection)

① 多头自注意力(MHSA):捕捉patch间的关联

自注意力的核心是计算"每个patch与其他所有patch的关联程度"(注意力权重),再根据权重聚合所有patch的信息。

  • 第一步:计算Q、K、V

    将输入序列(形状 ( N + 1 ) × D (N+1) \times D (N+1)×D)通过三个线性层,分别映射为查询(Query, Q)、键(Key, K)、值(Value, V),数学表达为:
    Q = W Q × X , K = W K × X , V = W V × X Q = W_Q \times X, \quad K = W_K \times X, \quad V = W_V \times X Q=WQ×X,K=WK×X,V=WV×X

    其中 W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 为线性层权重矩阵(维度均为 D × D D \times D D×D), X X X 为输入序列,最终 Q , K , V Q, K, V Q,K,V 形状均为 ( N + 1 ) × D (N+1) \times D (N+1)×D。

  • 第二步:拆分多头

    将 Q , K , V Q, K, V Q,K,V 按注意力头数 h h h 拆分为 h h h 个并行的子空间(每个子空间维度为 d k = D / h d_k = D/h dk=D/h),例如 h = 12 h=12 h=12、 D = 768 D=768 D=768,则每个子空间维度 d k = 64 d_k=64 dk=64:
    Q = [ Q 1 , Q 2 , . . . , Q h ] , K = [ K 1 , K 2 , . . . , K h ] , V = [ V 1 , V 2 , . . . , V h ] Q = [Q_1, Q_2, ..., Q_h], \quad K = [K_1, K_2, ..., K_h], \quad V = [V_1, V_2, ..., V_h] Q=[Q1,Q2,...,Qh],K=[K1,K2,...,Kh],V=[V1,V2,...,Vh]

    其中 Q i , K i , V i Q_i, K_i, V_i Qi,Ki,Vi 形状均为 ( N + 1 ) × d k (N+1) \times d_k (N+1)×dk。

  • 第三步:计算单头注意力

    对每个子空间,计算注意力权重和输出:
    Attention ( Q i , K i , V i ) = Softmax ( Q i × K i T d k ) × V i \text{Attention}(Q_i, K_i, V_i) = \text{Softmax}\left( \frac{Q_i \times K_i^T}{\sqrt{d_k}} \right) \times V_i Attention(Qi,Ki,Vi)=Softmax(dk Qi×KiT)×Vi

    其中:

    • Q i × K i T Q_i \times K_i^T Qi×KiT:计算Q与K的相似度(形状 ( N + 1 ) × ( N + 1 ) (N+1) \times (N+1) (N+1)×(N+1))
    • d k \sqrt{d_k} dk :缩放因子,避免相似度值过大导致Softmax后梯度消失
    • Softmax \text{Softmax} Softmax:将相似度转换为注意力权重(和为1)
    • 最终单头输出形状为 ( N + 1 ) × d k (N+1) \times d_k (N+1)×dk。
  • 第四步:多头拼接

    将 h h h 个单头注意力输出拼接,再通过一个线性层映射回维度 D D D:
    MHSA ( X ) = W O × Concat ( [ Attention ( Q 1 , K 1 , V 1 ) , . . . , Attention ( Q h , K h , V h ) ] ) \text{MHSA}(X) = W_O \times \text{Concat}([\text{Attention}(Q_1,K_1,V_1), ..., \text{Attention}(Q_h,K_h,V_h)]) MHSA(X)=WO×Concat([Attention(Q1,K1,V1),...,Attention(Qh,Kh,Vh)])

    其中 W O W_O WO 为输出线性层权重矩阵(维度 D × D D \times D D×D),最终MHSA输出形状为 ( N + 1 ) × D (N+1) \times D (N+1)×D。

② 多层感知机(MLP):增强非线性表达

MHSA输出后,通过一个两层的MLP进行非线性变换,数学表达为:
MLP ( X ) = W 2 × GELU ( W 1 × X + b 1 ) + b 2 \text{MLP}(X) = W_2 \times \text{GELU}(W_1 \times X + b_1) + b_2 MLP(X)=W2×GELU(W1×X+b1)+b2

其中:

  • W 1 , b 1 W_1, b_1 W1,b1:第一层线性层(输入 D D D,输出 4 D 4D 4D,通常放大4倍)
  • GELU \text{GELU} GELU:激活函数( GELU ( x ) = x × Φ ( x ) \text{GELU}(x) = x \times \Phi(x) GELU(x)=x×Φ(x), Φ ( x ) \Phi(x) Φ(x) 为标准正态分布的累积分布函数)
  • W 2 , b 2 W_2, b_2 W2,b2:第二层线性层(输入 4 D 4D 4D,输出 D D D)
  • 最终MLP输出形状仍为 ( N + 1 ) × D (N+1) \times D (N+1)×D。
③ 编码器完整流程

单个Transformer编码器层的流程为:
X 1 = LN ( X + MHSA ( X ) ) X_1 = \text{LN}(X + \text{MHSA}(X)) X1=LN(X+MHSA(X))
X 2 = LN ( X 1 + MLP ( X 1 ) ) X_2 = \text{LN}(X_1 + \text{MLP}(X_1)) X2=LN(X1+MLP(X1))

其中 X X X 为编码器输入, X 2 X_2 X2 为编码器输出,残差连接( X + . . . X + ... X+...)确保训练稳定,层归一化(LN)加速收敛。ViT通常堆叠 L L L 个编码器层(例如 L = 12 L=12 L=12),最终得到形状为 ( N + 1 ) × D (N+1) \times D (N+1)×D 的特征序列。

(4)分类头:输出类别概率

取编码器输出序列中第一个"分类token" ( class token \text{class token} class token)的特征(形状 1 × D 1 \times D 1×D),通过一个线性层映射到类别数 C C C,再经过Softmax得到类别概率:
logits = W cls × X class token + b cls \text{logits} = W_{\text{cls}} \times X_{\text{class token}} + b_{\text{cls}} logits=Wcls×Xclass token+bcls
prob = Softmax ( logits ) \text{prob} = \text{Softmax}(\text{logits}) prob=Softmax(logits)

其中 W cls W_{\text{cls}} Wcls 为分类层权重矩阵(维度 C × D C \times D C×D), prob \text{prob} prob 为最终类别概率(形状 1 × C 1 \times C 1×C)。

2.2 ViT完整结构维度表(以ViT-Base为例)

模块 输入形状 操作细节 输出形状 核心参数(ViT-Base)
图像输入 (3,224,224) 标准ImageNet图像(C=3, H=224, W=224) (3,224,224) -
Patch Embedding (3,224,224) 16×16 patch分割 + 线性层(768→768) (196,768) patch_size=16, D=768
添加Class Token (196,768) 拼接1个可学习token (197,768) -
位置编码 (197,768) 可学习位置编码(与输入逐元素相加) (197,768) pos_emb_shape=(197,768)
Transformer编码器(L=12层) (197,768) 每层含MHSA+MLP+LN+残差连接 (197,768) L=12, h=12, d_k=64
分类头 (1,768) 取Class Token + 线性层(768→1000)+ Softmax (1,1000) 类别数C=1000(ImageNet)

原文 资料 这里!

三、小白友好的ViT实战:CIFAR-10图像分类(服务器可直接运行)

本次实战选择CIFAR-10数据集(10个类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车),无需手动下载(Pytorch自动下载),全程英文图例+服务器适配,一键运行出结果。

3.1 环境准备(服务器通用)

bash 复制代码
# 安装依赖(Python 3.7+,Pytorch 1.8+)
pip install torch torchvision numpy matplotlib tqdm scikit-learn

3.2 完整代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# ======================== ViT Model Definition ========================
class PatchEmbedding(nn.Module):
    """
    Convert image to patch embeddings + class token + positional embedding
    Input: (B, C, H, W)
    Output: (B, N+1, D) where N = (H*W)/(P*P), D = embedding dimension
    """
    def __init__(self, img_size=32, patch_size=4, in_ch=3, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        
        # Calculate number of patches
        self.num_patches = (img_size // patch_size) ** 2
        
        # Patch embedding: linear layer (implemented via Conv2d for efficiency)
        self.patch_embed = nn.Conv2d(
            in_channels=in_ch, 
            out_channels=embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size  # No overlap between patches
        )
        
        # Class token: (1, 1, D) -> expand to (B, 1, D) in forward
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # Positional embedding: (1, N+1, D) -> learnable
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
        
    def forward(self, x):
        B = x.shape[0]  # Batch size
        
        # Step 1: Patch embedding (B, C, H, W) -> (B, D, N^(1/2), N^(1/2))
        x = self.patch_embed(x)  # (B, 256, 8, 8) for img_size=32, patch_size=4
        
        # Step 2: Flatten patches (B, D, N^(1/2), N^(1/2)) -> (B, D, N)
        x = x.flatten(2)  # (B, 256, 64)
        
        # Step 3: Transpose to (B, N, D)
        x = x.transpose(1, 2)  # (B, 64, 256)
        
        # Step 4: Add class token (B, 1, D)
        class_token = self.class_token.expand(B, -1, -1)  # (B, 1, 256)
        x = torch.cat([class_token, x], dim=1)  # (B, 65, 256)
        
        # Step 5: Add positional embedding
        x = x + self.pos_embed  # (B, 65, 256)
        
        return x


class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention (MHSA) module
    Input: (B, N, D)
    Output: (B, N, D)
    """
    def __init__(self, embed_dim=256, num_heads=8, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # Dimension per head
        
        # Ensure embed_dim is divisible by num_heads
        assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by num heads"
        
        # Linear layers for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
        # Output linear layer
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, D = x.shape  # (B, 65, 256)
        
        # Step 1: Compute Q, K, V (B, N, D)
        q = self.q_proj(x)  # (B, 65, 256)
        k = self.k_proj(x)  # (B, 65, 256)
        v = self.v_proj(x)  # (B, 65, 256)
        
        # Step 2: Split into multiple heads (B, num_heads, N, head_dim)
        q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # (B, 8, 65, 32)
        k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # (B, 8, 65, 32)
        v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # (B, 8, 65, 32)
        
        # Step 3: Compute attention scores (B, num_heads, N, N)
        scores = torch.matmul(q, k.transpose(-2, -1))  # (B, 8, 65, 65)
        scores = scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))  # Scale
        
        # Step 4: Softmax to get attention weights (B, num_heads, N, N)
        attn_weights = torch.softmax(scores, dim=-1)  # (B, 8, 65, 65)
        attn_weights = self.dropout(attn_weights)
        
        # Step 5: Compute weighted sum of V (B, num_heads, N, head_dim)
        attn_output = torch.matmul(attn_weights, v)  # (B, 8, 65, 32)
        
        # Step 6: Concatenate heads (B, N, D)
        attn_output = attn_output.transpose(1, 2).contiguous()  # (B, 65, 8, 32)
        attn_output = attn_output.view(B, N, D)  # (B, 65, 256)
        
        # Step 7: Linear projection
        output = self.out_proj(attn_output)  # (B, 65, 256)
        output = self.dropout(output)
        
        return output


class MLP(nn.Module):
    """
    Multi-Layer Perceptron for Transformer encoder
    Input: (B, N, D)
    Output: (B, N, D)
    """
    def __init__(self, embed_dim=256, mlp_dim=1024, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, mlp_dim)  # Expand to 4x dimension
        self.fc2 = nn.Linear(mlp_dim, embed_dim)  # Project back
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()  # Activation function
        
    def forward(self, x):
        x = self.fc1(x)  # (B, 65, 256) -> (B, 65, 1024)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)  # (B, 65, 1024) -> (B, 65, 256)
        x = self.dropout(x)
        return x


class TransformerEncoderLayer(nn.Module):
    """
    Single layer of Transformer encoder
    Input: (B, N, D)
    Output: (B, N, D)
    """
    def __init__(self, embed_dim=256, num_heads=8, mlp_dim=1024, dropout=0.1):
        super().__init__()
        # Layer normalization before MHSA
        self.ln1 = nn.LayerNorm(embed_dim)
        # Multi-Head Self-Attention
        self.mhsa = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        # Layer normalization before MLP
        self.ln2 = nn.LayerNorm(embed_dim)
        # MLP block
        self.mlp = MLP(embed_dim, mlp_dim, dropout)
        
    def forward(self, x):
        # Residual connection + MHSA
        x = x + self.mhsa(self.ln1(x))  # (B, 65, 256)
        # Residual connection + MLP
        x = x + self.mlp(self.ln2(x))  # (B, 65, 256)
        return x


class VisionTransformer(nn.Module):
    """
    Full Vision Transformer model for image classification
    Input: (B, C, H, W)
    Output: (B, num_classes)
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_ch=3,
        embed_dim=256,
        num_heads=8,
        num_layers=6,
        mlp_dim=1024,
        num_classes=10,
        dropout=0.1
    ):
        super().__init__()
        # Patch embedding + class token + positional embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_ch, embed_dim)
        
        # Transformer encoder (stack multiple layers)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ])
        
        # Layer normalization for encoder output
        self.ln = nn.LayerNorm(embed_dim)
        
        # Classification head (linear layer)
        self.classifier = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        """Initialize model weights (improve training stability)"""
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)
    
    def forward(self, x):
        # Step 1: Patch embedding + positional encoding (B, C, H, W) -> (B, N+1, D)
        x = self.patch_embed(x)  # (B, 65, 256)
        
        # Step 2: Pass through Transformer encoder layers
        for layer in self.encoder_layers:
            x = layer(x)  # (B, 65, 256)
        
        # Step 3: Layer normalization
        x = self.ln(x)  # (B, 65, 256)
        
        # Step 4: Extract class token feature (B, D)
        class_token_feature = x[:, 0, :]  # (B, 256)
        
        # Step 5: Classification head (B, num_classes)
        logits = self.classifier(class_token_feature)  # (B, 10)
        
        return logits


# ======================== Data Preparation ========================
def get_cifar10_dataloaders(batch_size=64, img_size=32):
    """
    Load CIFAR-10 dataset with data augmentation (for training)
    Returns: train_loader, val_loader, class_names
    """
    # Data transforms (server-friendly, no random crop for stability)
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),  # Data augmentation
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],  # CIFAR-10 stats
                             std=[0.2023, 0.1994, 0.2010])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2023, 0.1994, 0.2010])
    ])
    
    # Download CIFAR-10 dataset (auto-download if not exists)
    train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transform
    )
    val_dataset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=val_transform
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )
    
    # CIFAR-10 class names
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']
    
    print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
    return train_loader, val_loader, class_names


# ======================== Training & Validation Functions ========================
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Train model for one epoch"""
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} [Train]')
    for imgs, labels in pbar:
        # Move data to device (GPU/CPU)
        imgs, labels = imgs.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        total_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
        total_samples += imgs.size(0)
        
        # Update progress bar
        avg_loss = total_loss / total_samples
        avg_acc = total_correct / total_samples
        pbar.set_postfix({'Loss': f'{avg_loss:.4f}', 'Acc': f'{avg_acc:.4f}'})
    
    return avg_loss, avg_acc


def validate(model, val_loader, criterion, device):
    """Validate model on validation set"""
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='[Validation]')
        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            # Calculate metrics
            total_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += (preds == labels).sum().item()
            total_samples += imgs.size(0)
            
            # Collect preds and labels for confusion matrix
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            avg_loss = total_loss / total_samples
            avg_acc = total_correct / total_samples
            pbar.set_postfix({'Loss': f'{avg_loss:.4f}', 'Acc': f'{avg_acc:.4f}'})
    
    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc, all_preds, all_labels


def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs=20):
    """Full training pipeline"""
    # Record training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    print(f"Training started (Device: {device})")
    for epoch in range(epochs):
        # Train one epoch
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
        
        # Validate
        val_loss, val_acc, all_preds, all_labels = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Print epoch summary
        print(f'Epoch [{epoch+1}/{epochs}] | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | '
              f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
    
    # Save trained model
    torch.save(model.state_dict(), 'vit_cifar10.pth')
    print(f"Model saved as 'vit_cifar10.pth'")
    
    return history, all_preds, all_labels


# ======================== Result Visualization (Server Compatible) ========================
def plot_training_history(history, save_path='training_history.png'):
    """Plot training/validation loss and accuracy curves"""
    plt.figure(figsize=(12, 4))
    
    # Loss curve
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss', linewidth=2, color='blue')
    plt.plot(history['val_loss'], label='Val Loss', linewidth=2, color='red', linestyle='--')
    plt.xlabel('Epoch', fontsize=10)
    plt.ylabel('Loss', fontsize=10)
    plt.title('Training & Validation Loss', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)
    
    # Accuracy curve
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy', linewidth=2, color='blue')
    plt.plot(history['val_acc'], label='Val Accuracy', linewidth=2, color='red', linestyle='--')
    plt.xlabel('Epoch', fontsize=10)
    plt.ylabel('Accuracy', fontsize=10)
    plt.title('Training & Validation Accuracy', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Training history saved to {save_path}")


def plot_confusion_matrix(all_labels, all_preds, class_names, save_path='confusion_matrix.png'):
    """Plot confusion matrix for validation set"""
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.title('Confusion Matrix (CIFAR-10 Validation Set)', fontsize=14)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Confusion matrix saved to {save_path}")


def plot_predictions(model, val_loader, class_names, device, save_path='predictions.png'):
    """Plot sample predictions (correct and incorrect)"""
    model.eval()
    correct_samples = []
    incorrect_samples = []
    
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            
            # Collect samples
            for img, label, pred in zip(imgs, labels, preds):
                if len(correct_samples) < 5 and label == pred:
                    correct_samples.append((img.cpu(), label.cpu(), pred.cpu()))
                elif len(incorrect_samples) < 5 and label != pred:
                    incorrect_samples.append((img.cpu(), label.cpu(), pred.cpu()))
            
            if len(correct_samples) >= 5 and len(incorrect_samples) >= 5:
                break
    
    # Plot
    plt.figure(figsize=(15, 10))
    
    # Correct predictions
    for i, (img, label, pred) in enumerate(correct_samples):
        # Denormalize image
        img = img.permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
        img = img * torch.tensor([0.2023, 0.1994, 0.2010]) + torch.tensor([0.4914, 0.4822, 0.4465])
        img = torch.clip(img, 0, 1)
        
        plt.subplot(2, 5, i+1)
        plt.imshow(img)
        plt.title(f'Correct\nTrue: {class_names[label]}\nPred: {class_names[pred]}', fontsize=10)
        plt.axis('off')
    
    # Incorrect predictions
    for i, (img, label, pred) in enumerate(incorrect_samples):
        img = img.permute(1, 2, 0)
        img = img * torch.tensor([0.2023, 0.1994, 0.2010]) + torch.tensor([0.4914, 0.4822, 0.4465])
        img = torch.clip(img, 0, 1)
        
        plt.subplot(2, 5, i+6)
        plt.imshow(img)
        plt.title(f'Incorrect\nTrue: {class_names[label]}\nPred: {class_names[pred]}', fontsize=10)
        plt.axis('off')
    
    plt.suptitle('ViT Predictions on CIFAR-10 (Top: Correct, Bottom: Incorrect)', fontsize=16)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Sample predictions saved to {save_path}")


def plot_patch_attention(model, val_loader, class_names, device, save_path='attention_visualization.png'):
    """Visualize attention weights between class token and patches"""
    model.eval()
    
    # Get one batch of data
    imgs, labels = next(iter(val_loader))
    imgs, labels = imgs.to(device), labels.to(device)
    
    # Forward pass up to patch embedding
    patch_embed = model.patch_embed(imgs)  # (B, N+1, D)
    
    # Get attention weights from first Transformer layer
    first_encoder_layer = model.encoder_layers[0]
    ln1_output = first_encoder_layer.ln1(patch_embed)
    q = first_encoder_layer.mhsa.q_proj(ln1_output)
    k = first_encoder_layer.mhsa.k_proj(ln1_output)
    
    B, N, D = q.shape
    num_heads = first_encoder_layer.mhsa.num_heads
    head_dim = D // num_heads
    
    # Reshape Q and K for attention calculation
    q = q.view(B, N, num_heads, head_dim).transpose(1, 2)  # (B, H, N, d)
    k = k.view(B, N, num_heads, head_dim).transpose(1, 2)  # (B, H, N, d)
    
    # Compute attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
    attn_weights = torch.softmax(scores, dim=-1)  # (B, H, N, N)
    
    # Take class token (index 0) attention weights for first sample and first head
    sample_idx = 0
    head_idx = 0
    class_token_attn = attn_weights[sample_idx, head_idx, 0, 1:]  # (N,) -> exclude class token itself
    
    # Reshape attention weights to match patch grid
    patch_size = model.patch_embed.patch_size
    img_size = model.patch_embed.img_size
    grid_size = img_size // patch_size
    attn_grid = class_token_attn.view(grid_size, grid_size).cpu().numpy()
    
    # Get original image (denormalized)
    img = imgs[sample_idx].cpu().permute(1, 2, 0)
    img = img * torch.tensor([0.2023, 0.1994, 0.2010]) + torch.tensor([0.4914, 0.4822, 0.4465])
    img = torch.clip(img, 0, 1).numpy()
    
    # Plot image + attention heatmap
    plt.figure(figsize=(10, 5))
    
    # Original image
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title(f'Original Image\nClass: {class_names[labels[sample_idx].cpu().item()]}', fontsize=12)
    plt.axis('off')
    
    # Attention heatmap
    plt.subplot(1, 2, 2)
    im = plt.imshow(attn_grid, cmap='hot', interpolation='bilinear')
    plt.colorbar(im, label='Attention Weight')
    plt.title('Class Token Attention Heatmap\n(Patch Importance)', fontsize=12)
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Attention visualization saved to {save_path}")


# ======================== Main Function (One-Click Run) ========================
if __name__ == "__main__":
    # Configuration (adjust based on server resources)
    BATCH_SIZE = 64
    IMG_SIZE = 32
    EPOCHS = 20
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5  # Regularization to prevent overfitting
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Step 1: Load data
    train_loader, val_loader, class_names = get_cifar10_dataloaders(BATCH_SIZE, IMG_SIZE)
    
    # Step 2: Initialize ViT model
    model = VisionTransformer(
        img_size=IMG_SIZE,
        patch_size=4,
        in_ch=3,
        embed_dim=256,
        num_heads=8,
        num_layers=6,
        mlp_dim=1024,
        num_classes=10,
        dropout=0.1
    )
    model.to(DEVICE)
    print(f"ViT model initialized. Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Step 3: Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()  # Classification task
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    # Learning rate scheduler (reduce LR when val loss plateaus)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # Step 4: Train model
    history, all_preds, all_labels = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler, DEVICE, EPOCHS
    )
    
    # Step 5: Load best model (optional, since we save the final model)
    model.load_state_dict(torch.load('vit_cifar10.pth', map_location=DEVICE))
    
    # Step 6: Generate diverse result visualizations
    plot_training_history(history)
    plot_confusion_matrix(all_labels, all_preds, class_names)
    plot_predictions(model, val_loader, class_names, DEVICE)
    plot_patch_attention(model, val_loader, class_names, DEVICE)
    
    # Step 7: Print classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    
    print("="*50)
    print("All processes completed successfully!")
    print(f"Generated files:")
    print(f"1. Trained model: vit_cifar10.pth")
    print(f"2. Training history: training_history.png")
    print(f"3. Confusion matrix: confusion_matrix.png")
    print(f"4. Sample predictions: predictions.png")
    print(f"5. Attention visualization: attention_visualization.png")
    print("="*50)

3.3 代码核心亮点(小白友好+服务器适配)

  1. 模块化设计 :将ViT拆解为PatchEmbeddingMultiHeadSelfAttentionTransformerEncoderLayer等独立模块,每个模块功能单一,小白可逐块理解
  2. 自动下载数据集 :无需手动下载CIFAR-10,Pytorch自动下载并缓存到./data目录,服务器环境直接运行
  3. 英文图例适配:所有可视化标签、标题均为英文,避免服务器字体乱码问题
  4. 多样化结果保存:生成4类核心结果图(训练曲线、混淆矩阵、预测样本、注意力热力图),全方位评估模型
  5. 服务器优化
    • num_workers=0:避免多线程导致的服务器环境报错
    • pin_memory=True:加速GPU数据读取
    • 高分辨率保存(dpi=300):支持远程查看清晰结果
  6. 防过拟合机制:添加权重衰减(Weight Decay)、Dropout、学习率调度器,训练稳定不易过拟合
  7. 详细注释:每个函数、关键步骤都有英文注释,小白能看懂代码逻辑

3.4 运行步骤(服务器一键执行)

  1. 将代码保存为vit_cifar10_classification.py
  2. 登录服务器,切换到代码所在目录
  3. 执行命令:python vit_cifar10_classification.py
  4. 等待训练完成(20个epoch,GPU约30分钟,CPU约2小时,服务器GPU加速效果显著)
  5. 查看生成的5个核心文件:
    • vit_cifar10.pth:训练好的ViT模型权重
    • training_history.png:训练/验证损失+准确率曲线
    • confusion_matrix.png:混淆矩阵(展示各类别分类效果)
    • predictions.png:正确/错误预测样本对比(各5个)
    • attention_visualization.png:注意力热力图(直观展示ViT关注的图像区域)

3.5 结果解读(小白也能看懂)

(1)训练曲线(training_history.png)
  • 左图:训练损失(蓝色)和验证损失(红色虚线),若两者均逐渐下降且差距不大,说明模型收敛良好
  • 右图:训练准确率(蓝色)和验证准确率(红色虚线),ViT在CIFAR-10上最终验证准确率可达85%+,远超基础CNN(如LeNet约60%)
(2)混淆矩阵(confusion_matrix.png)
  • 行:真实类别,列:预测类别
  • 对角线数值越大,该类别的分类效果越好
  • 非对角线数值:混淆样本数(例如"猫"和"狗"容易混淆,对应位置数值较大)
(3)预测样本对比(predictions.png)
  • 上半部分:正确预测的5个样本,展示"真实类别+预测类别",ViT能准确识别大部分图像
  • 下半部分:错误预测的5个样本,可直观看到模型的薄弱点(如相似类别容易混淆)
(4)注意力热力图(attention_visualization.png)
  • 左图:原始图像(如飞机)
  • 右图:注意力热力图,颜色越红表示该patch对分类的贡献越大(ViT关注飞机的机身、机翼等关键区域)
  • 这是ViT的核心优势:自动聚焦图像中最具辨识度的区域,而非像CNN那样逐像素扫描

3.6 小白入门进阶建议

  1. 调整超参数
    • 增大num_layers=12(Transformer层数)、embed_dim=512(嵌入维度),准确率可能提升,但训练时间增加
    • 调整patch_size=8(更大的patch),减少patch数量,训练速度更快
  2. 更换数据集 :将CIFAR-10替换为自定义数据集(如宠物分类、水果分类),只需修改get_cifar10_dataloaders函数
  3. 模型轻量化 :减小embed_dim=128num_heads=4,适合CPU或低配置服务器运行
  4. 预训练权重微调 :加载ImageNet预训练的ViT权重(需安装timm库),只需几轮训练就能达到更高准确率

四、总结

Vision Transformer的核心突破是"用Transformer的全局注意力替代CNN的局部卷积",彻底打破了图像领域的"卷积依赖"。本文从数学原理到代码实现,层层拆解ViT的核心模块,配套的CIFAR-10分类项目无需复杂配置,小白也能在服务器上一键运行,快速感受Transformer在图像任务中的强大能力。

通过本次实战,你不仅能掌握ViT的代码实现,还能理解"图像分块→嵌入→位置编码→注意力建模"的核心逻辑------这些思想同样适用于ViT的衍生模型(如Swin Transformer、MAE),为后续学习更复杂的视觉Transformer打下坚实基础!

原文 资料 这里!

相关推荐
算家计算2 小时前
国产模型新王登基!刚刚,Kimi K2 Thinking发布,多项能力超越GPT-5
人工智能·开源·资讯
推理幻觉2 小时前
IDE/编码代理架构与 Cursor 相关研究(汇总)
ide·人工智能·架构·agent
YangYang9YangYan2 小时前
中专服装设计专业职业发展指南
大数据·人工智能·数据分析
新智元2 小时前
AI 科学家登场!12 小时抵人类科学家半年工作量,已有 7 项大成果
人工智能·openai
新智元2 小时前
PyTorch 之父闪电离职,AI 半壁江山集体致敬!
人工智能·openai
NON-JUDGMENTAL3 小时前
指令微调(Instruction Tuning)
人工智能·深度学习·机器学习
Funny_AI_LAB3 小时前
深度解析Andrej Karpathy访谈:关于AI智能体、AGI、强化学习与大模型的十年远见
人工智能·计算机视觉·ai·agi
互联科技报3 小时前
AI赋能企业办公:文多多AiPPT以技术创新破解行业痛点
人工智能
番石榴AI3 小时前
视频转ppt/pdf V2.0版(新增转为可编辑PPT功能)
人工智能·pdf·powerpoint