和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。
因此这里的代码和标准多头注意力的实现也是几乎完全一样:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
# 查询、键、值投影
self.q_proj = nn.Linear(embed_dim, embed_dim) # 多头查询
self.k_proj = nn.Linear(embed_dim, self.head_dim) # 单头键
self.v_proj = nn.Linear(embed_dim, self.head_dim) # 单头值
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.shape
# 投影
q = self.q_proj(x) # (batch, seq_len, embed_dim)
k = self.k_proj(x) # (batch, seq_len, head_dim)
v = self.v_proj(x) # (batch, seq_len, head_dim)
# 重塑查询为多头
q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# (batch, num_heads, seq_len, head_dim)
# 键和值保持单头,扩展到多头维度
k = k.unsqueeze(1) # (batch, 1, seq_len, head_dim)
v = v.unsqueeze(1) # (batch, 1, seq_len, head_dim)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# (batch, num_heads, seq_len, seq_len)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # (batch, num_heads, seq_len, head_dim)
# 合并多头
out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
out = self.out_proj(out) # (batch, seq_len, embed_dim)
return out
# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim) # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape) # torch.Size([2, 10, 64])