一、ViT凭啥颠覆CNN?------从"卷积霸权"到"Transformer统治"的逆袭
在2020年之前,图像领域是CNN的"一言堂"------从LeNet到ResNet,从MobileNet到EfficientNet,卷积操作凭借局部感受野+参数共享 的优势,垄断了图像分类、检测、分割等几乎所有任务。但2020年Google提出的Vision Transformer(ViT),彻底打破了这一格局:它完全抛弃卷积,仅用Transformer的自注意力机制,就在ImageNet等数据集上实现了超越CNN的性能,甚至衍生出Swin Transformer、ViT-L/16等"性能怪兽"。

ViT的核心优势在于:
- 全局建模能力:CNN需要通过多层卷积扩大感受野,而ViT直接捕捉图像全局依赖,对长距离特征关联更敏感
- 泛化能力强:在小数据集预训练后,迁移到其他任务(如医疗图像、遥感图像)的效果远超CNN
- 并行计算友好:自注意力机制可通过矩阵运算高效并行,服务器端训练速度比深层CNN更快
- 结构灵活:只需调整" patch 大小"、"注意力头数"等参数,就能适配不同分辨率图像

二、ViT核心原理深度拆解(含数学公式)
ViT的本质是"将图像拆分成小块,再用Transformer编码器处理这些小块",核心流程可概括为:图像分块→线性嵌入→添加位置编码→Transformer编码器→分类头 。

