视觉Transformer实战 | Pooling-based Vision Transformer(PiT)详解与实现

视觉Transformer实战 | Pooling-based Vision Transformer(PiT)详解与实现

    • [0. 前言](#0. 前言)
    • [1. PiT 技术原理](#1. PiT 技术原理)
      • [1.1 核心思想](#1.1 核心思想)
      • [1.2 与传统 ViT 对比](#1.2 与传统 ViT 对比)
    • [2. PiT 网络架构](#2. PiT 网络架构)
    • [3. 使用 PyTorch 实现 PiT](#3. 使用 PyTorch 实现 PiT)
      • [3.1 模型构建](#3.1 模型构建)
      • [3.2 模型训练](#3.2 模型训练)
    • 相关链接

0. 前言

Vision Transformer (ViT)在计算机视觉领域取得了巨大成功,但标准的 ViT 架构在处理不同尺度的视觉特征时存在一定局限性。Pooling-based Vision Transformer (PiT) 通过引入池化操作来改进 ViT 架构,使其能够更有效地处理多尺度特征,同时减少计算复杂度。本节将详细介绍 PiT 的技术原理,并使用 PyTorch 从零开始实现 PiT 模型。

1. PiT 技术原理

1.1 核心思想

PiT 的核心创新点在于将卷积神经网络 (Convolutional Neural Network, CNN) 的金字塔结构设计思想引入 Pooling-based Vision Transformer (PiT) ,通过动态调整空间分辨率和通道维度,实现更高效的多尺度特征提取。PiT 架构如下所示,其核心思想具体体现在以下三个方面:

  • 空间池化层 (Pooling Layer) 的引入:传统 ViTtoken 数量在全程保持固定(即图像分割后的 patch 数量不变),而 PiTTransformer 块之间插入空间池化层,逐步减少 token 数量(即降低空间分辨率)。例如,输入图像经 patch embedding 后生成 14×14token 序列,通过池化可逐步降为 7×74×4
  • 通道维度的动态扩展:随着空间分辨率的降低,PiT 逐步增加每个 token 的通道维度(特征深度),形成类似 CNN 的"金字塔结构"(如 ResNet 的通道数随层数增加),这种设计平衡了计算开销与特征表达能力
  • 多尺度特征融合:通过分层降低空间分辨率,PiT 在不同尺度下捕获特征,浅层保留细节信息(高分辨率),深层提取语义信息(低分辨率),这种结构与 CNN 的层次化特征提取机制一致,更适合视觉任务

1.2 与传统 ViT 对比

ViT 模型使用固定 token 数量,缺乏多尺度建模能力,可能丢失局部细节。ViT 在第一个嵌入层将图像按块 (patch) 划分,并将其嵌入到 token 中。该结构不包括空间缩减层,并且在网络的整个层中保持相同数量的空间 token。虽然自注意操作不受空间距离的限制,但参与注意的空间区域的大小受特征的空间大小的影响。
PiT 模型通过池化实现层级结构,更高效地处理不同尺度的视觉模式。由于 ViT2D 矩阵而不是 3D 张量的形式处理神经元响应,因此池化层应该分离空间 token 并将它们重塑为具有空间结构的 3D 张量。在整形之后,通过深度卷积来执行空间大小减小和通道增加。

2. PiT 网络架构

PiT 的完整架构由以下关键组件构成:

  • Patch Embedding:输入图像被分割为固定大小的非重叠 patch (如 16×16 像素),每个 patch 通过线性投影(全连接层)映射为 token,初始通道维度为 C 1 C_1 C1,例如:224×224 图像 → 14×14patch196token,每个 token 维度为 C 1 C_1 C1
  • Pooling Transformer Block,每个 Block 包含两个核心操作:Transformer 层使用多头自注意力 (Multi-head Self Attention, MSA) 和 MLP 层,结构与标准 ViT 一致,但增加了深度可分离卷积;空间池化层 (Pooling) 在特定阶段对 token 序列进行空间池化。假设当前 token 排列为 H × W × C H×W×C H×W×C,池化窗口为 k × k k×k k×k,则输出分辨率降为 H k × W k \frac Hk×\frac Wk kH×kW,通道数扩展至 k 2 C k^2C k2C (通过调整 MLP 实现),池化操作通常采用平均池化或最大池化,论文中采用深度可分离卷积 (Depth-wise Convolution) 实现,兼顾位置信息保留与计算效率
  • Depth-wise Convolution 的位置编码:PiT 摒弃了 ViT 的固定位置编码,改用深度可分离卷积 (3×3 卷积,分组数为通道数)隐式编码位置信息,该卷积应用于每个 Transformer 块的 MLP 之前,增强局部性建模能力(类似 CNN 的局部感受野)
  • 分类头 (Classifier Head):最终阶段的 token 序列通过全局平均池化 (Global Average Pooling, GAP) 压缩为 1 × 1 × C n 1×1×C_n 1×1×Cn,然后接全连接层输出分类结果

3. 使用 PyTorch 实现 PiT

3.1 模型构建

(1) 首先,导入所需库:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

(2) 实现深度可分离卷积,用于位置编码:

python 复制代码
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias=True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, 
                      groups=dim_in, stride=stride, bias=bias),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
        )
    
    def forward(self, x):
        return self.net(x)

(3) 实现 Transformer 中的前馈网络:

python 复制代码
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

(4) 实现多头注意力机制:

python 复制代码
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 生成查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        
        # 计算注意力分数
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        
        # 注意力权重
        attn = self.attend(dots)
        attn = self.dropout(attn)
        
        # 应用注意力权重到值上
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

(5) 使用注意力机制和前馈网络,实现完整的 Transformer 块:

python 复制代码
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                nn.LayerNorm(dim),
                Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                nn.LayerNorm(dim),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))
    
    def forward(self, x):
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x))
            x = x + ff(norm2(x))
        return x

(6) 构建带池化操作的 Transformer 块:

python 复制代码
class PoolingTransformer(nn.Module):
    def __init__(self, dim, dim_out, pool_size=3, stride=2, padding=1):
        super().__init__()
        self.pool = DepthWiseConv2d(dim, dim_out, pool_size, padding, stride)
        self.norm = nn.LayerNorm(dim_out)
    
    def forward(self, x):
        # x的形状: (batch_size, num_tokens, dim)
        # 转换为2D图像形式以应用池化
        h = w = int(x.shape[1] ** 0.5)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        
        # 应用池化
        x = self.pool(x)
        
        # 转换回序列形式
        h, w = x.shape[-2:]
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        return x

(7) 构建基于池化的 Vision Transformer 完整模型:

python 复制代码
class PiT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, 
                 pool_size=3, stride=2, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        
        # 参数设置
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2
        
        # Patch Embedding
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim),
        )
        
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.dropout = nn.Dropout(emb_dropout)
        
        # 计算各阶段维度
        dim1 = dim
        dim2 = dim1 * (stride ** 2)  # 池化后通道数增加
        dim3 = dim2 * (stride ** 2)  # 再次池化后通道数增加
        
        # Transformer阶段
        self.transformer1 = Transformer(dim1, depth[0], heads[0], dim_head, mlp_dim, dropout)
        self.pooling1 = PoolingTransformer(dim1, dim2, pool_size, stride)
        
        self.transformer2 = Transformer(dim2, depth[1], heads[1], dim_head, mlp_dim, dropout)
        self.pooling2 = PoolingTransformer(dim2, dim3, pool_size, stride)
        
        self.transformer3 = Transformer(dim3, depth[2], heads[2], dim_head, mlp_dim, dropout)
        
        # 分类头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim3),
            nn.Linear(dim3, num_classes)
        )
    
    def forward(self, img):
        # Patch Embedding
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        
        # 添加位置编码
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)
        
        # Transformer阶段1
        x = self.transformer1(x)
        x = self.pooling1(x)
        
        # Transformer阶段2
        x = self.transformer2(x)
        x = self.pooling2(x)
        
        # Transformer阶段3
        x = self.transformer3(x)
        
        # 全局平均池化并分类
        x = x.mean(dim=1)
        return self.mlp_head(x)

