python
复制代码
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# 使用卷积层实现patch分割和嵌入
self.proj = nn.Conv2d(
in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
# 输入x形状: [batch_size, in_channels, img_size, img_size]
# 输出形状: [batch_size, n_patches, embed_dim]
x = self.proj(x) # [batch_size, embed_dim, n_patches^0.5, n_patches^0.5]
x = x.flatten(2) # [batch_size, embed_dim, n_patches]
x = x.transpose(1, 2) # [batch_size, n_patches, embed_dim]
return x
class PositionEmbedding(nn.Module):
def __init__(self, n_patches, embed_dim, dropout=0.1):
super().__init__()
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim)) # +1 for class token
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x形状: [batch_size, n_patches+1, embed_dim]
x = x + self.pos_embed # 添加位置编码
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"
self.qkv = nn.Linear(embed_dim, embed_dim * 3) # 同时计算Q,K,V
self.attn_dropout = nn.Dropout(dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(self, x):
batch_size, n_patches, embed_dim = x.shape
# 计算Q,K,V [batch_size, n_patches, num_heads, head_dim]
qkv = self.qkv(x).reshape(batch_size, n_patches, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力分数 [batch_size, num_heads, n_patches, n_patches]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 应用注意力权重到V上 [batch_size, num_heads, n_patches, head_dim]
out = attn @ v
out = out.transpose(1, 2).reshape(batch_size, n_patches, embed_dim)
# 线性投影和dropout
out = self.proj(out)
out = self.proj_dropout(out)
return out
class MLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(
in_features=embed_dim,
hidden_features=embed_dim * mlp_ratio,
out_features=embed_dim,
dropout=dropout
)
def forward(self, x):
# 残差连接和层归一化
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
n_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
dropout=0.1
):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
n_patches = self.patch_embed.n_patches
# 分类token和位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = PositionEmbedding(n_patches, embed_dim, dropout)
# Transformer编码器
self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth)])
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, n_classes)
# 初始化权重
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
batch_size = x.shape[0]
# 生成patch嵌入
x = self.patch_embed(x) # [batch_size, n_patches, embed_dim]
# 添加class token
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1) # [batch_size, n_patches+1, embed_dim]
# 添加位置编码
x = self.pos_embed(x)
# 通过Transformer编码器
x = self.blocks(x)
# 分类
x = self.norm(x)
cls_token_final = x[:, 0] # 只取class token对应的输出
x = self.head(cls_token_final)
return x
if __name__ == '__main__':
x = torch.rand(1, 3, 224, 224)
model = VisionTransformer(
img_size=224,
patch_size=16,
)
y = model(x)
print('y.shape = ', y.shape)
print(y)