【llm对话系统】大模型源码分析之 LLaMA 模型的 Masked Attention

在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型"偷看"未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。

1. 什么是 Masked Attention

1.1 为什么需要 Mask

在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。

1.2 Mask 的类型

Mask 主要分为两种类型:

  1. Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  2. Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。

2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask

2.1 LLaMA 的实现逻辑

LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:

  1. 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  2. 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
  3. 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
  4. 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
  5. 计算加权值向量: 使用注意力权重对值向量进行加权求和。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class LlamaAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # 线性变换
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # 线性变换
        q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 应用 Mask
        if mask is not None:
          scores = scores.masked_fill(mask==0, float('-inf'))

        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)

        # 计算加权值向量
        attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 输出线性变换
        output = self.Wo(attn_output)
        return output


def generate_causal_mask(seq_len):
    mask = torch.ones((seq_len, seq_len), dtype=torch.bool).triu(diagonal=1)
    return mask.bool()

# 示例
import math
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
causal_mask = generate_causal_mask(seq_len)

output = attention_layer(input_tensor, mask=causal_mask)
print("Output Shape:", output.shape) # 输出: torch.Size([2, 10, 512])

代码解释:

  1. LlamaAttention 类:
    • 初始化线性变换层 Wq, Wk, Wv, 和 Wo
    • forward 方法中,首先对输入进行线性变换,并进行多头分割。
    • 计算注意力分数,将 qk 进行点积运算,并进行缩放。
    • 应用 Causal Mask。
    • 使用 Softmax 对分数进行归一化,得到注意力权重。
    • 使用注意力权重对 v 进行加权求和。
  2. generate_causal_mask 函数:
    • 生成一个下三角矩阵,其中对角线以上的位置为 True(即需要 Mask 的位置)。
    • 将 Mask 返回为布尔类型,方便后续使用 masked_fill 函数进行填充。
  3. 示例:
    • 使用随机的输入张量,构造 causal_mask。
    • 调用注意力层,得到输出。

2.3 Mask 应用细节

在代码中,我们使用了 scores.masked_fill(mask == 0, float('-inf')) 来应用 Mask。masked_fill 函数会将 mask 中为 False (也就是需要mask的位置) 的位置填充为 -inf。在 Softmax 计算时, -inf 将会被转换为 0,从而有效地屏蔽了未来的信息。

3. 与其他大模型 Masked Attention 方案的对比

3.1 GPT 系列模型

GPT 系列模型也使用 Causal Mask,其实现方式与 LLaMA 类似。主要区别在于:

  • 实现方式: GPT 系列模型通常使用 torch.triu() 函数来生成上三角 Mask,然后使用 masked_fill 函数填充。
  • 结构: GPT 模型主要使用单向 Transformer 结构,而 LLaMA 模型使用双向 Transformer 结构(encoder-decoder 结构)。

3.2 BERT 系列模型

BERT 系列模型主要用于理解任务,使用了双向注意力机制。BERT 使用两种 Mask:

  • Padding Mask: 用于处理变长序列,屏蔽 padding 部分的影响。
  • Attention Mask (随机mask): 在预训练阶段,随机 mask 输入序列中的一部分词,让模型预测被屏蔽的词。

3.3 对比总结

模型 Mask 类型 Mask 实现方式 适用场景
LLaMA Causal Mask masked_fill 自回归生成
GPT 系列 Causal Mask torch.triu() + masked_fill 自回归生成
BERT 系列 Padding Mask & Attention Mask masked_fill 理解任务

4. 训练与推理时的 Mask

4.1 训练时

在训练阶段,我们会为每个输入序列都生成相应的 Causal Mask。Mask 的形状取决于输入序列的长度,确保模型只能看到当前位置之前的输入。

4.2 推理时

在推理阶段(生成文本时),我们需要动态更新 Mask。每生成一个新词,我们都会追加到当前序列,并为新的序列生成相应的 Causal Mask。LLaMA 模型为了提升推理效率,做了很多优化,比如KV Cache,增量式的更新mask,加速推理。

5. 总结

本文深入分析了 LLaMA 模型中 Masked Attention 的实现逻辑,并对比了其他类型大模型的 Masked Attention 方案。通过了解 Mask 的原理和具体实现,我们能更好地理解自回归模型的工作方式。希望本文能帮助你更好地理解大模型中的注意力机制!

6. 参考资料

相关推荐
徐赛俊17 分钟前
# 自动定时运行Python爬虫脚本教程(Windows任务计划程序)
windows·爬虫·python
GitLqr17 分钟前
AI洞察 | 好酷!国产模型在 电影、3D、TTS 领域取得巨大进步!
aigc·ai编程·虚拟现实
暴躁的大熊24 分钟前
LLM大模型时代:生活服务领域的“生存革命“与新生态重构
人工智能
程序员秘密基地1 小时前
基于html,css,jquery,django,lstm,cnn,tensorflow,bert,推荐算法,mysql数据库
python·cnn·tensorflow·lstm·推荐算法
Blossom.1181 小时前
基于深度学习的医学图像分析:使用MobileNet实现医学图像分类
人工智能·深度学习·yolo·机器学习·分类·数据挖掘·迁移学习
德育处主任1 小时前
「豆包」加「PromptPilot」等于「优秀员工」
人工智能·llm·aigc
字节跳动安全中心1 小时前
猎影计划:从密流中捕获 Cobalt Strike 的隐秘身影
人工智能·安全·llm
技术炼丹人1 小时前
从RNN为什么长依赖遗忘到注意力机制的解决方案以及并行
人工智能·python·算法
FreeBuf_2 小时前
AI Agents漏洞百出,恶意提示等安全缺陷令人担忧
人工智能·安全
hqxstudying2 小时前
Java开发时出现的问题---语言特性与基础机制陷阱
java·jvm·python