【llm对话系统】大模型源码分析之 LLaMA 模型的 Masked Attention

在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型"偷看"未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。

1. 什么是 Masked Attention

1.1 为什么需要 Mask

在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。

1.2 Mask 的类型

Mask 主要分为两种类型:

  1. Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  2. Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。

2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask

2.1 LLaMA 的实现逻辑

LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:

  1. 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  2. 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
  3. 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
  4. 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
  5. 计算加权值向量: 使用注意力权重对值向量进行加权求和。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class LlamaAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # 线性变换
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # 线性变换
        q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 应用 Mask
        if mask is not None:
          scores = scores.masked_fill(mask==0, float('-inf'))

        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)

        # 计算加权值向量
        attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 输出线性变换
        output = self.Wo(attn_output)
        return output


def generate_causal_mask(seq_len):
    mask = torch.ones((seq_len, seq_len), dtype=torch.bool).triu(diagonal=1)
    return mask.bool()

# 示例
import math
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
causal_mask = generate_causal_mask(seq_len)

output = attention_layer(input_tensor, mask=causal_mask)
print("Output Shape:", output.shape) # 输出: torch.Size([2, 10, 512])

代码解释:

  1. LlamaAttention 类:
    • 初始化线性变换层 Wq, Wk, Wv, 和 Wo
    • forward 方法中,首先对输入进行线性变换,并进行多头分割。
    • 计算注意力分数,将 qk 进行点积运算,并进行缩放。
    • 应用 Causal Mask。
    • 使用 Softmax 对分数进行归一化,得到注意力权重。
    • 使用注意力权重对 v 进行加权求和。
  2. generate_causal_mask 函数:
    • 生成一个下三角矩阵,其中对角线以上的位置为 True(即需要 Mask 的位置)。
    • 将 Mask 返回为布尔类型,方便后续使用 masked_fill 函数进行填充。
  3. 示例:
    • 使用随机的输入张量,构造 causal_mask。
    • 调用注意力层,得到输出。

2.3 Mask 应用细节

在代码中,我们使用了 scores.masked_fill(mask == 0, float('-inf')) 来应用 Mask。masked_fill 函数会将 mask 中为 False (也就是需要mask的位置) 的位置填充为 -inf。在 Softmax 计算时, -inf 将会被转换为 0,从而有效地屏蔽了未来的信息。

3. 与其他大模型 Masked Attention 方案的对比

3.1 GPT 系列模型

GPT 系列模型也使用 Causal Mask,其实现方式与 LLaMA 类似。主要区别在于:

  • 实现方式: GPT 系列模型通常使用 torch.triu() 函数来生成上三角 Mask,然后使用 masked_fill 函数填充。
  • 结构: GPT 模型主要使用单向 Transformer 结构,而 LLaMA 模型使用双向 Transformer 结构(encoder-decoder 结构)。

3.2 BERT 系列模型

BERT 系列模型主要用于理解任务,使用了双向注意力机制。BERT 使用两种 Mask:

  • Padding Mask: 用于处理变长序列,屏蔽 padding 部分的影响。
  • Attention Mask (随机mask): 在预训练阶段,随机 mask 输入序列中的一部分词,让模型预测被屏蔽的词。

3.3 对比总结

模型 Mask 类型 Mask 实现方式 适用场景
LLaMA Causal Mask masked_fill 自回归生成
GPT 系列 Causal Mask torch.triu() + masked_fill 自回归生成
BERT 系列 Padding Mask & Attention Mask masked_fill 理解任务

4. 训练与推理时的 Mask

4.1 训练时

在训练阶段,我们会为每个输入序列都生成相应的 Causal Mask。Mask 的形状取决于输入序列的长度,确保模型只能看到当前位置之前的输入。

4.2 推理时

在推理阶段(生成文本时),我们需要动态更新 Mask。每生成一个新词,我们都会追加到当前序列,并为新的序列生成相应的 Causal Mask。LLaMA 模型为了提升推理效率,做了很多优化,比如KV Cache,增量式的更新mask,加速推理。

5. 总结

本文深入分析了 LLaMA 模型中 Masked Attention 的实现逻辑,并对比了其他类型大模型的 Masked Attention 方案。通过了解 Mask 的原理和具体实现,我们能更好地理解自回归模型的工作方式。希望本文能帮助你更好地理解大模型中的注意力机制!

6. 参考资料

相关推荐
冬天给予的预感43 分钟前
DAY 54 Inception网络及其思考
网络·python·深度学习
说私域1 小时前
互联网生态下赢家群体的崛起与“开源AI智能名片链动2+1模式S2B2C商城小程序“的赋能效应
人工智能·小程序·开源
钢铁男儿1 小时前
PyQt5高级界而控件(容器:装载更多的控件QDockWidget)
数据库·python·qt
董厂长4 小时前
langchain :记忆组件混淆概念澄清 & 创建Conversational ReAct后显示指定 记忆组件
人工智能·深度学习·langchain·llm
亿牛云爬虫专家5 小时前
Kubernetes下的分布式采集系统设计与实战:趋势监测失效引发的架构进化
分布式·python·架构·kubernetes·爬虫代理·监测·采集
墨风如雪7 小时前
告别“面目全非”!腾讯混元3D变身“建模艺术家”,建模效率直接起飞!
aigc
G皮T8 小时前
【人工智能】ChatGPT、DeepSeek-R1、DeepSeek-V3 辨析
人工智能·chatgpt·llm·大语言模型·deepseek·deepseek-v3·deepseek-r1
九年义务漏网鲨鱼8 小时前
【大模型学习 | MINIGPT-4原理】
人工智能·深度学习·学习·语言模型·多模态
元宇宙时间8 小时前
Playfun即将开启大型Web3线上活动,打造沉浸式GameFi体验生态
人工智能·去中心化·区块链
开发者工具分享8 小时前
文本音频违规识别工具排行榜(12选)
人工智能·音视频