2.1 核心模块数学原理
(1)图像分块(Patch Embedding):把图像变成"单词"
CNN处理图像是逐像素滑动卷积,而ViT第一步是将图像分割成固定大小的非重叠patch(类似NLP中把句子拆分成单词)。
假设输入图像尺寸为 H × W × C H \times W \times C H×W×C( H H H=高度, W W W=宽度, C C C=通道数),patch大小为 P × P P \times P P×P,则:
- 每个patch的像素数: P × P × C P \times P \times C P×P×C
- 图像拆分后的patch总数: N = H × W P × P N = \frac{H \times W}{P \times P} N=P×PH×W(需满足 H H H、 W W W 能被 P P P 整除)
例如:输入图像为 224 × 224 × 3 224 \times 224 \times 3 224×224×3(ImageNet标准尺寸),patch大小 16 × 16 16 \times 16 16×16,则:
- 每个patch像素数: 16 × 16 × 3 = 768 16 \times 16 \times 3 = 768 16×16×3=768
- patch总数: N = 224 × 224 16 × 16 = 196 N = \frac{224 \times 224}{16 \times 16} = 196 N=16×16224×224=196
之后,通过一个线性层 将每个patch映射到维度为 D D D 的向量(称为"patch embedding"),数学表达为:
patch embedding = Linear ( P × P × C , D ) × patch \text{patch embedding} = \text{Linear}(P \times P \times C, D) \times \text{patch} patch embedding=Linear(P×P×C,D)×patch
其中 Linear ( i n _ d i m , o u t _ d i m ) \text{Linear}(in\_dim, out\_dim) Linear(in_dim,out_dim) 表示线性变换(权重矩阵维度为 D × ( P × P × C ) D \times (P \times P \times C) D×(P×P×C)),最终得到 N N N 个维度为 D D D 的向量,形状为 N × D N \times D N×D。
(2)位置编码(Positional Embedding):告诉模型"patch在哪"
Transformer的自注意力是无序的 (对输入序列顺序不敏感),但图像中patch的位置信息至关重要(例如"猫的头"和"猫的尾巴"位置不同,语义不同)。因此ViT需要添加位置编码,为每个patch注入位置信息。
ViT采用可学习的位置编码 (区别于Transformer的正弦位置编码),其形状与patch embedding完全一致( N × D N \times D N×D),数学上表示为:
encoded patches = patch embedding + positional embedding \text{encoded patches} = \text{patch embedding} + \text{positional embedding} encoded patches=patch embedding+positional embedding
其中"+"表示逐元素相加(广播机制),位置编码会在训练过程中与模型参数一起更新,最终学会捕捉patch间的位置依赖。
此外,ViT还会在编码序列的最前面添加一个特殊的"分类token" ( class token \text{class token} class token),形状为 1 × D 1 \times D 1×D,用于最终的分类任务。此时输入序列的总长度变为 N + 1 N+1 N+1,形状为 ( N + 1 ) × D (N+1) \times D (N+1)×D。
(3)Transformer编码器:ViT的"大脑"
Transformer编码器是ViT的核心,由多头自注意力(Multi-Head Self-Attention, MHSA) 和多层感知机(MLP) 两个子模块组成,且每个子模块前都有层归一化(Layer Normalization, LN) ,模块间有残差连接(Residual Connection)。
① 多头自注意力(MHSA):捕捉patch间的关联
自注意力的核心是计算"每个patch与其他所有patch的关联程度"(注意力权重),再根据权重聚合所有patch的信息。
-
第一步:计算Q、K、V
将输入序列(形状 ( N + 1 ) × D (N+1) \times D (N+1)×D)通过三个线性层,分别映射为查询(Query, Q)、键(Key, K)、值(Value, V),数学表达为:
Q = W Q × X , K = W K × X , V = W V × X Q = W_Q \times X, \quad K = W_K \times X, \quad V = W_V \times X Q=WQ×X,K=WK×X,V=WV×X其中 W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 为线性层权重矩阵(维度均为 D × D D \times D D×D), X X X 为输入序列,最终 Q , K , V Q, K, V Q,K,V 形状均为 ( N + 1 ) × D (N+1) \times D (N+1)×D。
-
第二步:拆分多头
将 Q , K , V Q, K, V Q,K,V 按注意力头数 h h h 拆分为 h h h 个并行的子空间(每个子空间维度为 d k = D / h d_k = D/h dk=D/h),例如 h = 12 h=12 h=12、 D = 768 D=768 D=768,则每个子空间维度 d k = 64 d_k=64 dk=64:
Q = [ Q 1 , Q 2 , . . . , Q h ] , K = [ K 1 , K 2 , . . . , K h ] , V = [ V 1 , V 2 , . . . , V h ] Q = [Q_1, Q_2, ..., Q_h], \quad K = [K_1, K_2, ..., K_h], \quad V = [V_1, V_2, ..., V_h] Q=[Q1,Q2,...,Qh],K=[K1,K2,...,Kh],V=[V1,V2,...,Vh]其中 Q i , K i , V i Q_i, K_i, V_i Qi,Ki,Vi 形状均为 ( N + 1 ) × d k (N+1) \times d_k (N+1)×dk。
-
第三步:计算单头注意力
对每个子空间,计算注意力权重和输出:
Attention ( Q i , K i , V i ) = Softmax ( Q i × K i T d k ) × V i \text{Attention}(Q_i, K_i, V_i) = \text{Softmax}\left( \frac{Q_i \times K_i^T}{\sqrt{d_k}} \right) \times V_i Attention(Qi,Ki,Vi)=Softmax(dk Qi×KiT)×Vi其中:
- Q i × K i T Q_i \times K_i^T Qi×KiT:计算Q与K的相似度(形状 ( N + 1 ) × ( N + 1 ) (N+1) \times (N+1) (N+1)×(N+1))
- d k \sqrt{d_k} dk :缩放因子,避免相似度值过大导致Softmax后梯度消失
- Softmax \text{Softmax} Softmax:将相似度转换为注意力权重(和为1)
- 最终单头输出形状为 ( N + 1 ) × d k (N+1) \times d_k (N+1)×dk。
-
第四步:多头拼接
将 h h h 个单头注意力输出拼接,再通过一个线性层映射回维度 D D D:
MHSA ( X ) = W O × Concat ( [ Attention ( Q 1 , K 1 , V 1 ) , . . . , Attention ( Q h , K h , V h ) ] ) \text{MHSA}(X) = W_O \times \text{Concat}([\text{Attention}(Q_1,K_1,V_1), ..., \text{Attention}(Q_h,K_h,V_h)]) MHSA(X)=WO×Concat([Attention(Q1,K1,V1),...,Attention(Qh,Kh,Vh)])其中 W O W_O WO 为输出线性层权重矩阵(维度 D × D D \times D D×D),最终MHSA输出形状为 ( N + 1 ) × D (N+1) \times D (N+1)×D。
② 多层感知机(MLP):增强非线性表达
MHSA输出后,通过一个两层的MLP进行非线性变换,数学表达为:
MLP ( X ) = W 2 × GELU ( W 1 × X + b 1 ) + b 2 \text{MLP}(X) = W_2 \times \text{GELU}(W_1 \times X + b_1) + b_2 MLP(X)=W2×GELU(W1×X+b1)+b2
其中:
- W 1 , b 1 W_1, b_1 W1,b1:第一层线性层(输入 D D D,输出 4 D 4D 4D,通常放大4倍)
- GELU \text{GELU} GELU:激活函数( GELU ( x ) = x × Φ ( x ) \text{GELU}(x) = x \times \Phi(x) GELU(x)=x×Φ(x), Φ ( x ) \Phi(x) Φ(x) 为标准正态分布的累积分布函数)
- W 2 , b 2 W_2, b_2 W2,b2:第二层线性层(输入 4 D 4D 4D,输出 D D D)
- 最终MLP输出形状仍为 ( N + 1 ) × D (N+1) \times D (N+1)×D。
③ 编码器完整流程
单个Transformer编码器层的流程为:
X 1 = LN ( X + MHSA ( X ) ) X_1 = \text{LN}(X + \text{MHSA}(X)) X1=LN(X+MHSA(X))
X 2 = LN ( X 1 + MLP ( X 1 ) ) X_2 = \text{LN}(X_1 + \text{MLP}(X_1)) X2=LN(X1+MLP(X1))
其中 X X X 为编码器输入, X 2 X_2 X2 为编码器输出,残差连接( X + . . . X + ... X+...)确保训练稳定,层归一化(LN)加速收敛。ViT通常堆叠 L L L 个编码器层(例如 L = 12 L=12 L=12),最终得到形状为 ( N + 1 ) × D (N+1) \times D (N+1)×D 的特征序列。

