【llm对话系统】大模型源码分析之 LLaMA 模型的 Masked Attention

在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型"偷看"未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。

1. 什么是 Masked Attention

1.1 为什么需要 Mask

在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。

1.2 Mask 的类型

Mask 主要分为两种类型:

  1. Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  2. Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。

2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask

2.1 LLaMA 的实现逻辑

LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:

  1. 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  2. 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
  3. 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
  4. 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
  5. 计算加权值向量: 使用注意力权重对值向量进行加权求和。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class LlamaAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # 线性变换
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # 线性变换
        q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 应用 Mask
        if mask is not None:
          scores = scores.masked_fill(mask==0, float('-inf'))

        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)

        # 计算加权值向量
        attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 输出线性变换
        output = self.Wo(attn_output)
        return output


def generate_causal_mask(seq_len):
    mask = torch.ones((seq_len, seq_len), dtype=torch.bool).triu(diagonal=1)
    return mask.bool()

# 示例
import math
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
causal_mask = generate_causal_mask(seq_len)

output = attention_layer(input_tensor, mask=causal_mask)
print("Output Shape:", output.shape) # 输出: torch.Size([2, 10, 512])

代码解释:

  1. LlamaAttention 类:
    • 初始化线性变换层 Wq, Wk, Wv, 和 Wo
    • forward 方法中,首先对输入进行线性变换,并进行多头分割。
    • 计算注意力分数,将 qk 进行点积运算,并进行缩放。
    • 应用 Causal Mask。
    • 使用 Softmax 对分数进行归一化,得到注意力权重。
    • 使用注意力权重对 v 进行加权求和。
  2. generate_causal_mask 函数:
    • 生成一个下三角矩阵,其中对角线以上的位置为 True(即需要 Mask 的位置)。
    • 将 Mask 返回为布尔类型,方便后续使用 masked_fill 函数进行填充。
  3. 示例:
    • 使用随机的输入张量,构造 causal_mask。
    • 调用注意力层,得到输出。

2.3 Mask 应用细节

在代码中,我们使用了 scores.masked_fill(mask == 0, float('-inf')) 来应用 Mask。masked_fill 函数会将 mask 中为 False (也就是需要mask的位置) 的位置填充为 -inf。在 Softmax 计算时, -inf 将会被转换为 0,从而有效地屏蔽了未来的信息。

3. 与其他大模型 Masked Attention 方案的对比

3.1 GPT 系列模型

GPT 系列模型也使用 Causal Mask,其实现方式与 LLaMA 类似。主要区别在于:

  • 实现方式: GPT 系列模型通常使用 torch.triu() 函数来生成上三角 Mask,然后使用 masked_fill 函数填充。
  • 结构: GPT 模型主要使用单向 Transformer 结构,而 LLaMA 模型使用双向 Transformer 结构(encoder-decoder 结构)。

3.2 BERT 系列模型

BERT 系列模型主要用于理解任务,使用了双向注意力机制。BERT 使用两种 Mask:

  • Padding Mask: 用于处理变长序列,屏蔽 padding 部分的影响。
  • Attention Mask (随机mask): 在预训练阶段,随机 mask 输入序列中的一部分词,让模型预测被屏蔽的词。

3.3 对比总结

模型 Mask 类型 Mask 实现方式 适用场景
LLaMA Causal Mask masked_fill 自回归生成
GPT 系列 Causal Mask torch.triu() + masked_fill 自回归生成
BERT 系列 Padding Mask & Attention Mask masked_fill 理解任务

4. 训练与推理时的 Mask

4.1 训练时

在训练阶段,我们会为每个输入序列都生成相应的 Causal Mask。Mask 的形状取决于输入序列的长度,确保模型只能看到当前位置之前的输入。

4.2 推理时

在推理阶段(生成文本时),我们需要动态更新 Mask。每生成一个新词,我们都会追加到当前序列,并为新的序列生成相应的 Causal Mask。LLaMA 模型为了提升推理效率,做了很多优化,比如KV Cache,增量式的更新mask,加速推理。

5. 总结

本文深入分析了 LLaMA 模型中 Masked Attention 的实现逻辑,并对比了其他类型大模型的 Masked Attention 方案。通过了解 Mask 的原理和具体实现,我们能更好地理解自回归模型的工作方式。希望本文能帮助你更好地理解大模型中的注意力机制!

6. 参考资料

相关推荐
--fancy1 小时前
股票预测情感分析研究案例分析
python
shughui1 小时前
PyCharm 完整教程(旧版本卸载+旧/新版本下载安装+基础使用,2026最新版附安装包)
ide·python·pycharm
AI机器学习算法1 小时前
深度学习模型演进:6个里程碑式CNN架构
人工智能·深度学习·cnn·大模型·ai学习路线
Ztopcloud极拓云视角2 小时前
从 OpenRouter 数据看中美 AI 调用量反转:统计口径、模型路由与多云应对方案
人工智能·阿里云·大模型·token·中美ai
AI医影跨模态组学2 小时前
如何将深度学习MTSR与膀胱癌ITGB8/TGF-β/WNT机制建立关联,并进一步解释其与患者预后及肿瘤侵袭、免疫抑制的生物学联系
人工智能·深度学习·论文·医学影像
小糖学代码2 小时前
LLM系列:1.python入门:15.JSON 数据处理与操作
开发语言·python·json·aigc
yejqvow122 小时前
CSS如何控制placeholder文字的颜色_使用--placeholder伪元素
jvm·数据库·python
搬砖的前端2 小时前
AI编辑器开源主模型搭配本地模型辅助对标GPT5.2/GPT5.4/Claude4.6(前端开发专属)
人工智能·开源·claude·mcp·trae·qwen3.6·ops4.6
m0_743623922 小时前
HTML怎么创建多语言切换器_HTML语言选择下拉结构【指南】
jvm·数据库·python
pele2 小时前
Angular 表单中基于下拉选择动态启用字段必填校验的完整实现
jvm·数据库·python