[深度学习]Vision Transformer

Pytorch实现Vision Transformer

python 复制代码
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        # 使用卷积层实现patch分割和嵌入
        self.proj = nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=embed_dim,
                        kernel_size=patch_size,
                        stride=patch_size
                    )

    def forward(self, x):
        # 输入x形状: [batch_size, in_channels, img_size, img_size]
        # 输出形状: [batch_size, n_patches, embed_dim]
        x = self.proj(x)  # [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]
        x = x.flatten(2)  # [batch_size, embed_dim, n_patches]
        x = x.transpose(1, 2)  # [batch_size, n_patches, embed_dim]
        return x

class PositionEmbedding(nn.Module):
    def __init__(self, n_patches, embed_dim, dropout=0.1):
        super().__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))  # +1 for class token
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x形状: [batch_size, n_patches+1, embed_dim]
        x = x + self.pos_embed # 添加位置编码
        x = self.dropout(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)  # 同时计算Q,K,V
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        batch_size, n_patches, embed_dim = x.shape
        # 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]
        qkv = self.qkv(x).reshape(batch_size, n_patches, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        # 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        # 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]
        out = attn @ v
        out = out.transpose(1, 2).reshape(batch_size, n_patches, embed_dim)
        # 线性投影和dropout
        out = self.proj(out)
        out = self.proj_dropout(out)
        return out

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(
        in_features=embed_dim,
        hidden_features=embed_dim * mlp_ratio,
        out_features=embed_dim,
        dropout=dropout
        )

    def forward(self, x):
        # 残差连接和层归一化
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        # 分类token和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = PositionEmbedding(n_patches, embed_dim, dropout)
        # Transformer编码器
        self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth)])
        # 分类头
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)
        # 初始化权重
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        batch_size = x.shape[0]
        # 生成patch嵌入
        x = self.patch_embed(x)  # [batch_size, n_patches, embed_dim]
        # 添加class token
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)  # [batch_size, n_patches+1, embed_dim]
        # 添加位置编码
        x = self.pos_embed(x)
        # 通过Transformer编码器
        x = self.blocks(x)
        # 分类
        x = self.norm(x)
        cls_token_final = x[:, 0]  # 只取class token对应的输出
        x = self.head(cls_token_final)
        return x

if __name__ == '__main__':
    x = torch.rand(1, 3, 224, 224)
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
    )
    y = model(x)
    print('y.shape = ', y.shape)
    print(y)

参考资料

  1. 博客
  2. B站教程
相关推荐
Web3VentureView2 小时前
目标:覆盖全网主流公链,SYNBO 正式开启公链生态媒体合作矩阵计划
大数据·网络·人工智能·区块链·媒体·加密货币
weixin_395448912 小时前
average_weights.py
pytorch·python·深度学习
香芋Yu2 小时前
【深度学习教程——02_优化与正则(Optimization)】09_为什么Dropout能防止过拟合?正则化的本质
人工智能·深度学习
易营宝2 小时前
Yandex广告投放效果怎么样?B2B外贸品牌实测报告
人工智能·seo
会飞的老朱2 小时前
专精特新科技企业,如何用数智化打通管理全链路?
人工智能·科技·oa协同办公
AI_56782 小时前
Git冲突治理白皮书:智能标记与可视化协同的下一代解决方案
大数据·人工智能·git·机器学习
蒜香拿铁2 小时前
【第一章】爬虫概述
爬虫·python
ID_180079054732 小时前
Python调用淘宝评论API:从入门到首次采集全流程
服务器·数据库·python
小猪咪piggy2 小时前
【Python】(2) 执行顺序控制语句
开发语言·python