(4)分类头:输出类别概率
取编码器输出序列中第一个"分类token" ( class token \text{class token} class token)的特征(形状 1 × D 1 \times D 1×D),通过一个线性层映射到类别数 C C C,再经过Softmax得到类别概率:
logits = W cls × X class token + b cls \text{logits} = W_{\text{cls}} \times X_{\text{class token}} + b_{\text{cls}} logits=Wcls×Xclass token+bcls
prob = Softmax ( logits ) \text{prob} = \text{Softmax}(\text{logits}) prob=Softmax(logits)
其中 W cls W_{\text{cls}} Wcls 为分类层权重矩阵(维度 C × D C \times D C×D), prob \text{prob} prob 为最终类别概率(形状 1 × C 1 \times C 1×C)。
2.2 ViT完整结构维度表(以ViT-Base为例)
| 模块 | 输入形状 | 操作细节 | 输出形状 | 核心参数(ViT-Base) |
|---|---|---|---|---|
| 图像输入 | (3,224,224) | 标准ImageNet图像(C=3, H=224, W=224) | (3,224,224) | - |
| Patch Embedding | (3,224,224) | 16×16 patch分割 + 线性层(768→768) | (196,768) | patch_size=16, D=768 |
| 添加Class Token | (196,768) | 拼接1个可学习token | (197,768) | - |
| 位置编码 | (197,768) | 可学习位置编码(与输入逐元素相加) | (197,768) | pos_emb_shape=(197,768) |
| Transformer编码器(L=12层) | (197,768) | 每层含MHSA+MLP+LN+残差连接 | (197,768) | L=12, h=12, d_k=64 |
| 分类头 | (1,768) | 取Class Token + 线性层(768→1000)+ Softmax | (1,1000) | 类别数C=1000(ImageNet) |
三、小白友好的ViT实战:CIFAR-10图像分类(服务器可直接运行)
本次实战选择CIFAR-10数据集(10个类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车),无需手动下载(Pytorch自动下载),全程英文图例+服务器适配,一键运行出结果。
3.1 环境准备(服务器通用)
bash
# 安装依赖(Python 3.7+,Pytorch 1.8+)
pip install torch torchvision numpy matplotlib tqdm scikit-learn
3.2 完整代码
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
# ======================== ViT Model Definition ========================
class PatchEmbedding(nn.Module):
"""
Convert image to patch embeddings + class token + positional embedding
Input: (B, C, H, W)
Output: (B, N+1, D) where N = (H*W)/(P*P), D = embedding dimension
"""
def __init__(self, img_size=32, patch_size=4, in_ch=3, embed_dim=256):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
# Calculate number of patches
self.num_patches = (img_size // patch_size) ** 2
# Patch embedding: linear layer (implemented via Conv2d for efficiency)
self.patch_embed = nn.Conv2d(
in_channels=in_ch,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size # No overlap between patches
)
# Class token: (1, 1, D) -> expand to (B, 1, D) in forward
self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Positional embedding: (1, N+1, D) -> learnable
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
def forward(self, x):
B = x.shape[0] # Batch size
# Step 1: Patch embedding (B, C, H, W) -> (B, D, N^(1/2), N^(1/2))
x = self.patch_embed(x) # (B, 256, 8, 8) for img_size=32, patch_size=4
# Step 2: Flatten patches (B, D, N^(1/2), N^(1/2)) -> (B, D, N)
x = x.flatten(2) # (B, 256, 64)
# Step 3: Transpose to (B, N, D)
x = x.transpose(1, 2) # (B, 64, 256)
# Step 4: Add class token (B, 1, D)
class_token = self.class_token.expand(B, -1, -1) # (B, 1, 256)
x = torch.cat([class_token, x], dim=1) # (B, 65, 256)
# Step 5: Add positional embedding
x = x + self.pos_embed # (B, 65, 256)
return x
class MultiHeadSelfAttention(nn.Module):
"""
Multi-Head Self-Attention (MHSA) module
Input: (B, N, D)
Output: (B, N, D)
"""
def __init__(self, embed_dim=256, num_heads=8, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # Dimension per head
# Ensure embed_dim is divisible by num_heads
assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by num heads"
# Linear layers for Q, K, V
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Output linear layer
self.out_proj = nn.Linear(embed_dim, embed_dim)
# Dropout layer
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, D = x.shape # (B, 65, 256)
# Step 1: Compute Q, K, V (B, N, D)
q = self.q_proj(x) # (B, 65, 256)
k = self.k_proj(x) # (B, 65, 256)
v = self.v_proj(x) # (B, 65, 256)
# Step 2: Split into multiple heads (B, num_heads, N, head_dim)
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, 8, 65, 32)
k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, 8, 65, 32)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, 8, 65, 32)
# Step 3: Compute attention scores (B, num_heads, N, N)
scores = torch.matmul(q, k.transpose(-2, -1)) # (B, 8, 65, 65)
scores = scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)) # Scale
# Step 4: Softmax to get attention weights (B, num_heads, N, N)
attn_weights = torch.softmax(scores, dim=-1) # (B, 8, 65, 65)
attn_weights = self.dropout(attn_weights)
# Step 5: Compute weighted sum of V (B, num_heads, N, head_dim)
attn_output = torch.matmul(attn_weights, v) # (B, 8, 65, 32)
# Step 6: Concatenate heads (B, N, D)
attn_output = attn_output.transpose(1, 2).contiguous() # (B, 65, 8, 32)
attn_output = attn_output.view(B, N, D) # (B, 65, 256)
# Step 7: Linear projection
output = self.out_proj(attn_output) # (B, 65, 256)
output = self.dropout(output)
return output
class MLP(nn.Module):
"""
Multi-Layer Perceptron for Transformer encoder
Input: (B, N, D)
Output: (B, N, D)
"""
def __init__(self, embed_dim=256, mlp_dim=1024, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(embed_dim, mlp_dim) # Expand to 4x dimension
self.fc2 = nn.Linear(mlp_dim, embed_dim) # Project back
self.dropout = nn.Dropout(dropout)
self.gelu = nn.GELU() # Activation function
def forward(self, x):
x = self.fc1(x) # (B, 65, 256) -> (B, 65, 1024)
x = self.gelu(x)
x = self.dropout(x)
x = self.fc2(x) # (B, 65, 1024) -> (B, 65, 256)
x = self.dropout(x)
return x
class TransformerEncoderLayer(nn.Module):
"""
Single layer of Transformer encoder
Input: (B, N, D)
Output: (B, N, D)
"""
def __init__(self, embed_dim=256, num_heads=8, mlp_dim=1024, dropout=0.1):
super().__init__()
# Layer normalization before MHSA
self.ln1 = nn.LayerNorm(embed_dim)
# Multi-Head Self-Attention
self.mhsa = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
# Layer normalization before MLP
self.ln2 = nn.LayerNorm(embed_dim)
# MLP block
self.mlp = MLP(embed_dim, mlp_dim, dropout)
def forward(self, x):
# Residual connection + MHSA
x = x + self.mhsa(self.ln1(x)) # (B, 65, 256)
# Residual connection + MLP
x = x + self.mlp(self.ln2(x)) # (B, 65, 256)
return x
class VisionTransformer(nn.Module):
"""
Full Vision Transformer model for image classification
Input: (B, C, H, W)
Output: (B, num_classes)
"""
def __init__(
self,
img_size=32,
patch_size=4,
in_ch=3,
embed_dim=256,
num_heads=8,
num_layers=6,
mlp_dim=1024,
num_classes=10,
dropout=0.1
):
super().__init__()
# Patch embedding + class token + positional embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_ch, embed_dim)
# Transformer encoder (stack multiple layers)
self.encoder_layers = nn.ModuleList([
TransformerEncoderLayer(embed_dim, num_heads, mlp_dim, dropout)
for _ in range(num_layers)
])
# Layer normalization for encoder output
self.ln = nn.LayerNorm(embed_dim)
# Classification head (linear layer)
self.classifier = nn.Linear(embed_dim, num_classes)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize model weights (improve training stability)"""
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def forward(self, x):
# Step 1: Patch embedding + positional encoding (B, C, H, W) -> (B, N+1, D)
x = self.patch_embed(x) # (B, 65, 256)
# Step 2: Pass through Transformer encoder layers
for layer in self.encoder_layers:
x = layer(x) # (B, 65, 256)
# Step 3: Layer normalization
x = self.ln(x) # (B, 65, 256)
# Step 4: Extract class token feature (B, D)
class_token_feature = x[:, 0, :] # (B, 256)
# Step 5: Classification head (B, num_classes)
logits = self.classifier(class_token_feature) # (B, 10)
return logits
# ======================== Data Preparation ========================
def get_cifar10_dataloaders(batch_size=64, img_size=32):
"""
Load CIFAR-10 dataset with data augmentation (for training)
Returns: train_loader, val_loader, class_names
"""
# Data transforms (server-friendly, no random crop for stability)
train_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(p=0.5), # Data augmentation
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # CIFAR-10 stats
std=[0.2023, 0.1994, 0.2010])
])
val_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010])
])
# Download CIFAR-10 dataset (auto-download if not exists)
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform
)
val_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=val_transform
)
# Create data loaders
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
)
# CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
return train_loader, val_loader, class_names
# ======================== Training & Validation Functions ========================
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
"""Train model for one epoch"""
model.train()
total_loss = 0.0
total_correct = 0
total_samples = 0
pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} [Train]')
for imgs, labels in pbar:
# Move data to device (GPU/CPU)
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass
outputs = model(imgs)
loss = criterion(outputs, labels)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Calculate metrics
total_loss += loss.item() * imgs.size(0)
_, preds = torch.max(outputs, 1)
total_correct += (preds == labels).sum().item()
total_samples += imgs.size(0)
# Update progress bar
avg_loss = total_loss / total_samples
avg_acc = total_correct / total_samples
pbar.set_postfix({'Loss': f'{avg_loss:.4f}', 'Acc': f'{avg_acc:.4f}'})
return avg_loss, avg_acc
def validate(model, val_loader, criterion, device):
"""Validate model on validation set"""
model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0
all_preds = []
all_labels = []
with torch.no_grad():
pbar = tqdm(val_loader, desc='[Validation]')
for imgs, labels in pbar:
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass
outputs = model(imgs)
loss = criterion(outputs, labels)
# Calculate metrics
total_loss += loss.item() * imgs.size(0)
_, preds = torch.max(outputs, 1)
total_correct += (preds == labels).sum().item()
total_samples += imgs.size(0)
# Collect preds and labels for confusion matrix
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Update progress bar
avg_loss = total_loss / total_samples
avg_acc = total_correct / total_samples
pbar.set_postfix({'Loss': f'{avg_loss:.4f}', 'Acc': f'{avg_acc:.4f}'})
avg_loss = total_loss / total_samples
avg_acc = total_correct / total_samples
return avg_loss, avg_acc, all_preds, all_labels
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs=20):
"""Full training pipeline"""
# Record training history
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': []
}
print(f"Training started (Device: {device})")
for epoch in range(epochs):
# Train one epoch
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
# Validate
val_loss, val_acc, all_preds, all_labels = validate(model, val_loader, criterion, device)
# Update scheduler
scheduler.step(val_loss)
# Save history
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
# Print epoch summary
print(f'Epoch [{epoch+1}/{epochs}] | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | '
f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
# Save trained model
torch.save(model.state_dict(), 'vit_cifar10.pth')
print(f"Model saved as 'vit_cifar10.pth'")
return history, all_preds, all_labels
# ======================== Result Visualization (Server Compatible) ========================
def plot_training_history(history, save_path='training_history.png'):
"""Plot training/validation loss and accuracy curves"""
plt.figure(figsize=(12, 4))
# Loss curve
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss', linewidth=2, color='blue')
plt.plot(history['val_loss'], label='Val Loss', linewidth=2, color='red', linestyle='--')
plt.xlabel('Epoch', fontsize=10)
plt.ylabel('Loss', fontsize=10)
plt.title('Training & Validation Loss', fontsize=12)
plt.legend()
plt.grid(alpha=0.3)
# Accuracy curve
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Accuracy', linewidth=2, color='blue')
plt.plot(history['val_acc'], label='Val Accuracy', linewidth=2, color='red', linestyle='--')
plt.xlabel('Epoch', fontsize=10)
plt.ylabel('Accuracy', fontsize=10)
plt.title('Training & Validation Accuracy', fontsize=12)
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Training history saved to {save_path}")
def plot_confusion_matrix(all_labels, all_preds, class_names, save_path='confusion_matrix.png'):
"""Plot confusion matrix for validation set"""
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix (CIFAR-10 Validation Set)', fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Confusion matrix saved to {save_path}")
def plot_predictions(model, val_loader, class_names, device, save_path='predictions.png'):
"""Plot sample predictions (correct and incorrect)"""
model.eval()
correct_samples = []
incorrect_samples = []
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, preds = torch.max(outputs, 1)
# Collect samples
for img, label, pred in zip(imgs, labels, preds):
if len(correct_samples) < 5 and label == pred:
correct_samples.append((img.cpu(), label.cpu(), pred.cpu()))
elif len(incorrect_samples) < 5 and label != pred:
incorrect_samples.append((img.cpu(), label.cpu(), pred.cpu()))
if len(correct_samples) >= 5 and len(incorrect_samples) >= 5:
break
# Plot
plt.figure(figsize=(15, 10))
# Correct predictions
for i, (img, label, pred) in enumerate(correct_samples):
# Denormalize image
img = img.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
img = img * torch.tensor([0.2023, 0.1994, 0.2010]) + torch.tensor([0.4914, 0.4822, 0.4465])
img = torch.clip(img, 0, 1)
plt.subplot(2, 5, i+1)
plt.imshow(img)
plt.title(f'Correct\nTrue: {class_names[label]}\nPred: {class_names[pred]}', fontsize=10)
plt.axis('off')
# Incorrect predictions
for i, (img, label, pred) in enumerate(incorrect_samples):
img = img.permute(1, 2, 0)
img = img * torch.tensor([0.2023, 0.1994, 0.2010]) + torch.tensor([0.4914, 0.4822, 0.4465])
img = torch.clip(img, 0, 1)
plt.subplot(2, 5, i+6)
plt.imshow(img)
plt.title(f'Incorrect\nTrue: {class_names[label]}\nPred: {class_names[pred]}', fontsize=10)
plt.axis('off')
plt.suptitle('ViT Predictions on CIFAR-10 (Top: Correct, Bottom: Incorrect)', fontsize=16)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Sample predictions saved to {save_path}")
def plot_patch_attention(model, val_loader, class_names, device, save_path='attention_visualization.png'):
"""Visualize attention weights between class token and patches"""
model.eval()
# Get one batch of data
imgs, labels = next(iter(val_loader))
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass up to patch embedding
patch_embed = model.patch_embed(imgs) # (B, N+1, D)
# Get attention weights from first Transformer layer
first_encoder_layer = model.encoder_layers[0]
ln1_output = first_encoder_layer.ln1(patch_embed)
q = first_encoder_layer.mhsa.q_proj(ln1_output)
k = first_encoder_layer.mhsa.k_proj(ln1_output)
B, N, D = q.shape
num_heads = first_encoder_layer.mhsa.num_heads
head_dim = D // num_heads
# Reshape Q and K for attention calculation
q = q.view(B, N, num_heads, head_dim).transpose(1, 2) # (B, H, N, d)
k = k.view(B, N, num_heads, head_dim).transpose(1, 2) # (B, H, N, d)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
attn_weights = torch.softmax(scores, dim=-1) # (B, H, N, N)
# Take class token (index 0) attention weights for first sample and first head
sample_idx = 0
head_idx = 0
class_token_attn = attn_weights[sample_idx, head_idx, 0, 1:] # (N,) -> exclude class token itself
# Reshape attention weights to match patch grid
patch_size = model.patch_embed.patch_size
img_size = model.patch_embed.img_size
grid_size = img_size // patch_size
attn_grid = class_token_attn.view(grid_size, grid_size).cpu().numpy()
# Get original image (denormalized)
img = imgs[sample_idx].cpu().permute(1, 2, 0)
img = img * torch.tensor([0.2023, 0.1994, 0.2010]) + torch.tensor([0.4914, 0.4822, 0.4465])
img = torch.clip(img, 0, 1).numpy()
# Plot image + attention heatmap
plt.figure(figsize=(10, 5))
# Original image
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title(f'Original Image\nClass: {class_names[labels[sample_idx].cpu().item()]}', fontsize=12)
plt.axis('off')
# Attention heatmap
plt.subplot(1, 2, 2)
im = plt.imshow(attn_grid, cmap='hot', interpolation='bilinear')
plt.colorbar(im, label='Attention Weight')
plt.title('Class Token Attention Heatmap\n(Patch Importance)', fontsize=12)
plt.axis('off')
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Attention visualization saved to {save_path}")
# ======================== Main Function (One-Click Run) ========================
if __name__ == "__main__":
# Configuration (adjust based on server resources)
BATCH_SIZE = 64
IMG_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5 # Regularization to prevent overfitting
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Step 1: Load data
train_loader, val_loader, class_names = get_cifar10_dataloaders(BATCH_SIZE, IMG_SIZE)
# Step 2: Initialize ViT model
model = VisionTransformer(
img_size=IMG_SIZE,
patch_size=4,
in_ch=3,
embed_dim=256,
num_heads=8,
num_layers=6,
mlp_dim=1024,
num_classes=10,
dropout=0.1
)
model.to(DEVICE)
print(f"ViT model initialized. Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
# Step 3: Define loss function and optimizer
criterion = nn.CrossEntropyLoss() # Classification task
optimizer = optim.AdamW(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
# Learning rate scheduler (reduce LR when val loss plateaus)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# Step 4: Train model
history, all_preds, all_labels = train_model(
model, train_loader, val_loader, criterion, optimizer, scheduler, DEVICE, EPOCHS
)
# Step 5: Load best model (optional, since we save the final model)
model.load_state_dict(torch.load('vit_cifar10.pth', map_location=DEVICE))
# Step 6: Generate diverse result visualizations
plot_training_history(history)
plot_confusion_matrix(all_labels, all_preds, class_names)
plot_predictions(model, val_loader, class_names, DEVICE)
plot_patch_attention(model, val_loader, class_names, DEVICE)
# Step 7: Print classification report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))
print("="*50)
print("All processes completed successfully!")
print(f"Generated files:")
print(f"1. Trained model: vit_cifar10.pth")
print(f"2. Training history: training_history.png")
print(f"3. Confusion matrix: confusion_matrix.png")
print(f"4. Sample predictions: predictions.png")
print(f"5. Attention visualization: attention_visualization.png")
print("="*50)
3.3 代码核心亮点(小白友好+服务器适配)
- 模块化设计 :将ViT拆解为
PatchEmbedding、MultiHeadSelfAttention、TransformerEncoderLayer等独立模块,每个模块功能单一,小白可逐块理解 - 自动下载数据集 :无需手动下载CIFAR-10,Pytorch自动下载并缓存到
./data目录,服务器环境直接运行 - 英文图例适配:所有可视化标签、标题均为英文,避免服务器字体乱码问题
- 多样化结果保存:生成4类核心结果图(训练曲线、混淆矩阵、预测样本、注意力热力图),全方位评估模型
- 服务器优化 :
num_workers=0:避免多线程导致的服务器环境报错pin_memory=True:加速GPU数据读取- 高分辨率保存(dpi=300):支持远程查看清晰结果
- 防过拟合机制:添加权重衰减(Weight Decay)、Dropout、学习率调度器,训练稳定不易过拟合
- 详细注释:每个函数、关键步骤都有英文注释,小白能看懂代码逻辑
3.4 运行步骤(服务器一键执行)
- 将代码保存为
vit_cifar10_classification.py - 登录服务器,切换到代码所在目录
- 执行命令:
python vit_cifar10_classification.py - 等待训练完成(20个epoch,GPU约30分钟,CPU约2小时,服务器GPU加速效果显著)
- 查看生成的5个核心文件:
vit_cifar10.pth:训练好的ViT模型权重training_history.png:训练/验证损失+准确率曲线confusion_matrix.png:混淆矩阵(展示各类别分类效果)predictions.png:正确/错误预测样本对比(各5个)attention_visualization.png:注意力热力图(直观展示ViT关注的图像区域)
3.5 结果解读(小白也能看懂)
(1)训练曲线(training_history.png)
- 左图:训练损失(蓝色)和验证损失(红色虚线),若两者均逐渐下降且差距不大,说明模型收敛良好
- 右图:训练准确率(蓝色)和验证准确率(红色虚线),ViT在CIFAR-10上最终验证准确率可达85%+,远超基础CNN(如LeNet约60%)

(2)混淆矩阵(confusion_matrix.png)
- 行:真实类别,列:预测类别
- 对角线数值越大,该类别的分类效果越好
- 非对角线数值:混淆样本数(例如"猫"和"狗"容易混淆,对应位置数值较大)

(3)预测样本对比(predictions.png)
- 上半部分:正确预测的5个样本,展示"真实类别+预测类别",ViT能准确识别大部分图像
- 下半部分:错误预测的5个样本,可直观看到模型的薄弱点(如相似类别容易混淆)

(4)注意力热力图(attention_visualization.png)
- 左图:原始图像(如飞机)
- 右图:注意力热力图,颜色越红表示该patch对分类的贡献越大(ViT关注飞机的机身、机翼等关键区域)
- 这是ViT的核心优势:自动聚焦图像中最具辨识度的区域,而非像CNN那样逐像素扫描
3.6 小白入门进阶建议
- 调整超参数 :
- 增大
num_layers=12(Transformer层数)、embed_dim=512(嵌入维度),准确率可能提升,但训练时间增加 - 调整
patch_size=8(更大的patch),减少patch数量,训练速度更快
- 增大
- 更换数据集 :将CIFAR-10替换为自定义数据集(如宠物分类、水果分类),只需修改
get_cifar10_dataloaders函数 - 模型轻量化 :减小
embed_dim=128、num_heads=4,适合CPU或低配置服务器运行 - 预训练权重微调 :加载ImageNet预训练的ViT权重(需安装
timm库),只需几轮训练就能达到更高准确率
四、总结
Vision Transformer的核心突破是"用Transformer的全局注意力替代CNN的局部卷积",彻底打破了图像领域的"卷积依赖"。本文从数学原理到代码实现,层层拆解ViT的核心模块,配套的CIFAR-10分类项目无需复杂配置,小白也能在服务器上一键运行,快速感受Transformer在图像任务中的强大能力。
通过本次实战,你不仅能掌握ViT的代码实现,还能理解"图像分块→嵌入→位置编码→注意力建模"的核心逻辑------这些思想同样适用于ViT的衍生模型(如Swin Transformer、MAE),为后续学习更复杂的视觉Transformer打下坚实基础!