Multi-Query Attention (MQA) PyTorch 实现

和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。

因此这里的代码和标准多头注意力的实现也是几乎完全一样:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # 查询、键、值投影
        self.q_proj = nn.Linear(embed_dim, embed_dim)  # 多头查询
        self.k_proj = nn.Linear(embed_dim, self.head_dim)  # 单头键
        self.v_proj = nn.Linear(embed_dim, self.head_dim)  # 单头值
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # 投影
        q = self.q_proj(x)  # (batch, seq_len, embed_dim)
        k = self.k_proj(x)  # (batch, seq_len, head_dim)
        v = self.v_proj(x)  # (batch, seq_len, head_dim)

        # 重塑查询为多头
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (batch, num_heads, seq_len, head_dim)
        
        # 键和值保持单头,扩展到多头维度
        k = k.unsqueeze(1)  # (batch, 1, seq_len, head_dim)
        v = v.unsqueeze(1)  # (batch, 1, seq_len, head_dim)

        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # (batch, num_heads, seq_len, seq_len)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # (batch, num_heads, seq_len, head_dim)

        # 合并多头
        out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        out = self.out_proj(out)  # (batch, seq_len, embed_dim)

        return out

# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim)  # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape)  # torch.Size([2, 10, 64])
相关推荐
轻竹办公PPT1 分钟前
AI一键生成年终总结PPT
人工智能·python·powerpoint
是Dream呀1 分钟前
昇腾平台 PyTorch 迁移实操:从环境搭建到精度达标的完整步骤
人工智能·pytorch·python·昇腾
lxmyzzs6 分钟前
【图像算法 - 36】医疗应用:基于 YOLOv12 与 OpenCV 的高精度脑肿瘤检测系统实现
python·深度学习·opencv·yolo·计算机视觉·脑肿瘤检测
工藤学编程7 分钟前
零基础学AI大模型之Milvus实战:Attu可视化安装+Python整合全案例
人工智能·python·milvus
学不完了是吧12 分钟前
“小白专属”python字符串处理文档
开发语言·python
☆光之梦☆14 分钟前
《openGauss全密态与防篡改账本数据库:云上数据安全与可信的新范式》
数据库·python
Maya动画技术14 分钟前
python的py转pyd方法(cython)
开发语言·python·spring
找了一圈尾巴15 分钟前
Python 学习-深入理解 Python 进程、线程与协程(上)
python·学习·并发
ZAz_15 分钟前
DAY 26 机器学习流水线
python
方知我16 分钟前
【GoogLeNet】基本原理
人工智能·pytorch·深度学习·神经网络·cnn