(8) 构建不同规模的 PiT 模型:

python 复制代码
def pit_tiny(num_classes=1000):
    """创建小型PiT模型"""
    return PiT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes,
        dim=256,
        depth=[2, 4, 3],
        heads=[3, 6, 12],
        mlp_dim=512
    )

def pit_small(num_classes=1000):
    """创建中型PiT模型"""
    return PiT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes,
        dim=384,
        depth=[2, 6, 4],
        heads=[6, 12, 24],
        mlp_dim=768
    )

def pit_base(num_classes=1000):
    """创建大型PiT模型"""
    return PiT(
        image_size=224,
        patch_size=16,
        num_classes=num_classes,
        dim=512,
        depth=[3, 6, 4],
        heads=[8, 16, 32],
        mlp_dim=1024
    )

3.2 模型训练

(1) 获取训练和验证数据加载器,本节使用 CIFAR-10 数据集:

python 复制代码
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

def get_dataloaders(batch_size=64):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    
    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                           download=True, transform=transform_train)
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                         download=True, transform=transform_val)
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader

(2) 定义模型训练和评估函数:

python 复制代码
from matplotlib import pyplot as plt

def train_model(model, train_loader, val_loader, epochs=20, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        
        # 验证阶段
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)
        
        # 学习率调整
        scheduler.step()

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # 打印统计信息
        print(f'Epoch [{epoch+1}/{epochs}], '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
    
    print(f'Training finished. Best validation accuracy: {best_acc:.2f}%')
    # 绘制训练曲线
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.show()
    return model
    
def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc

(3) 初始化模型,并训练模型:

python 复制代码
# 获取数据加载器
train_loader, val_loader = get_dataloaders(batch_size=64)
# 初始化模型
model = pit_tiny(num_classes=10)  # CIFAR-10有10个类别
# 训练模型
trained_model = train_model(model, train_loader, val_loader, epochs=50, lr=0.0001)

训练过程模型损失和准确率变化情况如下:

相关链接

视觉Transformer实战 | Transformer详解与实现
视觉Transformer实战 | Vision Transformer(ViT)详解与实现
视觉Transformer实战 | Token-to-Token Vision Transformer(T2T-ViT)详解与实现

相关推荐
feifeigo1232 小时前
MATLAB微光图像增强综合实现
开发语言·计算机视觉·matlab
有Li2 小时前
PISCO:用于改进动态MRI神经隐式k空间表示的自监督k空间正则化文献速递-医疗影像分割与目标检测最新技术
论文阅读·深度学习·文献·医学生
Zhuanshan_2 小时前
服务器连接及训练问题
服务器·深度学习
童园管理札记2 小时前
融传统文化于幼儿日常 育根魂少年于启蒙之时
经验分享·深度学习·创业创新·学习方法·微信公众平台
shayudiandian3 小时前
数据增强(Data Augmentation)策略大全
人工智能·深度学习·计算机视觉
这张生成的图像能检测吗3 小时前
(论文速读)GAT:图注意神经网络
人工智能·深度学习·神经网络·图神经网络·注意力机制
Das13 小时前
【计算机视觉】08_识别分类
人工智能·计算机视觉·分类
m0_692457103 小时前
图像添加水印
图像处理·opencv·计算机视觉
AndrewHZ3 小时前
【图像处理基石】VR的眩晕感是如何产生的?
图像处理·算法·计算机视觉·vr·cv·立体视觉·眩晕感