手撕MHA、MLA、MQA、GQA

多头自注意力机制(Multi-Head Attention, MHA)

通过并行计算多个注意力头,使模型能够同时关注输入序列中不同位置的特征。其核心思想是将输入映射到多个子空间,分别计算注意力权重并聚合结果,从而增强模型对复杂模式的捕捉能力。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size  # 嵌入维度大小
        self.heads = heads              # 头的数量
        self.head_dim = embed_size // heads  # 每个头的维度

        # 确保嵌入维度可以被头的数量整除
        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        # 定义线性变换层,用于将输入转换为查询、键和值
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        # 最终输出的全连接层
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # 批量大小
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # 将嵌入向量分割成多个头
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # 对每个头进行线性变换
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # 计算点积注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)

        # 应用掩码(如果存在)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # 对注意力分数应用softmax函数
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # 计算加权求和后的输出
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then flatten last two dimensions

        # 通过全连接层得到最终输出
        out = self.fc_out(out)
        return out

# 示例用法:
if __name__ == "__main__":
    embed_size = 512  # 嵌入维度大小
    heads = 8         # 头的数量
    batch_size = 64   # 批量大小
    seq_length = 10   # 序列长度

    mha = MultiHeadAttention(embed_size, heads)
    values = keys = query = torch.randn(batch_size, seq_length, embed_size)

    print(mha.forward(values, keys, query, mask=None).shape)

多查询注意力机制(Multi-Query Attention,MQA)

Multi-Query Attention (MQA) 是对多头注意力(MHA)的高效改进版本,其核心思想是共享键(Key)和值(Value)的投影参数,仅对查询(Query)使用独立的头参数。这种方法显著减少了模型参数量和计算复杂度,同时保留了多头注意力的部分并行性优势。

python 复制代码
import torch
import torch.nn as nn
from thop import profile

class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        """
        Multi-Query Attention 的实现。

        Args:
            hidden_size (int): 输入特征的维度,也即 hidden_state 的最后一维。
            num_heads (int): 注意力头的数量。
            dropout (float): dropout 的概率,默认为 0.0。
        """
        super(MultiQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads  # 每个头的维度
        
        # 定义线性变换层,用于生成 Q, K, V
        self.query = nn.Linear(hidden_size, hidden_size)  # 每个头独立的 Query
        self.key = nn.Linear(hidden_size, self.head_dim)   # 所有头共享的 Key
        self.value = nn.Linear(hidden_size, self.head_dim) # 所有头共享的 Value
        
        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)

    def forward(self, query, key, value, mask=None):
        N = query.shape[0]  # 批量大小
        query_len = query.shape[1]
        key_len = key.shape[1]
        value_len = value.shape[1]

        # 将嵌入向量分割成多个头
        query = self.query(query).view(N, query_len, self.num_heads, self.head_dim)
        key = self.key(key).view(N, key_len, self.num_heads, self.head_dim)
        value = self.value(value).view(N, value_len, self.num_heads, self.head_dim)

        # 调整形状以便进行点积操作
        query = query.transpose(1, 2)  # shape: (N, heads, query_len, head_dim)
        key = key.transpose(1, 2)      # shape: (N, heads, key_len, head_dim)
        value = value.transpose(1, 2)    # shape: (N, heads, value_len, head_dim)

        # 计算点积注意力分数
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
        # scores shape: (N, heads, query_len, key_len)

        # 应用掩码(如果存在)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 对注意力分数应用softmax函数
        attention_weights = torch.softmax(scores, dim=-1)
        # attention_weights shape: (N, heads, query_len, key_len)

        # 计算加权求和后的输出
        context = torch.matmul(attention_weights, value)
        # context shape: (N, heads, query_len, head_dim)

        # 调整形状以便通过全连接层
        context = context.transpose(1, 2).contiguous().view(N, query_len, self.hidden_size)
        # context shape: (N, query_len, hidden_size)

        # 通过全连接层得到最终输出
        output = self.out_projection(context)
        # output shape: (N, query_len, hidden_size)

        return output

# 示例用法:
if __name__ == "__main__":
    hidden_size = 512  # 嵌入维度大小
    num_heads = 8      # 头的数量
    batch_size = 64    # 批量大小
    seq_length = 10    # 序列长度

    mqa = MultiQueryAttention(hidden_size, num_heads)
    query = key = value = torch.randn(batch_size, seq_length, hidden_size)

    print(mqa.forward(query, key, value, mask=None).shape)

分组查询注意力机制(Grouped Query Attention,GQA)

Grouped Query Attention (GQA) 是对多头注意力(MHA)和多查询注意力(MQA)的折中优化方案。其核心思想是将查询头(Query Heads)划分为多个组(Group),每组内的查询头共享一组键(Key)和值(Value),从而在保留多头并行性的同时减少参数量和计算复杂度。GQA 在参数效率与模型性能之间取得了平衡,适用于大规模模型的高效部署。

