MHD、MQA、GQA注意力机制详解

MHD、MQA、GQA注意力机制详解

注意力机制详解及代码

前言:

自回归解码器推理是 Transformer 模型的 一个严重瓶颈,因为在每个解码步骤中加 载解码器权重以及所有注意键和值会产生 内存带宽开销

下图为三种注意力机制的结构图和实验结果

MHA

多头注意力机制是Transformer模型中的核心组件。在其设计中,"多头"意味着该机制并不只计算一种注意力权重,而是并行计算多种权重,每种权重都从不同的"视角"捕获输入的不同信息。

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
python 复制代码
import torch
from torch import nn
class MutiHeadAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key)
        value = self.split_head(value)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        ## 对注意力输出进行拼接
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
 
        
    def split_head(self, x):
        batch_size = x.size()[0]
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)

MQA

多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用了多查询注意力,如PaLM但许多语言模型没有,包括公开可用的语言模型,如T5和LLaM.

  • hidden_state经过线性层得到q、k、v
  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=1,v=1)。相当于多个query,即多查询。
  • q、k计算注意力分数score
  • softmax对注意力分数进行归一化得到注意力权重attention_probs
  • 使用注意力权重和值计算输出:output
  • 对注意力输出进行拼接concat
python 复制代码
## 多查询注意力
import torch
from torch import nn
class MutiQueryAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MutiQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        ## 初始化Q、K、V投影矩阵
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.head_dim) ###
        self.v_linear = nn.Linear(hidden_size, self.head_dim) ###
        
        ## 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, hidden_state, attention_mask=None):
        batch_size = hidden_state.size()[0]
        
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)
        
        query = self.split_head(query)
        key = self.split_head(key, 1)
        value = self.split_head(value, 1)
        
        ## 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
        
        if attention_mask != None:
            attention_scores += attention_mask * -1e-9
        
        ## 对注意力分数进行归一化
        attention_probs = torch.softmax(attention_scores, dim=-1)
        
        output = torch.matmul(attention_probs, value)
        
        output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
        
        output = self.o_linear(output)
        
        return output
        
        
        
        
    def split_head(self, x, head_num=None):
        
        batch_size = x.size()[0]
        
        if head_num == None:
            return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        else:
            return x.view(batch_size, -1, head_num, self.head_dim).transpose(1,2)
    
    

GQA

  • 使用 5% 的原始预训练 计算将现有的多头语言模型检查点训 练到具有 MQA 的模型中
  • 引入分组查询注意力 (GQA),这是多 头语言模型的泛化。查询注意力,它使用中间,多于一个,少于查询头数量的键值头。
  • 经过训练的GQA 实现了接近多头注意力 的质量,并且速度与 MQA 相当。
  • hidden_state经过线性层得到q、k、v

  • q、k、v经过split后增加一个维度:num_heads(q = num_heads,k=group_num,v=group_num)。相当于把多头分组了,比如原先有10个头,那就是10个query,分成5组,每组2个query,1个value,1个key。

  • q、k计算注意力分数score

  • softmax对注意力分数进行归一化得到注意力权重attention_probs

  • 使用注意力权重和值计算输出:output

  • 对注意力输出进行拼接concat

    分组注意力查询

    import torch
    from torch import nn
    class MutiGroupAttention(torch.nn.Module):
    def init(self, hidden_size, num_heads, group_num):
    super(MutiGroupAttention, self).init()
    self.num_heads = num_heads
    self.head_dim = hidden_size // num_heads
    self.group_num = group_num

          ## 初始化Q、K、V投影矩阵
          self.q_linear = nn.Linear(hidden_size, hidden_size)
          self.k_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)
          self.v_linear = nn.Linear(hidden_size, self.group_num * self.head_dim)
          
          ## 输出线性层
          self.o_linear = nn.Linear(hidden_size, hidden_size)
          
      def forward(self, hidden_state, attention_mask=None):
          batch_size = hidden_state.size()[0]
          
          query = self.q_linear(hidden_state)
          key = self.k_linear(hidden_state)
          value = self.v_linear(hidden_state)
          
          query = self.split_head(query)
          key = self.split_head(key, self.group_num)
          value = self.split_head(value, self.group_num)
          
          ## 计算注意力分数
          attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim))
          
          if attention_mask != None:
              attention_scores += attention_mask * -1e-9
          
          ## 对注意力分数进行归一化
          attention_probs = torch.softmax(attention_scores, dim=-1)
          
          output = torch.matmul(attention_probs, value)
          
          output = output.transpose(-1, -2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)
          
          output = self.o_linear(output)
          
          return output
          
          
          
          
      def split_head(self, x, group_num=None):
          
          batch_size,seq_len = x.size()[:2]
          
          if group_num == None:
              return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
          else:
              x = x.view(batch_size, -1, group_num, self.head_dim).transpose(1,2)
              x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads // group_num * group_num, seq_len, self.head_dim)
              return x
    
相关推荐
Guofu_Liao3 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
gz7seven5 小时前
BLIP-2模型的详解与思考
大模型·llm·多模态·blip·多模态大模型·blip-2·q-former
不爱说话郭德纲10 小时前
探索LLM前沿,共话科技未来
人工智能·算法·llm
AI_小站12 小时前
RAG 示例:使用 langchain、Redis、llama.cpp 构建一个 kubernetes 知识库问答
人工智能·程序人生·langchain·kubernetes·llama·知识库·rag
Guofu_Liao12 小时前
Llama模型文件介绍
人工智能·llama
Donvink16 小时前
多模态大语言模型——《动手学大模型》实践教程第六章
人工智能·深度学习·语言模型·自然语言处理·llama
我爱学Python!17 小时前
解决复杂查询难题:如何通过 Self-querying Prompting 提高 RAG 系统效率?
人工智能·程序人生·自然语言处理·大模型·llm·大语言模型·rag
rommel rain19 小时前
SpecInfer论文阅读
人工智能·语言模型·transformer
Donvink19 小时前
大模型安全和越狱攻击——《动手学大模型》实践教程第五章
深度学习·安全·语言模型·llama
Donvink20 小时前
大模型智能体安全——《动手学大模型》实践教程第七章
深度学习·安全·语言模型·prompt·llama