[深度学习]Vision Transformer

Pytorch实现Vision Transformer

python 复制代码
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        # 使用卷积层实现patch分割和嵌入
        self.proj = nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=embed_dim,
                        kernel_size=patch_size,
                        stride=patch_size
                    )

    def forward(self, x):
        # 输入x形状: [batch_size, in_channels, img_size, img_size]
        # 输出形状: [batch_size, n_patches, embed_dim]
        x = self.proj(x)  # [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]
        x = x.flatten(2)  # [batch_size, embed_dim, n_patches]
        x = x.transpose(1, 2)  # [batch_size, n_patches, embed_dim]
        return x

class PositionEmbedding(nn.Module):
    def __init__(self, n_patches, embed_dim, dropout=0.1):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))  # +1 for class token
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x形状: [batch_size, n_patches+1, embed_dim]
        x = x + self.pos_embed # 添加位置编码
        x = self.dropout(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)  # 同时计算Q,K,V
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        batch_size, n_patches, embed_dim = x.shape
        # 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]
        qkv = self.qkv(x).reshape(batch_size, n_patches, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        # 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        # 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]
        out = attn @ v
        out = out.transpose(1, 2).reshape(batch_size, n_patches, embed_dim)
        # 线性投影和dropout
        out = self.proj(out)
        out = self.proj_dropout(out)
        return out

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(
        in_features=embed_dim,
        hidden_features=embed_dim * mlp_ratio,
        out_features=embed_dim,
        dropout=dropout
        )

    def forward(self, x):
        # 残差连接和层归一化
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        # 分类token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = PositionEmbedding(n_patches, embed_dim, dropout)
        # Transformer编码器
        self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth)])
        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)
        # 初始化权重
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        batch_size = x.shape[0]
        # 生成patch嵌入
        x = self.patch_embed(x)  # [batch_size, n_patches, embed_dim]
        # 添加class token
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)  # [batch_size, n_patches+1, embed_dim]
        # 添加位置编码
        x = self.pos_embed(x)
        # 通过Transformer编码器
        x = self.blocks(x)
        # 分类
        x = self.norm(x)
        cls_token_final = x[:, 0]  # 只取class token对应的输出
        x = self.head(cls_token_final)
        return x

if __name__ == '__main__':
    x = torch.rand(1, 3, 224, 224)
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
    )
    y = model(x)
    print('y.shape = ', y.shape)
    print(y)

参考资料

  1. 博客
  2. B站教程
相关推荐
波动几何19 小时前
通用自然语言任务执行器:设计理念与实现思路
人工智能
mit6.82419 小时前
trubble shotting
人工智能
向量引擎19 小时前
AI Agent 安全元年:OpenClaw 投毒事件如何改变整个生态安全标准,
运维·人工智能·安全·自动化·aigc·api调用
Kel19 小时前
从Prompt到Response:大模型推理端到端核心链路深度拆解
人工智能·算法·架构
亦暖筑序19 小时前
Message 四分天下:Spring AI 如何统一消息格式
java·人工智能
tinygone19 小时前
OpenClaw通过ACPX调用Claude Code实现飞书操作CC
人工智能·飞书·ai编程
2501_9333295519 小时前
AI驱动媒介宣发:Infoseek舆情系统的技术架构与公关实战
数据仓库·人工智能·重构·数据库开发
ZC跨境爬虫20 小时前
极验滑动验证码自动化实战(ddddocr免费方案):本地缺口识别与Playwright滑动模拟
前端·爬虫·python·自动化
云栖梦泽20 小时前
【AI】AI安全工具:常用AI安全检测工具的使用教程
大数据·人工智能·安全
海兰20 小时前
【AI网关】阿里开源的Higress(OpenAPI-to-MCP工具)
人工智能·架构·开源·银行系统