python 复制代码
import torch
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):
        """
        Grouped Query Attention 实现。

        Args:
            hidden_size (int): 输入特征的维度。
            num_heads (int): 查询头的数量。
            group_size (int): 每个组中包含的查询头数量。
            dropout (float): dropout 的概率。
        """
        super(GroupedQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除"
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.group_size = group_size
        self.group_num = num_heads // group_size
        self.head_dim = hidden_size // num_heads  # 每个头的维度
        
        # 定义线性变换层,用于生成 Q, K, V
        self.query = nn.Linear(hidden_size, hidden_size)  # 每个头独立的 Query
        self.key = nn.Linear(hidden_size, self.group_num * self.head_dim)  # 分组共享的 Key
        self.value = nn.Linear(hidden_size, self.group_num * self.head_dim)  # 分组共享的 Value
        
        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)

    def forward(self, query, key, value, mask=None):
        N = query.shape[0]  # 批量大小
        query_len = query.shape[1]
        key_len = key.shape[1]
        value_len = value.shape[1]

        # 将嵌入向量分割成多个头
        query = self.query(query).view(N, query_len, self.num_heads, self.head_dim)
        key = self.key(key).view(N, key_len, self.group_num, self.head_dim)
        value = self.value(value).view(N, value_len, self.group_num, self.head_dim)

        # 调整形状以便进行点积操作
        query = query.transpose(1, 2)  # shape: (N, heads, query_len, head_dim)
        key = key.transpose(1, 2)      # shape: (N, group_num, key_len, head_dim)
        value = value.transpose(1, 2)    # shape: (N, group_num, value_len, head_dim)

        # 计算点积注意力分数
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
        # scores shape: (N, heads, query_len, key_len)

        # 应用掩码(如果存在)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 对注意力分数应用softmax函数
        attention_weights = torch.softmax(scores, dim=-1)
        # attention_weights shape: (N, heads, query_len, key_len)

        # 计算加权求和后的输出
        context = torch.matmul(attention_weights, value)
        # context shape: (N, heads, query_len, head_dim)

        # 合并组内的头
        context = context.view(N, query_len, self.num_heads, self.head_dim)
        context = context.permute(0, 2, 1, 3).contiguous().view(N, query_len, self.hidden_size)
        # context shape: (N, query_len, hidden_size)

        # 通过全连接层得到最终输出
        output = self.out_projection(context)
        # output shape: (N, query_len, hidden_size)

        return output

# 示例用法:
if __name__ == "__main__":
    hidden_size = 512  # 嵌入维度大小
    num_heads = 8      # 头的数量
    group_size = 2     # 每个组中包含的查询头数量
    batch_size = 64    # 批量大小
    seq_length = 10    # 序列长度

    gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)
    query = key = value = torch.randn(batch_size, seq_length, hidden_size)

    print(gqa.forward(query, key, value, mask=None).shape)

多头潜在注意力(Multi-Head Latent Attention, MLA)

Multi-Head Latent Attention (MLA) 是一种结合低秩参数化与旋转位置编码(RoPE)的高效注意力机制。其核心思想是通过低秩投影压缩查询(Q)、键(K)、值(V)的维度,并在注意力计算中解耦内容与位置信息,从而减少计算复杂度,同时保留长距离依赖建模能力。MLA 特别适用于大规模模型的部署,平衡了效率与性能。

python 复制代码
import torch
import torch.nn as nn
import math

class RotaryEmbedding(nn.Module):
    def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
        """
        RoPE位置编码模块

        Args:
            hidden_size (int): 模型维度
            num_heads (int): 注意力头数量
            base (int): 频率基值
            max_len (int): 最大序列长度
        """
        super().__init__()
        self.head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.base = base
        self.max_len = max_len
        self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()

    def _compute_pos_emb(self):
        # 计算频率因子
        theta_i = 1. / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
        # 创建位置索引
        positions = torch.arange(self.max_len)
        # 计算位置编码
        pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)
        # 计算cos和sin的位置编码
        cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
        sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)
        return cos_pos, sin_pos

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.size(1)
        # 获取缓存中的位置编码
        cos_pos = self.cos_pos_cache[:seq_len].unsqueeze(1)
        sin_pos = self.sin_pos_cache[:seq_len].unsqueeze(1)

        # 分割x为奇数和偶数部分
        x1, x2 = x.chunk(2, dim=-1)
        # 应用旋转位置编码
        rotated_x1 = x1 * cos_pos - x2 * sin_pos
        rotated_x2 = x1 * sin_pos + x2 * cos_pos
        # 合并结果
        rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1)
        return rotated_x

# 示例用法:
if __name__ == "__main__":
    hidden_size = 512  # 模型维度
    num_heads = 8      # 注意力头数量
    batch_size = 64    # 批量大小
    seq_length = 10    # 序列长度

    rope = RotaryEmbedding(hidden_size, num_heads)
    x = torch.randn(batch_size, seq_length, hidden_size)

    print(rope(x).shape)
相关推荐
麦当_10 小时前
基于 Shadcn 的可配置表单解决方案
前端·javascript·面试
天天摸鱼的java工程师12 小时前
商品详情页 QPS 达 10 万,如何设计缓存架构降低数据库压力?
java·后端·面试
天天摸鱼的java工程师12 小时前
设计一个分布式 ID 生成器,要求全局唯一、趋势递增、支持每秒 10 万次生成,如何实现?
java·后端·面试
掘金安东尼15 小时前
9 个【宝藏工具】精选,大幅提升效率与灵感!
前端·面试·github
_一条咸鱼_16 小时前
Android Runtime二进制镜像(ART Image)生成原理(44)
android·面试·android jetpack
顾林海16 小时前
Android线程栈优化全解析:从创建流程到内存管控的深度实践
android·面试·性能优化
有仙则茗16 小时前
JS 迭代器是什么东西
前端·javascript·面试
_一条咸鱼_17 小时前
Android Runtime全局优化与跨函数分析原理(43)
android·面试·android jetpack
Spirited_Away17 小时前
什么你还不会用navigation来管理导航?
前端·javascript·面试