MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的"视角"捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
python 复制代码
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key)
        value = self.split_head(value)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        ## 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
 
        
    def split_head(self, x):
        batch_size = x.size()[0]
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
python 复制代码
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.head_dim) ###
        self.v_linear = nn.Linear(hidden_size, self.head_dim) ###
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key, 1)
        value = self.split_head(value, 1)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
        
        
        
        
    def split_head(self, x, head_num=None):
        
        batch_size = x.size()[0]
        
        if head_num == None:
            return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        else:
            return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)
    
    

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v

  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。

  • q、k计算注意力分数score

  • softmax对注意力分数进行归一化得到注意力权重attention_probs

  • 使用注意力权重和值计算输出:output

  • 对注意力输出进行拼接concat

    分组注意力查询

    import torch
    from torch import nn
    class MutiGroupAttention(torch.nn.Module):
    def init(self, hidden_size, num_heads, group_num):
    super(MutiGroupAttention, self).init()
    self.num_heads = num_heads
    self.head_dim = hidden_size // num_heads
    self.group_num = group_num

    复制代码
          ## 初始化Q、K、V投影矩阵
          self.q_linear = nn.Linear(hidden_size, hidden_size)
          self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)
          self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)
          
          ## 输出线性层
          self.o_linear = nn.Linear(hidden_size, hidden_size)
          
      def forward(self, hidden_state, attention_mask=None):
          batch_size = hidden_state.size()[0]
          
          query = self.q_linear(hidden_state)
          key = self.k_linear(hidden_state)
          value = self.v_linear(hidden_state)
          
          query = self.split_head(query)
          key = self.split_head(key, self.group_num)
          value = self.split_head(value, self.group_num)
          
          ## 计算注意力分数
          attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
          
          if attention_mask != None:
              attention_scores += attention_mask * -1e-9
          
          ## 对注意力分数进行归一化
          attention_probs = torch.softmax(attention_scores, dim=-1)
          
          output = torch.matmul(attention_probs, value)
          
          output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
          
          output = self.o_linear(output)
          
          return output
          
          
          
          
      def split_head(self, x, group_num=None):
          
          batch_size,seq_len = x.size()[:2]
          
          if group_num == None:
              return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
          else:
              x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)
              x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)
              return x
相关推荐
RUNTIME2 分钟前
大模型微调实操记录
llm
数据智能老司机7 分钟前
构建具备自主性的人工智能系统——探索协调者、工作者和委托者方法
深度学习·llm·aigc
数据智能老司机12 分钟前
构建具备自主性的人工智能系统——使代理能够使用工具和进行规划
深度学习·llm·aigc
量子位4 小时前
智能车速度刷新:仅 10 个月,首个纯端侧大模型上车量产!
人工智能·llm
技术你大飞哥6 小时前
【突破数据孤岛】MCP协议进化史:从 STDIO 到全双工流式 —— AI 应用开发效率提升 90% 的秘密武器
llm·ai编程·mcp
Goboy7 小时前
零基础搞定 Trae 智能体配置 + MySQL MCP 集成!手把手教学
llm·ai编程·trae
COOCC18 小时前
PyTorch 实战:Transformer 模型搭建全解析
人工智能·pytorch·python·深度学习·神经网络·目标检测·transformer
漫谈网络21 小时前
Ollama API 应用指南
ai·llm·aigc·api·ollama
数据智能老司机1 天前
使用 FastAPI 构建生成式 AI 服务——与生成模型的实时通信
llm·openai·fastapi
数据智能老司机1 天前
使用 FastAPI 构建生成式 AI 服务——AI集成与模型服务
llm·openai·fastapi