Minimind项目源码解析(3)Attention模块(核心)

Attention机制代码详细解析

既然大家开始看LLM相关了内容了,那么大家一定对attention机制有了一定的了解,在此我就不对attention机制进行过于细致的讲解了,在此主要讲解一些具体实现和一些扩展

attention机制简要讲解

在大语言模型里,我们本质上是让 LLM 学习 token 与 token 之间的依赖关系,而提取这种关系的核心,正是 Attention 注意力机制。

Attention 的计算主要依靠三个关键矩阵 ------ 也就是大家熟知的 Q、K、V 矩阵:

Q(Query)查询向量:代表当前 token 想要 "查找" 什么信息

K(Key)键向量:代表每个 token 自身携带的 "特征标识"

V(Value)值向量:代表每个 token 真正要传递出去的信息内容

具体过程可以简单理解为:

用当前 token 的 Q,去和其他 token 的 K 做点积运算,算出它们之间的相似度 / 关联强度。 在 decoder 里,为了保证生成顺序合法,我们只会让当前 token 关注它之前出现的 token,而看不到未来的 token。 再用这个关联权重去加权求和所有对应的 V,最终得到的结果,就是当前 token 从所有相关 token 那里收集到的有效信息。 这就是注意力机制最核心的原理。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> </math>

这里的softmax是为了将所有值变成一个概率分布,没学过可以在B站上搜一下,讲解非常多。

话不多说,直接上代码,我会在代码中进行比较详细的讲解。

attention的具体实现代码

不过这里用到的,其实是多头注意力(Multi-Head Attention)。

肯定有人会问:什么是 "头"?

其实非常直观:

我们把 token 映射成 Q、K、V 时用到的一组线性变换矩阵 W,就可以看作一个注意力头。而多头注意力,就是同时使用多组这样的独立矩阵,也就是多个 "头"。 这么多头有什么用?

举个简单例子: 假设一个 token 向量是 16 维,我们用 4 个注意力头,就可以把这 16 维平均分成 4 份,每份 4 维,让每个注意力头只负责其中一小段维度。

这样一来:

不同的头可以关注不同的语义信息

有的头关注语法结构

有的头关注语义关联

有的头捕捉局部依赖

最后再把所有头的结果拼接起来 模型就能更细、更全面、更丰富地提取 token 之间的语义关系。 此外,这里为了节省参数,我们将KV头进行复用(GQA:分组查询注意力) 怎么理解呢? 还是假设一个token有16个维度,我们要4个头,也就是4个Q头,4个K头,4个V头。也就是将这16个维度分成了4个小组。

但是这样太耗费参数了,怎么办呢? 我们可以把这4个小组再组合一下,分成2个大组,不同小组之间使用不同的Q,不同大组之间使用不同的KV,这样我们就能减少参数的使用了! 具体可以看一下代码的注释

python 复制代码
class Attention(nn.Module):
    def __init__(self, args: MiniMindConfig):
        super().__init__()
        self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
        assert args.num_attention_heads % self.num_key_value_heads == 0 

        self.n_local_heads = args.num_attention_heads #我们需要的Q注意力头数
        self.n_local_kv_heads = self.num_key_value_heads #我们需要的K头和V头的数量
        self.n_rep = self.n_local_heads // self.n_local_kv_heads #每个大组中的小组数量
        self.head_dim = args.hidden_size // args.num_attention_heads #每个头需要关注的维度数量
        self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)#Wq,这里我们将多个头拼接在一起,下面同理。
        self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)#输出矩阵,用来将计算出来的注意力分数进行整合输出。
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)==0=
        self.dropout = args.dropout
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn #flash Attention 建议自己学习一下,这里不细讲
        # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")

    def forward(self,
                x: torch.Tensor,
                position_embeddings: Tuple[torch.Tensor, torch.Tensor],  # 修改为接收cos和sin
                past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
                use_cache=False,
                attention_mask: Optional[torch.Tensor] = None):
        
        bsz, seq_len, _ = x.shape
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)


        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)#重塑一下尺寸

        cos, sin = position_embeddings
        xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)#应用RoPE旋转编码,具体可以看我讲RoPE的文章

        # kv_cache实现
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)
        past_kv = (xk, xv) if use_cache else None
        #为什么要有KV cache呢,考虑一下LLM推理时的原理,LLM的推理是自回归的,也就是每次会生成一个token,然后再根据该token返回查询与以前所有token的关系,从而生成下一个token,那么如果我们不保存之前token的xk和xv,我们在每次查询的时候都要再算一遍,那如果把之前的token的k和v全部保存下来,我们是不是就不用再算了,也就大大加快了推理速度。

        xq, xk, xv = (
            xq.transpose(1, 2),
            repeat_kv(xk, self.n_rep).transpose(1, 2),
            repeat_kv(xv, self.n_rep).transpose(1, 2)
        )
        '''为什么要这样呢,自己用笔写一下,如果不换维度,那么最后乘出来后两维是head_dim * head_dim的,因为我们要的其实是token和token之间的关系,所以我们要的实际上是seq_len * seq_len,我们通过维度的变换,就能够做到这样的结果了。'''
        
        if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
            output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
            #flash_attn 自己学一下,我也不太会
        else:
            scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) 
            '''公式,除以根号是为了防止数值过大,如果不这样的话,假设有一个很大的数字,1e9,那么他经过softmax后的权重就是1,会导致梯度消失'''
            scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
            '''因果注意力掩码   
               会生成一个下三角矩阵,从而让每个token只能关注到自己前面的token'''

            if attention_mask is not None:
                extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
                extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
                scores = scores + extended_attention_mask

            '''attention_mask可以忽略掉某些无用的token,比如为了将一句话填充为seq_len长度的padding(填充),他只是为了充长度的,没有实际作用,我们就给他去除'''

            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = scores @ xv

        output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
        output = self.resid_dropout(self.o_proj(output))#复原并输出
        return output, past_kv
相关推荐
拳打南山敬老院1 小时前
你的 Agent 可能并不需要过度工程化:一次从 LangGraph 到极简 Agent 的架构反思
人工智能·设计模式
Halo咯咯1 小时前
从 Vibe Coder 到 AI 工程师,差的就是这 15 个概念
人工智能
Mintopia1 小时前
Gemini 的发展之道:从多模态模型演进到工程化落地的技术路径
人工智能
童话名剑1 小时前
YOLO v1(学习笔记)
人工智能·深度学习·yolo·目标检测
洞见前行1 小时前
AI Agent 的外部连接层:MCP 协议原理、机制设计与实战开发
人工智能
陈广亮1 小时前
当 AI Agent 学会付钱:x402 协议与 Agent 支付基础设施全解析
人工智能
廋到被风吹走1 小时前
持续学习方向 AI工程化(TensorFlow Serving、MLflow)
人工智能·学习·tensorflow
Once_day1 小时前
AI实践(0)学习路线
人工智能·学习·ai实践
数据与后端架构提升之路1 小时前
论大模型应用架构(RAG/Agent)的设计与应用——以自动驾驶数据闭环平台为例
人工智能·架构·自动驾驶