LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN:https://spike.blog.csdn.net/

本文地址:https://spike.blog.csdn.net/article/details/145368666


GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构,GroupQueryAttention 是注意力机制的变体,通过将查询(Query)分组,每组与相同的键(Key)值(Value)交互,优化计算效率和性能,保持模型对于输入信息有效关注,减少计算资源的消耗,适用于处理大规模数据和复杂任务的场景。KVCache 是缓存机制,用于存储和快速检索键值对(KV),当模型处理新的输入(Q)时,直接从缓存中读取KV数据,无需重新计算,显著提高模型的推理速度和效率。GQA 与 KVCache 在提升模型性能和优化资源利用方面,都发挥着重要作用,结合使用可以进一步增强模型在实际应用中的表现。

从 MHA 到 GQA,再到 GQA+KVCache,简单实现,参考:

Scaled Dot-Product Attention (缩放点积注意力机制),也称单头自注意力机制,公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QK⊤)V

1. MultiHeadAttention

MultiHeadAttention (多头注意力机制),合计 43 行:

  1. __init__ 初始化 (10行):
    • 输入:heads(头数)、d_model(维度)、dropout (用于 scores)
    • 计算 d_k 每个 Head 的维度,即 d m o d e l = h e a d s × d k d_{model} = heads \times d_{k} dmodel=heads×dk
    • 线性层是 QKVO,Dropout 层
  2. attention 注意力 (10行):
    • q q q 的维度 [bs,h,s,d],与 k ⊤ k^{\top} k⊤ 的 [bs,h,d,s],mm 之后 scores 是 [bs,h,s,s]
    • mask 的维度是 [bs,s,s],使用 unsqueeze(1),转换成 [bs,1,s,s]
    • QKV 的计算,额外支持 Dropout
  3. forward 推理 (12行):
    • QKV Linear 转换成 [bs,s,h,dk],再转换 [bs,h,s,dk]
    • 计算 attn 的 [bs,h,s,dk]
    • 转换 [bs,s,h,dk],再 contiguous(),再 合并 h × d k = d h \times d_{k} = d h×dk=d
    • 再过 O
  4. 测试 (11行):
    • torch.randn 构建数据
    • Mask 的 torch.tril(torch.ones(bs, s, s))

即:

python 复制代码
import math
import torch
import torch.nn.functional as F
from torch import nn
class MultiHeadAttention(nn.Module):
    """
    多头自注意力机制 MultiHeadAttention
    """
    def __init__(self, heads, d_model, dropout=0.1):  # 10行
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):  # 10行
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None):  # 12行
        bs = q.size(0)
        # 进行线性操作划分为成 h 个头
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        # 矩阵转置
        k = k.transpose(1, 2)  # [bs,h,s,d] = [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
        print(f"[Info] attn: {attn.shape}")
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output
def main():
    # 设置超参数
    bs, s, h, d = 2, 10, 8, 512
    dropout_rate = 0.1
    # 创建 MultiHeadAttention 实例
    attention = MultiHeadAttention(h, d, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 测试无掩码的情况
    output_no_mask = attention(q, k, v)
    print("Output shape without mask:", output_no_mask.shape)
    # 测试有掩码的情况
    output_with_mask = attention(q, k, v, mask)
    print("Output shape with mask:", output_with_mask.shape)
    # 检查输出是否符合预期
    assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
    assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
    print("Test passed!")
if __name__ == '__main__':
    main()

2. GroupQueryAttention

GroupQueryAttention (分组查询注意力机制) ,相比于 MHA,参考 torch.nn.functional.scaled_dot_product_attention

  1. __init__ :增加参数 kv_heads,即 KV Head 数量,KV 的 Linear 层输出维度(kv_heads * self.d_k)也需要修改。
  2. forward:使用 repeat_interleave 扩充 KV 维度,其他相同,增加 3 行。

即:

python 复制代码
import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力机制(Group Query Attention)
    """
    def __init__(self, heads, d_model, kv_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.kv_heads = kv_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # [2, 8, 10, 64] x [2, 8, 64, 10] = [2, 8, 10, 10]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        # 进行线性操作
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 10, 8, 64]
        k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 10, 4, 64]
        v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)
        # 复制键值头以匹配查询头的数量
        group = self.h // self.kv_heads
        k = k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]
        v = v.repeat_interleave(group, dim=2)
        # 矩阵转置, 将 head 在前
        k = k.transpose(1, 2)  # [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output
def main():
    # 设置超参数, GQA 8//4=2组
    bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
    dropout_rate = 0.1
    # 创建 MultiHeadAttention 实例
    attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 测试无掩码的情况
    output_no_mask = attention(q, k, v)
    print("Output shape without mask:", output_no_mask.shape)
    # 测试有掩码的情况
    output_with_mask = attention(q, k, v, mask)
    print("Output shape with mask:", output_with_mask.shape)
    # 检查输出是否符合预期
    assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"
    assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"
    print("Test passed!")
if __name__ == '__main__':
    main()

3. GQA + KVCache

GroupQueryAttention + KVCache,相比于 GQA,增加 KVCache:

  1. forward :增加参数 kv_cache,合并 [cached_k, new_k],同时返回 new_kv_cache,用于迭代,增加 5 行。
  2. 设置 cur_qkvcur_mask,迭代序列s维度,合计 8 行。

即:

python 复制代码
import math
import torch
import torch.nn.functional as F
from torch import nn
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力机制(Group Query Attention)
    """
    def __init__(self, heads, d_model, kv_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.kv_heads = kv_heads
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.v_linear = nn.Linear(d_model, kv_heads * self.d_k)
        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    @staticmethod
    def attention(q, k, v, d_k, mask=None, dropout=None):
        # [2, 8, 1, 64] x [2, 8, 64, 10] = [2, 8, 1, 10]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        # 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0
        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)
        if dropout is not None:
            scores = dropout(scores)
        output = torch.matmul(scores, v)
        return output
    def forward(self, q, k, v, mask=None, kv_cache=None):
        bs = q.size(0)
        # 进行线性操作
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)  # [2, 1, 8, 64]
        new_k = self.k_linear(k).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]
        new_v = self.v_linear(v).view(bs, -1, self.kv_heads, self.d_k)  # [2, 1, 4, 64]
        # 处理 KV Cache
        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            new_k = torch.cat([cached_k, new_k], dim=1)
            new_v = torch.cat([cached_v, new_v], dim=1)
        # 复制键值头以匹配查询头的数量
        group = self.h // self.kv_heads
        k = new_k.repeat_interleave(group, dim=2)  # [2, 10, 4, 64] -> [2, 10, 8, 64]
        v = new_v.repeat_interleave(group, dim=2)
        # 矩阵转置, 将 head 在前
        # KV Cache 最后1轮: q--->[2, 8, 1, 64] k->[2, 8, 10, 64] v->[2, 8, 10, 64]
        k = k.transpose(1, 2)  # [2, 8, 10, 64]
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 计算 attention
        attn = self.attention(q, k, v, self.d_k, mask, self.dropout)  # [2, 8, 1, 64]
        print(f"[Info] attn: {attn.shape}")
        # 连接多个头并输入到最后的线性层
        concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        # 更新 KV Cache
        new_kv_cache = (new_k, new_v)  # 当前的 KV 缓存
        return output, new_kv_cache
