GQA分组注意力机制

一、目录

  1. 定义
  2. demo

二、实现

  1. 定义
    grouped query attention(GQA)
    1 GQA 原理与优点:将query 进行分组,每组query 参数共享一份key,value, 从而使key, value 矩阵变小。
    2. 优点: 降低内存读取模型权重的时间开销:由于Key矩阵和Value矩阵数量变少了,因此权重参数量也减少了,需要读取到内存的数量量少了,因此减少了读取权重的等待时间。
    3. 效果(并未降低模型性能):GQA通过设置合适的分组大小,可以和MQA的推理性能几乎相等,同时逼近MHA的模型性能。

  2. llama3 分组数为4, chatglm2 分组数为2 .


    参考:https://zhuanlan.zhihu.com/p/693928854
    demo

    import torch
    import torch.nn as nn
    import math

    #GQA
    bs=3
    seq_len =5
    hidden_size= 32
    n_heads=4
    n_kv_heads = 2
    head_dim = hidden_size//n_heads #
    groups = n_heads//n_kv_heads # 4/2
    print("groups=",groups)
    x=torch.randn((bs,seq_len,hidden_size))
    print("x:", x.shape)
    wq = nn.Linear(hidden_size,n_heads*head_dim,bias=False)
    wk = nn.Linear(hidden_size, n_kv_heads * head_dim, bias=False)
    wv = nn.Linear(hidden_size, n_kv_heads * head_dim, bias=False)
    xq,xk,xv=wq(x),wk(x),wv(x)
    xq = xq.view(bs,seq_len, n_heads, head_dim).transpose(1, 2)
    xk = xk.view(bs,seq_len, n_kv_heads, head_dim).transpose(1, 2)
    xv = xv.view(bs,seq_len, n_kv_heads, head_dim).transpose(1, 2)
    print("xq:",xq.shape) #[bs,n_heads,seq_len, head_dim]
    print("xk:", xk.shape)#[bs,n_kv_heads,seq_len, head_dim]
    print("xv:", xv.shape)#[bs,n_kv_heads,seq_len, head_dim]
    def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
    return keys, values
    #复制kv head
    key,val = repeat_kv(xk,xv, groups,dim=1)
    print("key:", key.shape)
    print("val:", val.shape)
    attn_weights = torch.matmul(xq, key.transpose(2, 3)) / math.sqrt(head_dim)
    print("attn_weights:", attn_weights.shape) #[bs,n_heads,seq_len,seq_len]
    attn_output = torch.matmul(attn_weights, val)
    print("attn_output:", attn_output.shape) # [bs,n_heads,seq_len,head_dim]

相关推荐
z小猫不吃鱼17 分钟前
15 InstructGPT 论文精读:SFT + RLHF 如何让模型听懂指令?
人工智能·深度学习·算法·机器学习·语言模型·自然语言处理·gpt-3
zcg194233 分钟前
如何在CV中使用transformer
人工智能·深度学习·transformer
SuperHeroWu743 分钟前
【MindSpore】MindSpore 开源深度学习框架
人工智能·深度学习·开源·框架·mindspore
weixin_468466851 小时前
Airtable 零基础快速上手与实战指南
数据库·人工智能·python·深度学习·ai·大模型
hsg771 小时前
简述:ResNet34/ResNet50及SENet改进模型
人工智能·深度学习
weixin_468466851 小时前
图像分割新手入门:从环境搭建到实战应用
图像处理·人工智能·深度学习·计算机视觉·ai
codefan※1 小时前
pytorch安装流程
人工智能·pytorch·python
笑脸惹桃花1 小时前
目标检测:YOLOv12环境配置,超详细,适合0基础纯小白
深度学习·yolo·目标检测·目标跟踪·yolov12
kTR2hD1qb1 小时前
深度学习进阶(二十五)RoPE:现代 NLP 的位置编码范式
人工智能·深度学习·自然语言处理
钓了猫的鱼儿1 小时前
基于深度学习+AI的无人机违规防控目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·无人机