存一串代码(简化的动态稀疏视觉Transformer的PyTorch代码)
import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicSparseAttention(nn.Module):
def __init__(self, dim, num_heads=8, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class HierarchicalRoutingBlock(nn.Module):
def __init__(self, dim, num_heads=8, mlp_ratio=4., dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = DynamicSparseAttention(dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(dropout)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class DynamicSparseVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, num_heads=8, depth=12, mlp_ratio=4., dropout=0.1):
super().__init__()
self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2, dim))
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([HierarchicalRoutingBlock(dim, num_heads, mlp_ratio, dropout) for _ in range(depth)])
self.norm = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x):
x = self.patch_embed(x).flatten(2).transpose(1, 2)
x = x + self.pos_embed
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
x = x[:, 0]
x = self.head(x)
return x
# 使用
model = DynamicSparseVisionTransformer()
x = torch.randn(1, 3, 224, 224)
output = model(x)
print(output.shape)
代码解释
DynamicSparseAttention:实现动态稀疏注意力模块。
HierarchicalRoutingBlock:实现层次化路由块,包含注意力模块和多层感知机。
DynamicSparseVisionTransformer:实现完整的动态稀疏视觉Transformer模型,包括补丁嵌入、位置嵌入、层次化路由块和分类头。