【llm对话系统】大模型源码分析之 LLaMA 模型的 Masked Attention

在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型"偷看"未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。

1. 什么是 Masked Attention

1.1 为什么需要 Mask

在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。

1.2 Mask 的类型

Mask 主要分为两种类型:

  1. Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  2. Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。

2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask

2.1 LLaMA 的实现逻辑

LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:

  1. 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  2. 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
  3. 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
  4. 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
  5. 计算加权值向量: 使用注意力权重对值向量进行加权求和。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class LlamaAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # 线性变换
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # 线性变换
        q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 应用 Mask
        if mask is not None:
          scores = scores.masked_fill(mask==0, float('-inf'))

        # 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)

        # 计算加权值向量
        attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 输出线性变换
        output = self.Wo(attn_output)
        return output


def generate_causal_mask(seq_len):
    mask = torch.ones((seq_len, seq_len), dtype=torch.bool).triu(diagonal=1)
    return mask.bool()

# 示例
import math
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
causal_mask = generate_causal_mask(seq_len)

output = attention_layer(input_tensor, mask=causal_mask)
print("Output Shape:", output.shape) # 输出: torch.Size([2, 10, 512])

代码解释:

  1. LlamaAttention 类:
    • 初始化线性变换层 Wq, Wk, Wv, 和 Wo
    • forward 方法中,首先对输入进行线性变换,并进行多头分割。
    • 计算注意力分数,将 qk 进行点积运算,并进行缩放。
    • 应用 Causal Mask。
    • 使用 Softmax 对分数进行归一化,得到注意力权重。
    • 使用注意力权重对 v 进行加权求和。
  2. generate_causal_mask 函数:
    • 生成一个下三角矩阵,其中对角线以上的位置为 True(即需要 Mask 的位置)。
    • 将 Mask 返回为布尔类型,方便后续使用 masked_fill 函数进行填充。
  3. 示例:
    • 使用随机的输入张量,构造 causal_mask。
    • 调用注意力层,得到输出。

2.3 Mask 应用细节

在代码中,我们使用了 scores.masked_fill(mask == 0, float('-inf')) 来应用 Mask。masked_fill 函数会将 mask 中为 False (也就是需要mask的位置) 的位置填充为 -inf。在 Softmax 计算时, -inf 将会被转换为 0,从而有效地屏蔽了未来的信息。

3. 与其他大模型 Masked Attention 方案的对比

3.1 GPT 系列模型

GPT 系列模型也使用 Causal Mask,其实现方式与 LLaMA 类似。主要区别在于:

  • 实现方式: GPT 系列模型通常使用 torch.triu() 函数来生成上三角 Mask,然后使用 masked_fill 函数填充。
  • 结构: GPT 模型主要使用单向 Transformer 结构,而 LLaMA 模型使用双向 Transformer 结构(encoder-decoder 结构)。

3.2 BERT 系列模型

BERT 系列模型主要用于理解任务,使用了双向注意力机制。BERT 使用两种 Mask:

  • Padding Mask: 用于处理变长序列,屏蔽 padding 部分的影响。
  • Attention Mask (随机mask): 在预训练阶段,随机 mask 输入序列中的一部分词,让模型预测被屏蔽的词。

3.3 对比总结

模型 Mask 类型 Mask 实现方式 适用场景
LLaMA Causal Mask masked_fill 自回归生成
GPT 系列 Causal Mask torch.triu() + masked_fill 自回归生成
BERT 系列 Padding Mask & Attention Mask masked_fill 理解任务

4. 训练与推理时的 Mask

4.1 训练时

在训练阶段,我们会为每个输入序列都生成相应的 Causal Mask。Mask 的形状取决于输入序列的长度,确保模型只能看到当前位置之前的输入。

4.2 推理时

在推理阶段(生成文本时),我们需要动态更新 Mask。每生成一个新词,我们都会追加到当前序列,并为新的序列生成相应的 Causal Mask。LLaMA 模型为了提升推理效率,做了很多优化,比如KV Cache,增量式的更新mask,加速推理。

5. 总结

本文深入分析了 LLaMA 模型中 Masked Attention 的实现逻辑,并对比了其他类型大模型的 Masked Attention 方案。通过了解 Mask 的原理和具体实现,我们能更好地理解自回归模型的工作方式。希望本文能帮助你更好地理解大模型中的注意力机制!

6. 参考资料

相关推荐
忆~遂愿1 小时前
3大关键点教你用Java和Spring Boot快速构建微服务架构:从零开发到高效服务注册与发现的逆袭之路
java·人工智能·spring boot·深度学习·机器学习·spring cloud·eureka
纠结哥_Shrek1 小时前
pytorch逻辑回归实现垃圾邮件检测
人工智能·pytorch·逻辑回归
辞落山1 小时前
自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
人工智能·pytorch·逻辑回归
eybk2 小时前
Qpython+Flask监控添加发送语音中文信息功能
后端·python·flask
天宇琪云2 小时前
关于opencv环境搭建问题:由于找不到opencv_worldXXX.dll,无法执行代码,重新安装程序可能会解决此问题
人工智能·opencv·计算机视觉
大模型之路2 小时前
大模型(LLM)工程师实战之路(含学习路线图、书籍、课程等免费资料推荐)
人工智能·大模型·llm
weixin_307779133 小时前
Spark Streaming的背压机制的原理与实现代码及分析
大数据·python·spark
deephub3 小时前
十大主流联邦学习框架:技术特性、架构分析与对比研究
人工智能·python·深度学习·机器学习·联邦学习
英国翰思教育4 小时前
留学毕业论文如何利用不同问题设计问卷
人工智能·深度学习·学习·算法·学习方法·论文笔记
西猫雷婶4 小时前
python学opencv|读取图像(四十七)使用cv2.bitwise_not()函数实现图像按位取反运算
开发语言·python·opencv