在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型"偷看"未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。
1. 什么是 Masked Attention
1.1 为什么需要 Mask
在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。
1.2 Mask 的类型
Mask 主要分为两种类型:
- Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
- Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。
2. LLaMA 中的 Masked Attention
LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask。
2.1 LLaMA 的实现逻辑
LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:
- 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
- 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
- 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
- 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
- 计算加权值向量: 使用注意力权重对值向量进行加权求和。
2.2 LLaMA 源码示例 (PyTorch)
以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class LlamaAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# 线性变换
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.size()
# 线性变换
q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 应用 Mask
if mask is not None:
scores = scores.masked_fill(mask==0, float('-inf'))
# 计算注意力权重
attn_weights = F.softmax(scores, dim=-1)
# 计算加权值向量
attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
# 输出线性变换
output = self.Wo(attn_output)
return output
def generate_causal_mask(seq_len):
mask = torch.ones((seq_len, seq_len), dtype=torch.bool).triu(diagonal=1)
return mask.bool()
# 示例
import math
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2
attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
causal_mask = generate_causal_mask(seq_len)
output = attention_layer(input_tensor, mask=causal_mask)
print("Output Shape:", output.shape) # 输出: torch.Size([2, 10, 512])
代码解释:
LlamaAttention
类:- 初始化线性变换层
Wq
,Wk
,Wv
, 和Wo
。 - 在
forward
方法中,首先对输入进行线性变换,并进行多头分割。 - 计算注意力分数,将
q
和k
进行点积运算,并进行缩放。 - 应用 Causal Mask。
- 使用 Softmax 对分数进行归一化,得到注意力权重。
- 使用注意力权重对
v
进行加权求和。
- 初始化线性变换层
generate_causal_mask
函数:- 生成一个下三角矩阵,其中对角线以上的位置为
True
(即需要 Mask 的位置)。 - 将 Mask 返回为布尔类型,方便后续使用
masked_fill
函数进行填充。
- 生成一个下三角矩阵,其中对角线以上的位置为
- 示例:
- 使用随机的输入张量,构造 causal_mask。
- 调用注意力层,得到输出。
2.3 Mask 应用细节
在代码中,我们使用了 scores.masked_fill(mask == 0, float('-inf'))
来应用 Mask。masked_fill
函数会将 mask
中为 False
(也就是需要mask的位置) 的位置填充为 -inf
。在 Softmax 计算时, -inf
将会被转换为 0,从而有效地屏蔽了未来的信息。
3. 与其他大模型 Masked Attention 方案的对比
3.1 GPT 系列模型
GPT 系列模型也使用 Causal Mask,其实现方式与 LLaMA 类似。主要区别在于:
- 实现方式: GPT 系列模型通常使用
torch.triu()
函数来生成上三角 Mask,然后使用masked_fill
函数填充。 - 结构: GPT 模型主要使用单向 Transformer 结构,而 LLaMA 模型使用双向 Transformer 结构(encoder-decoder 结构)。
3.2 BERT 系列模型
BERT 系列模型主要用于理解任务,使用了双向注意力机制。BERT 使用两种 Mask:
- Padding Mask: 用于处理变长序列,屏蔽 padding 部分的影响。
- Attention Mask (随机mask): 在预训练阶段,随机 mask 输入序列中的一部分词,让模型预测被屏蔽的词。
3.3 对比总结
模型 | Mask 类型 | Mask 实现方式 | 适用场景 |
---|---|---|---|
LLaMA | Causal Mask | masked_fill |
自回归生成 |
GPT 系列 | Causal Mask | torch.triu() + masked_fill |
自回归生成 |
BERT 系列 | Padding Mask & Attention Mask | masked_fill |
理解任务 |
4. 训练与推理时的 Mask
4.1 训练时
在训练阶段,我们会为每个输入序列都生成相应的 Causal Mask。Mask 的形状取决于输入序列的长度,确保模型只能看到当前位置之前的输入。
4.2 推理时
在推理阶段(生成文本时),我们需要动态更新 Mask。每生成一个新词,我们都会追加到当前序列,并为新的序列生成相应的 Causal Mask。LLaMA 模型为了提升推理效率,做了很多优化,比如KV Cache,增量式的更新mask,加速推理。
5. 总结
本文深入分析了 LLaMA 模型中 Masked Attention 的实现逻辑,并对比了其他类型大模型的 Masked Attention 方案。通过了解 Mask 的原理和具体实现,我们能更好地理解自回归模型的工作方式。希望本文能帮助你更好地理解大模型中的注意力机制!