理解Attention,MHA、MQA、GQA理论知识和代码实现

理论知识链接:理解Attention:从起源到MHA,MQA和GQA | Linsight

现有模型升级方法:https://blog.nghuyong.top/2023/09/10/NLP/llm-attention/

pytorch代码实现:

复制代码
class BaseAttention(torch.nn.Module):
    def __init__(self):
        super(BaseAttention, self).__init__()
        self.softmax = torch.nn.Softmax(dim=-1)

    def attention(self, q, k, v, mask=None, dropout=None):
        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])

        if mask is not None:
            attn = attn + mask
        
        attn = self.softmax(attn)
        if dropout is not None:
            attn = dropout(attn)
        output = torch.matmul(attn, v)
        return output


class Attention(BaseAttention):

    def __init__(self, hidden_size, dropout=None):
        super(Attention, self).__init__()
        self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.softmax = torch.nn.Softmax(dim=-1)
        
        if dropout is not None:
            self.dropout = torch.nn.Dropout(p=dropout)
        else:
            self.dropout = None
    
    def forward(self, x, mask=None):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        output = self.attention(q, k, v, mask, self.dropout)
        return output


class MHAttention(BaseAttention):

    def __init__(self, hidden_size, num_heads=32, dropout=None):
        super(MHAttention, self).__init__()
        self.num_heads = num_heads
        self.softmax = torch.nn.Softmax(dim=-1)
        self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
        
        if dropout is not None:
            self.dropout = torch.nn.Dropout(p=dropout)
    
    def forward(self, x, mask=None):
        bs, seq_len, hidden_size = x.shape

        q = self.q_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        k = self.k_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        v = self.v_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        output = self.attention(q, k, v, mask, self.dropout)
        output = output.view(bs, seq_len, hidden_size)
        return output


class MQAttention(BaseAttention):

    def __init__(self, hidden_size, num_heads=32, dropout=None):
        super(MQAttention, self).__init__()
        self.num_heads = num_heads
        self.softmax = torch.nn.Softmax(dim=-1)
        assert hidden_size % num_heads == 0
        self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.k_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads)
        self.v_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads)
        
        if dropout is not None:
            self.dropout = torch.nn.Dropout(p=dropout)
    
    def forward(self, x, mask=None):
        bs, seq_len, hidden_size = x.shape

        q = self.q_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        k = self.k_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
        v = self.v_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
        output = self.attention(q, k, v, mask, self.dropout)
        output = output.view(bs, seq_len, hidden_size)
        return output


class GQAttention(BaseAttention):

    def __init__(self, hidden_size, num_heads=32, num_kv_heads=8, dropout=None):
        super(GQAttention, self).__init__()
        assert hidden_size % num_heads == 0 and num_heads % num_kv_heads == 0

        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_group = num_heads // num_kv_heads
        self.softmax = torch.nn.Softmax(dim=-1)
        self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
        self.k_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads * num_kv_heads)
        self.v_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads * num_kv_heads)
        
        if dropout is not None:
            self.dropout = torch.nn.Dropout(p=dropout)
    
    def repeat_kv(self, feature, num_group): #llama2源码
        bs, num_kv_heads, seq_len, head_dims = feature.shape
        if num_group == 1:
            return feature
        feature = feature[:, :, None, :, :].expand(bs, num_kv_heads, num_group, seq_len, head_dims)
        return feature.reshape(bs, num_kv_heads * num_group, seq_len, head_dims)

    def forward(self, x, mask=None):
        bs, seq_len, hidden_size = x.shape

        q = self.q_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        k = self.k_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
        v = self.v_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
        k, v = self.repeat_kv(k, self.num_group), self.repeat_kv(v, self.num_group)
        output = self.attention(q, k, v, mask, self.dropout)
        output = output.view(bs, seq_len, hidden_size)
        return output
        

model = Attention(hidden_size=4096, dropout=0.1)
model = MHAttention(hidden_size=4096, num_heads=32, dropout=0.1)
model = MQAttention(hidden_size=4096, num_heads=32, dropout=0.1)
model = GQAttention(hidden_size=4096, num_heads=32, num_kv_heads=4, dropout=0.1)
input_data = torch.randn(1, 20, 4096)
output = model(input_data)
print()
相关推荐
板面华仔12 分钟前
机器学习入门(三)——决策树(Decision Tree)
人工智能·决策树·机器学习
GAOJ_K25 分钟前
滚珠花键的无预压、间隙调整与过盈配合“场景适配型”
人工智能·科技·机器人·自动化·制造
ai_xiaogui29 分钟前
【开源探索】Panelai:重新定义AI服务器管理面板,助力团队私有化算力部署与模型运维
人工智能·开源·私有化部署·docker容器化·panelai·ai服务器管理面板·comfyui集群管理
源于花海34 分钟前
迁移学习的前沿知识(AI与人类经验结合、传递式、终身、在线、强化、可解释性等)
人工智能·机器学习·迁移学习·迁移学习前沿
机 _ 长37 分钟前
YOLO26 改进 | 基于特征蒸馏 | 知识蒸馏 (Response & Feature-based Distillation)
python·深度学习·机器学习
king of code porter1 小时前
百宝箱企业版搭建智能体应用-平台概述
人工智能·大模型·智能体
愚公搬代码1 小时前
【愚公系列】《AI短视频创作一本通》004-AI短视频的准备工作(创作AI短视频的基本流程)
人工智能·音视频
物联网软硬件开发-轨物科技1 小时前
【轨物洞见】告别“被动维修”!预测性运维如何重塑老旧电站的资产价值?
运维·人工智能
电商API_180079052471 小时前
第三方淘宝商品详情 API 全维度调用指南:从技术对接到生产落地
java·大数据·前端·数据库·人工智能·网络爬虫
小杨互联网1 小时前
LLM应用三大隐形风险与防护方案详解
llm