def main():
    # 设置超参数
    bs, s, h, d, kv_heads = 2, 10, 8, 512, 4
    dropout_rate = 0.1
    # 创建 GroupQueryAttention 实例
    attention = GroupQueryAttention(h, d, kv_heads, dropout_rate)
    # 创建随机输入张量
    q = torch.randn(bs, s, d)
    k = torch.randn(bs, s, d)
    v = torch.randn(bs, s, d)
    # 可选:创建掩码,因果掩码,上三角矩阵
    mask = torch.tril(torch.ones(bs, s, s))
    # 模拟逐步生成序列,测试 KV Cache
    print("Testing KV Cache...")
    kv_cache, output = None, None
    for i in range(s):
        cur_q = q[:, i:i+1, :]
        cur_k = k[:, i:i+1, :]
        cur_v = v[:, i:i+1, :]
        cur_mask = mask[:, i:i+1, :i+1]   # q是 i:i+1,k是 :i+1
        output, kv_cache = attention(cur_q, cur_k, cur_v, cur_mask, kv_cache)
        print(f"Output shape at step {i}:", output.shape)
    # 检查输出是否符合预期
    assert output.shape == (bs, 1, d), "Output shape is incorrect when using KV Cache"
    print("Test passed!")
if __name__ == "__main__":
    main()
相关推荐
Ning_.6 小时前
LeetCode 349题解:两个数组的交集
数据结构·算法·leetcode
xiao--xin6 小时前
LeetCode100之全排列(46)--Java
java·算法·leetcode·回溯
_周游6 小时前
【数据结构】_链表经典算法OJ(力扣版)
数据结构·leetcode·链表
滨HI06 小时前
18. 四数之和【力扣】——两层循环后的双指针法
数据结构·c++·算法·leetcode·职场和发展
大模型之路9 小时前
LLM幻觉(Hallucination)缓解技术综述与展望
llm·rag·检索增强生成·llm幻觉·hallucination
xiaoshiguang39 小时前
LeetCode:738.单调自增的数字
java·算法·leetcode
KpLn_HJL10 小时前
leetcode - 126. Word Ladder II
leetcode·c#·word
萌の鱼12 小时前
leetcode 930. 和相同的二元子数组
数据结构·c++·算法·leetcode
夏末秋也凉13 小时前
力扣-链表-19 删除链表倒数第N个节点
leetcode·链表
珊瑚里的鱼17 小时前
【单链表算法实战】解锁数据结构核心谜题——相交链表
c语言·数据结构·程序人生·算法·leetcode·链表