简化的动态稀疏视觉Transformer的PyTorch代码

存一串代码(简化的动态稀疏视觉Transformer的PyTorch代码)

复制代码
import torch 
import torch.nn  as nn 
import torch.nn.functional  as F 
 
class DynamicSparseAttention(nn.Module): 
    def __init__(self, dim, num_heads=8, dropout=0.1): 
        super().__init__() 
        self.num_heads  = num_heads 
        self.head_dim  = dim // num_heads 
        self.scale  = self.head_dim  ** -0.5 
 
        self.qkv  = nn.Linear(dim, dim * 3, bias=False) 
        self.attn_drop  = nn.Dropout(dropout) 
        self.proj  = nn.Linear(dim, dim) 
        self.proj_drop  = nn.Dropout(dropout) 
 
    def forward(self, x): 
        B, N, C = x.shape  
        qkv = self.qkv(x).reshape(B,  N, 3, self.num_heads,  self.head_dim).permute(2,  0, 3, 1, 4) 
        q, k, v = qkv.unbind(0)  
 
        attn = (q @ k.transpose(-2,  -1)) * self.scale  
        attn = attn.softmax(dim=-1)  
        attn = self.attn_drop(attn)  
 
        x = (attn @ v).transpose(1, 2).reshape(B, N, C) 
        x = self.proj(x)  
        x = self.proj_drop(x)  
        return x 
 
class HierarchicalRoutingBlock(nn.Module): 
    def __init__(self, dim, num_heads=8, mlp_ratio=4., dropout=0.1): 
        super().__init__() 
        self.norm1  = nn.LayerNorm(dim) 
        self.attn  = DynamicSparseAttention(dim, num_heads, dropout) 
        self.norm2  = nn.LayerNorm(dim) 
        self.mlp  = nn.Sequential( 
            nn.Linear(dim, int(dim * mlp_ratio)), 
            nn.GELU(), 
            nn.Dropout(dropout), 
            nn.Linear(int(dim * mlp_ratio), dim), 
            nn.Dropout(dropout) 
        ) 
 
    def forward(self, x): 
        x = x + self.attn(self.norm1(x))  
        x = x + self.mlp(self.norm2(x))  
        return x 
 
class DynamicSparseVisionTransformer(nn.Module): 
    def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, num_heads=8, depth=12, mlp_ratio=4., dropout=0.1): 
        super().__init__() 
        self.patch_embed  = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size) 
        self.pos_embed  = nn.Parameter(torch.zeros(1,  (img_size // patch_size) ** 2, dim)) 
        self.dropout  = nn.Dropout(dropout) 
 
        self.blocks  = nn.ModuleList([HierarchicalRoutingBlock(dim, num_heads, mlp_ratio, dropout) for _ in range(depth)]) 
        self.norm  = nn.LayerNorm(dim) 
 
        self.head  = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() 
 
    def forward(self, x): 
        x = self.patch_embed(x).flatten(2).transpose(1,  2) 
        x = x + self.pos_embed  
        x = self.dropout(x)  
 
        for blk in self.blocks:  
            x = blk(x) 
 
        x = self.norm(x)  
        x = x[:, 0] 
        x = self.head(x)  
        return x 
 
# 使用 
model = DynamicSparseVisionTransformer() 
x = torch.randn(1,  3, 224, 224) 
output = model(x) 
print(output.shape)  

代码解释

DynamicSparseAttention:实现动态稀疏注意力模块。

HierarchicalRoutingBlock:实现层次化路由块,包含注意力模块和多层感知机。

DynamicSparseVisionTransformer:实现完整的动态稀疏视觉Transformer模型,包括补丁嵌入、位置嵌入、层次化路由块和分类头。

相关推荐
2601_961963381 分钟前
从“电子化”到“自动化”:2026年智能合约与电子合同融合的技术逻辑与法律适配
网络·人工智能·区块链·智能合约·政务
米小虾12 分钟前
AI Skills 工程化:当每个开发者都有一支「AI 小队」,你该怎么管理?
人工智能
DisonTangor21 分钟前
谷歌开源首个扩散大语言模型——DiffusionGemma
人工智能·语言模型·自然语言处理·开源·aigc·transformer
冬奇Lab25 分钟前
每日一个开源项目(第129篇):OpenMed - 永不离开设备的医疗 NLP
人工智能·开源·资讯
冬奇Lab27 分钟前
Agent 系列(19):Harness 完整体系——8 层防护框架全景
人工智能·llm·agent
米小虾27 分钟前
Claude Fable 5 系统提示词被扒出来了:1586 行代码背后,藏着 AI 产品工程的终极哲学
人工智能·agent
云烟成雨TD29 分钟前
Spring AI Alibaba 1.x 系列【77】执行取消
java·人工智能·spring
Teacher.chenchong30 分钟前
AI-Agent2.0 科研全链路实战营:LLM+NotebookLM + 自动化编程 + 文献管理 + 论文写作,搭建本地科研智能体
人工智能·自动化
weberCd35 分钟前
ChatGPT 实用技巧总结(国内)
人工智能·chatgpt
我爱cope39 分钟前
【Agent智能体26 | 多智能体-多智能体工作流】
人工智能·设计模式·语言模型·职场和发展