Multi-Query Attention (MQA) PyTorch 实现

和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。

因此这里的代码和标准多头注意力的实现也是几乎完全一样:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # 查询、键、值投影
        self.q_proj = nn.Linear(embed_dim, embed_dim)  # 多头查询
        self.k_proj = nn.Linear(embed_dim, self.head_dim)  # 单头键
        self.v_proj = nn.Linear(embed_dim, self.head_dim)  # 单头值
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # 投影
        q = self.q_proj(x)  # (batch, seq_len, embed_dim)
        k = self.k_proj(x)  # (batch, seq_len, head_dim)
        v = self.v_proj(x)  # (batch, seq_len, head_dim)

        # 重塑查询为多头
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (batch, num_heads, seq_len, head_dim)
        
        # 键和值保持单头,扩展到多头维度
        k = k.unsqueeze(1)  # (batch, 1, seq_len, head_dim)
        v = v.unsqueeze(1)  # (batch, 1, seq_len, head_dim)

        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # (batch, num_heads, seq_len, seq_len)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # (batch, num_heads, seq_len, head_dim)

        # 合并多头
        out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        out = self.out_proj(out)  # (batch, seq_len, embed_dim)

        return out

# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim)  # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape)  # torch.Size([2, 10, 64])
相关推荐
Python图像识别6 小时前
75_基于深度学习的咖啡叶片病害检测系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo
闲人编程6 小时前
Python游戏开发入门:Pygame实战
开发语言·python·游戏·pygame·毕设·codecapsule
雍凉明月夜7 小时前
人工智能学习中深度学习之python基础之 类
python·学习
Geo_V7 小时前
OpenAI 大模型 API 使用示例
python·chatgpt·openai·大模型应用·llm 开发
Hello_WOAIAI7 小时前
2.4 python装饰器在 Web 框架和测试中的实战应用
开发语言·前端·python
百锦再7 小时前
第1章 Rust语言概述
java·开发语言·人工智能·python·rust·go·1024程序员节
tokepson8 小时前
chatgpt-to-md优化并重新复习
python·ai·技术·pypi·记录
Victory_orsh8 小时前
“自然搞懂”深度学习(基于Pytorch架构)——010203
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
java1234_小锋8 小时前
PyTorch2 Python深度学习 - 模型保存与加载
开发语言·python·深度学习·pytorch2
Python图像识别8 小时前
74_基于深度学习的垃圾桶垃圾溢出检测系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo