从零开始构建大型语言模型——实现注意力机制

本章内容:

  • 使用注意力机制的原因
  • 基本的自注意力框架,逐步深入到增强的自注意力机制
  • 允许LLMs逐个生成词元的因果注意力模块
  • 通过dropout随机屏蔽部分注意力权重以减少过拟合
  • 将多个因果注意力模块堆叠为多头注意力模块

到目前为止,你已经了解了如何通过将文本拆分为单词和子词词元来准备LLM的输入文本,并将其编码为向量表示(嵌入)。现在,我们将介绍LLM架构中的一个重要部分------注意力机制,如图3.1所示。我们将主要独立地研究注意力机制,并在机制层面深入探讨。然后,我们将编写围绕自注意力机制的LLM其他部分的代码,以观察其实际效果,并构建一个用于生成文本的模型。

我们将实现四种不同的注意力机制变体,如图3.2所示。这些不同的注意力变体是逐步构建的,目标是最终实现一个紧凑且高效的多头注意力机制实现,然后可以将其嵌入到我们将在下一章编写的LLM架构中。

处理长序列建模的问题

在深入探讨LLM核心的自注意力机制之前,让我们先考虑一下在没有注意力机制的传统架构中遇到的问题。假设我们想开发一个将文本从一种语言翻译为另一种语言的翻译模型。如图3.3所示,由于源语言和目标语言中的语法结构不同,我们无法简单地逐字逐句进行翻译。

为了解决这个问题,通常使用包含两个子模块的深度神经网络,即编码器和解码器。编码器的任务是首先读取并处理整个文本,而解码器随后生成翻译后的文本。

在Transformer模型出现之前,循环神经网络(RNN)是用于语言翻译的最流行的编码器-解码器架构。RNN是一种神经网络,它将前一步的输出作为当前步骤的输入,因此非常适合处理像文本这样的序列数据。如果你不熟悉RNN,也不必担心,理解它的详细工作原理并不是理解本讨论的必要条件;这里我们主要关注编码器-解码器结构的一般概念。

在编码器-解码器RNN中,输入文本按顺序输入编码器,编码器逐步处理输入。在每一步,编码器会更新其隐藏状态(隐藏层的内部值),尝试在最终的隐藏状态中捕捉输入句子的完整意义,如图3.4所示。然后,解码器使用这个最终的隐藏状态开始逐字生成翻译句子。解码器也会在每一步更新其隐藏状态,这个状态应携带下一步预测单词所需的上下文信息。

虽然我们不需要深入了解编码器-解码器RNN的内部工作原理,但关键的思想是,编码器部分将整个输入文本处理为一个隐藏状态(记忆单元),然后解码器使用这个隐藏状态生成输出。你可以将这个隐藏状态看作是一个嵌入向量,这个概念我们在第2章中讨论过。

编码器-解码器RNN的一个主要限制是,在解码阶段,RNN无法直接访问编码器中的早期隐藏状态。因此,它只能依赖当前的隐藏状态,这个状态包含所有相关的信息。这可能会导致上下文丢失,尤其是在复杂句子中,依赖关系可能跨越较长距离。

幸运的是,构建LLM并不需要深入理解RNN。只需记住,编码器-解码器RNN的这一缺点促使了注意力机制的设计。

通过注意力机制捕捉数据依赖关系

虽然RNN在翻译短句时表现良好,但在处理较长文本时效果不佳,因为它无法直接访问输入中的前面部分。这个方法的一个主要缺点是,RNN必须在一个隐藏状态中记住整个编码的输入,然后再将其传递给解码器(如图3.4所示)。

因此,研究人员在2014年为RNN开发了Bahdanau注意力机制(以论文的第一作者命名;详情见附录B),这种机制修改了编码器-解码器RNN,使得解码器在每个解码步骤都可以有选择地访问输入序列的不同部分,如图3.5所示。

有趣的是,仅仅三年后,研究人员发现构建自然语言处理的深度神经网络并不需要RNN架构,并提出了最初的Transformer架构(第1章讨论过),其中包括受Bahdanau注意力机制启发的自注意力机制。

自注意力是一种机制,它允许输入序列中的每个位置在计算序列表示时考虑所有其他位置的相关性,或"关注"同一序列中的所有其他位置。自注意力是基于Transformer架构的现代LLM(如GPT系列)的关键组件。

本章将重点介绍如何编写和理解GPT类模型中使用的自注意力机制,如图3.6所示。在下一章中,我们将编写LLM的其他部分代码。

通过自注意力机制关注输入的不同部分

现在我们将介绍自注意力机制的内部工作原理,并从头开始学习如何编写相关代码。自注意力是基于Transformer架构的每个LLM的核心基础。这个主题可能需要高度集中注意力(双关意图),但一旦掌握了它的基础,你将攻克本书和LLM实现中最具挑战性的部分之一。

自注意力中的"自我"

在自注意力中,"自我"指的是该机制能够通过关联单个输入序列中的不同位置来计算注意力权重。它评估并学习输入本身各个部分之间的关系和依赖性,例如句子中的单词或图像中的像素。

这与传统的注意力机制不同,传统注意力机制的重点在于两个不同序列的元素之间的关系,例如在序列到序列模型中,注意力可能存在于输入序列和输出序列之间,如图3.5所示的例子。

由于自注意力机制可能看起来复杂,尤其是当你第一次接触它时,我们将从一个简化版本开始进行讲解。接着,我们将实现LLM中使用的带有可训练权重的自注意力机制。

没有可训练权重的简单自注意力机制

我们首先来实现一个没有任何可训练权重的简化自注意力机制,如图3.7所示。这样做的目的是在引入可训练权重之前,先阐明自注意力机制中的一些关键概念。

图3.7显示了一个输入序列,记作x,由T个元素组成,表示为x(1)到x(T)。这个序列通常表示已转换为词元嵌入的文本,例如一个句子。

例如,考虑输入文本"Your journey starts with one step." 在这种情况下,序列中的每个元素(如x(1))对应一个d维嵌入向量,表示一个特定的词元,比如"Your"。图3.7显示了这些输入向量作为三维嵌入。

在自注意力机制中,我们的目标是为输入序列中的每个元素x(i)计算上下文向量z(i)。上下文向量可以解释为一个增强的嵌入向量。

为了说明这个概念,我们重点关注第二个输入元素x(2)(对应词元"journey")的嵌入向量及其对应的上下文向量z(2),如图3.7底部所示。这个增强的上下文向量z(2)是一个包含x(2)及所有其他输入元素x(1)到x(T)信息的嵌入。

上下文向量在自注意力机制中起着至关重要的作用。它们的目的是通过融合输入序列(如一个句子)中所有其他元素的信息,来创建每个元素的增强表示(图3.7)。这对LLMs至关重要,因为它们需要理解句子中词与词之间的关系和相关性。稍后,我们将添加可训练的权重,帮助LLM学习构建这些上下文向量,使它们与生成下一个词元相关。但首先,让我们实现一个简化的自注意力机制,一步步计算这些权重和相应的上下文向量。

考虑以下句子,它已经嵌入为三维向量(见第2章)。我选择了较小的嵌入维度以确保它可以在页面上展示而不换行:

ini 复制代码
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

实现自注意力机制的第一步是计算中间值w,即所谓的注意力分数,如图3.8所示。由于空间限制,图中显示了截断版的输入张量值;例如,0.87被截断为0.8。在这个截断版本中,"journey"和"starts"的嵌入可能由于随机原因看起来相似。

图3.8展示了如何计算查询词元与每个输入词元之间的中间注意力分数。我们通过计算查询词元 x(2)x(2)x(2) 与每个其他输入词元的点积来确定这些分数:

ini 复制代码
query = inputs[1]                            #1
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
  • #1 第二个输入词元作为查询词元。

计算出的注意力分数为:

scss 复制代码
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

理解点积

点积本质上是一种逐元素相乘并求和的简便方法,示例如下:

ini 复制代码
res = 0.
for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]
print(res)
print(torch.dot(inputs[0], query))

输出结果确认了逐元素相乘的和与点积相同:

scss 复制代码
tensor(0.9544)
tensor(0.9544)

除了将点积视为将两个向量组合为一个标量值的数学工具之外,点积也是衡量相似性的一种方式,因为它量化了两个向量之间的对齐程度:点积越大,表示向量之间的对齐或相似度越高。在自注意力机制的上下文中,点积决定了序列中的每个元素在多大程度上"关注"其他元素:点积越大,表示两个元素之间的相似性和注意力分数越高。

下一步,如图3.9所示,我们对之前计算的每个注意力分数进行归一化处理。归一化的主要目的是使注意力权重的总和为1。这个归一化步骤是一种有助于解释和保持LLM训练稳定性的惯例。以下是实现该归一化步骤的简单方法:

python 复制代码
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

输出结果为:

css 复制代码
Attention weights: tensor([0.1444, 0.2261, 0.2232, 0.1276, 0.1069, 0.1718])
Sum: tensor(1.)

输出显示,注意力权重的总和现在为1:

css 复制代码
Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)

在实际操作中,使用 softmax 函数进行归一化更为常见和可取。这种方法能够更好地处理极端值,并在训练过程中提供更有利的梯度特性。以下是一个简单的 softmax 函数实现,用于归一化注意力分数:

scss 复制代码
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

输出显示,softmax 函数也实现了目标,并使注意力权重的总和为1:

css 复制代码
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

此外,softmax 函数确保注意力权重始终为正数。这使得输出可以被解释为概率或相对重要性,权重越高表示重要性越大。

注意,这个简单的 softmax 实现(softmax_naive)可能在处理极大或极小的输入值时遇到数值不稳定性问题(如溢出和下溢)。因此,实际操作中建议使用PyTorch的 softmax 实现,该实现经过广泛优化以提高性能:

python 复制代码
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

这次结果与我们之前的 softmax_naive 函数相同:

css 复制代码
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)

现在我们已经计算了归一化的注意力权重,接下来我们进行最后一步,如图3.10所示:通过将嵌入的输入词元 x(i)x(i)x(i) 与相应的注意力权重相乘并对得到的向量求和来计算上下文向量 z(2)z(2)z(2)。因此,上下文向量 z(2)z(2)z(2) 是所有输入向量的加权和,即通过将每个输入向量乘以其对应的注意力权重得到:

ini 复制代码
query = inputs[1]         #1
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)
  • #1 第二个输入词元作为查询词元。

该计算的结果为:

scss 复制代码
tensor([0.4419, 0.6515, 0.5683])

接下来,我们将对计算上下文向量的过程进行泛化,以便同时计算所有的上下文向量。

为所有输入词元计算注意力权重

到目前为止,我们已经为输入2计算了注意力权重和上下文向量,如图3.11中高亮的行所示。现在,让我们扩展这个计算,计算所有输入的注意力权重和上下文向量。

我们将遵循之前的三个步骤(参见图3.12),只是对代码进行了一些修改,以计算所有的上下文向量,而不仅仅是第二个 z(2)z(2)z(2):

scss 复制代码
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

这个代码计算了每个输入之间的注意力分数,形成了一个 6×6 的矩阵,其中每个元素表示输入序列中两个词元之间的相似性。

得到的注意力分数如下:

css 复制代码
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

张量中的每个元素表示每对输入之间的注意力分数,如图3.11所示。注意该图中的值是经过归一化处理的,因此与上面的未归一化注意力分数不同。我们稍后会处理归一化。

在计算前面的注意力分数张量时,我们使用了Python中的for循环。然而,for循环通常比较慢,我们可以使用矩阵乘法来达到相同的效果:

ini 复制代码
attn_scores = inputs @ inputs.T
print(attn_scores)

我们可以直观地确认结果与之前相同:

css 复制代码
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 06654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

在图3.12的第二步中,我们对每一行进行归一化,使每行的值总和为1:

scss 复制代码
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

这将返回如下的注意力权重张量,与图3.10中的值相匹配:

css 复制代码
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

在使用PyTorch时,像torch.softmax这样的函数中的dim参数指定了函数将在输入张量的哪个维度上计算。通过将dim=-1设置为最后一个维度,我们让softmax函数沿着attn_scores张量的最后一维进行归一化。如果attn_scores是一个二维张量(例如形状为 [行, 列]),它将沿着列进行归一化,使每行的值(列维度上的求和)总和为1。

我们可以验证每行的总和确实为1:

python 复制代码
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

结果是:

sql 复制代码
Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

在图3.12的第三步也是最后一步中,我们使用这些注意力权重通过矩阵乘法计算所有上下文向量:

ini 复制代码
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

在输出的张量中,每行包含一个三维上下文向量:

css 复制代码
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

我们可以通过比较第二行与我们在第3.3.1节中计算的上下文向量 z(2)z(2)z(2) 来再次验证代码的正确性:

bash 复制代码
print("Previous 2nd context vector:", context_vec_2)

结果显示,之前计算的context_vec_2与上面张量的第二行完全匹配:

arduino 复制代码
Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])

至此,我们完成了简单自注意力机制的代码实现。接下来,我们将添加可训练权重,使LLM能够从数据中学习,并在特定任务上提高性能。

实现具有可训练权重的自注意力机制

我们的下一步是实现原始Transformer架构、GPT模型以及大多数其他流行LLM中使用的自注意力机制。这种自注意力机制也被称为缩放点积注意力。图3.13展示了这种自注意力机制如何嵌入到实现LLM的更广泛背景中。

如图3.13所示,带有可训练权重的自注意力机制基于之前的概念:我们希望计算上下文向量,作为针对特定输入元素的输入向量的加权和。你将会发现,这与我们之前编写的基本自注意力机制相比,只有细微的区别。

最显著的区别是引入了在模型训练过程中更新的权重矩阵。这些可训练的权重矩阵至关重要,因为模型(特别是模型中的注意力模块)需要通过它们学习生成"良好"的上下文向量。(我们将在第5章训练LLM。)

我们将在两个小节中处理这个自注意力机制。首先,我们将像之前一样逐步编写代码。其次,我们将把代码组织成一个紧凑的Python类,便于导入LLM架构中。

逐步计算注意力权重

我们将逐步实现自注意力机制,并引入三个可训练的权重矩阵 WqW_qWq​、WkW_kWk​ 和 WvW_vWv​。这三个矩阵分别用于将嵌入的输入词元 x(i)x(i)x(i) 映射到查询向量(query)、键向量(key)和值向量(value),如图3.14所示。

之前,我们在计算简化的注意力权重以得出上下文向量 z(2)z(2)z(2) 时,将第二个输入元素 x(2)x(2)x(2) 作为查询向量。随后,我们将其推广以计算六个单词组成的输入句子 "Your journey starts with one step" 的所有上下文向量 z(1)...z(T)z(1) \dots z(T)z(1)...z(T)。

同样地,为了便于说明,我们从计算一个上下文向量 z(2)z(2)z(2) 开始。之后,我们将修改代码以计算所有上下文向量。

首先,定义几个变量:

ini 复制代码
x_2 = inputs[1]     #1
d_in = inputs.shape[1]      #2
d_out = 2         #3
  • #1 第二个输入元素
  • #2 输入嵌入维度,d=3d = 3d=3
  • #3 输出嵌入维度,dout=2d_{out} = 2dout=2

请注意,在类似GPT的模型中,输入和输出的维度通常是相同的,但为了更好地理解计算过程,这里我们使用不同的输入维度 din=3d_{in} = 3din​=3 和输出维度 dout=2d_{out} = 2dout​=2。

接下来,我们初始化三个权重矩阵 WqW_qWq​、WkW_kWk​ 和 WvW_vWv​,如图3.14所示:

ini 复制代码
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

我们将 requires_grad=False 设置为关闭梯度计算以简化输出,但如果我们在模型训练中使用这些权重矩阵,我们会将 requires_grad=True 以便在训练过程中更新这些矩阵。

接下来,我们计算查询向量、键向量和值向量:

ini 复制代码
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
print(query_2)

查询向量的输出结果是一个二维向量,因为我们通过 doutd_{out}dout​ 将相应的权重矩阵的列数设置为2:

scss 复制代码
tensor([0.4306, 1.4551])

权重参数与注意力权重

在权重矩阵 WWW 中,"权重"是"权重参数"的缩写,指的是在训练过程中优化的神经网络值。这与注意力权重不同。正如我们之前看到的,注意力权重决定了上下文向量在多大程度上依赖于输入的不同部分(即网络在多大程度上关注输入的不同部分)。

总之,权重参数是定义网络连接的基本学习系数,而注意力权重则是动态的、特定于上下文的值。

尽管我们当前的目标只是计算一个上下文向量 z(2)z(2)z(2),但我们仍然需要所有输入元素的键向量和值向量,因为它们参与计算与查询 q(2)q(2)q(2) 相关的注意力权重(见图3.14)。

我们可以通过矩阵乘法获得所有键和值:

perl 复制代码
keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

从输出中可以看出,我们成功地将六个输入词元从三维投影到了二维嵌入空间:

css 复制代码
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])

第二步是计算注意力分数,如图3.15所示。

首先,让我们计算注意力分数 ω22\omega_{22}ω22​:

ini 复制代码
keys_2 = keys[1]             #1
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)
  • #1 请记住,Python的索引从0开始。

未归一化的注意力分数结果为:

scss 复制代码
tensor(1.8524)

同样,我们可以通过矩阵乘法将此计算推广到所有注意力分数:

ini 复制代码
attn_scores_2 = query_2 @ keys.T       #1
print(attn_scores_2)
  • #1 针对给定查询的所有注意力分数。

可以看到,作为快速检查,输出中的第二个元素与我们之前计算的 attn_score_22\text{attn_score_22}attn_score_22 匹配:

scss 复制代码
tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

现在,我们希望将注意力分数转换为注意力权重,如图3.16所示。我们通过缩放注意力分数并使用 softmax 函数来计算注意力权重。然而,这次我们通过将注意力分数除以键的嵌入维度的平方根进行缩放(取平方根在数学上与指数为0.5相同):

ini 复制代码
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

得到的注意力权重为:

scss 复制代码
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

缩放点积注意力的原理

通过嵌入维度大小进行归一化的原因是为了提高训练性能,避免小梯度的出现。例如,在GPT类的LLM中,嵌入维度通常大于1000,缩放后的大点积在反向传播过程中可能导致非常小的梯度,这是由于对它们应用的softmax函数所致。随着点积的增加,softmax函数的表现越来越像一个阶跃函数,导致梯度接近于零。这些小梯度可能会极大地减缓学习速度,或导致训练停滞。

通过嵌入维度的平方根进行缩放是自注意力机制被称为缩放点积注意力的原因。

现在,最后一步是计算上下文向量,如图3.17所示。

类似于我们在计算上下文向量时对输入向量进行加权求和(见第3.3节),我们现在通过对值向量进行加权求和来计算上下文向量。在这里,注意力权重作为加权因子,用于衡量每个值向量的相对重要性。和之前一样,我们可以使用矩阵乘法一步获得输出:

ini 复制代码
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

得到的向量内容如下:

scss 复制代码
tensor([0.3061, 0.8210])

到目前为止,我们只计算了一个上下文向量 z(2)。接下来,我们将对代码进行泛化,以计算输入序列中的所有上下文向量 z(1) 到 z(T)。

为什么使用查询、键和值?

在注意力机制的上下文中,"键"、"查询"和"值"这三个术语借鉴自信息检索和数据库领域,其中类似的概念用于存储、搜索和检索信息。

查询类似于数据库中的搜索查询。它代表模型当前关注或试图理解的项目(例如句子中的一个单词或词元)。查询用于探测输入序列的其他部分,以确定对它们应该给予多少注意。

键就像数据库中用于索引和搜索的键。在注意力机制中,输入序列中的每个项目(例如句子中的每个单词)都有一个关联的键。这些键用于与查询进行匹配。

在这种情况下,值类似于数据库中键值对中的值。它代表输入项目的实际内容或表示。一旦模型确定哪些键(因此哪些输入部分)与查询(当前关注的项目)最相关,它就会检索相应的值。

实现紧凑的自注意力Python类

到目前为止,我们经历了许多步骤来计算自注意力输出。这主要是为了说明目的,使我们能够逐步进行。在实际应用中,考虑到下一章中的LLM实现,将这些代码组织成一个Python类是很有帮助的,如下所示。

代码清单3.1:紧凑的自注意力类

ini 复制代码
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

在这段PyTorch代码中,SelfAttention_v1 是一个派生自 nn.Module 的类,nn.Module 是PyTorch模型的基本构建块,提供了模型层创建和管理所需的功能。

__init__ 方法初始化了查询、键和值的可训练权重矩阵(W_queryW_keyW_value),每个矩阵将输入维度 d_in 转换为输出维度 d_out

在前向传播过程中,使用 forward 方法,我们通过将查询与键相乘来计算注意力分数(attn_scores),然后使用 softmax 对这些分数进行归一化。最后,通过将这些归一化的注意力分数与值加权,创建上下文向量。

我们可以如下使用这个类:

scss 复制代码
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

由于 inputs 包含六个嵌入向量,因此这将产生一个存储六个上下文向量的矩阵:

ini 复制代码
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

作为快速检查,请注意第二行([0.3061, 0.8210])与上一节中的 context_vec_2 的内容相匹配。图3.18总结了我们刚刚实现的自注意力机制。

自注意力机制涉及可训练的权重矩阵 WqW_qWq​、WkW_kWk​ 和 WvW_vWv​。这些矩阵分别将输入数据转换为查询、键和值,这些都是注意力机制的重要组成部分。随着模型在训练过程中接触到更多数据,它会调整这些可训练的权重,正如我们将在后面的章节中看到的。

我们可以进一步改进 SelfAttention_v1 的实现,利用 PyTorch 的 nn.Linear 层,这样在禁用偏置单元时可以有效地执行矩阵乘法。此外,使用 nn.Linear 而不是手动实现 nn.Parameter(torch.rand(...)) 的一个显著优势是,nn.Linear 具有优化的权重初始化方案,有助于更稳定和有效的模型训练。

代码清单3.2:使用PyTorch的线性层的自注意力类

ini 复制代码
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

你可以像使用 SelfAttention_v1 一样使用 SelfAttention_v2

scss 复制代码
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

输出为:

ini 复制代码
tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

请注意,SelfAttention_v1SelfAttention_v2 的输出不同,因为它们使用了不同的权重矩阵初始值,nn.Linear 使用了更复杂的权重初始化方案。

练习 3.1 比较 SelfAttention_v1 和 SelfAttention_v2

请注意,SelfAttention_v2 中的 nn.Linear 使用了不同的权重初始化方案,而 SelfAttention_v1 中使用的是 nn.Parameter(torch.rand(d_in, d_out)),这导致两个机制产生不同的结果。为了检查这两个实现(SelfAttention_v1SelfAttention_v2)在其他方面是否相似,我们可以将 SelfAttention_v2 对象的权重矩阵转移到 SelfAttention_v1,使得两个对象都产生相同的结果。

你的任务是正确地将 SelfAttention_v2 实例的权重分配给 SelfAttention_v1 实例。为此,你需要理解这两个版本中权重之间的关系。(提示:nn.Linear 以转置形式存储权重矩阵。)在分配之后,你应该观察到两个实例产生相同的输出。

接下来,我们将对自注意力机制进行增强,重点是引入因果和多头元素。因果方面涉及修改注意力机制,以防止模型访问序列中的未来信息,这对于语言建模等任务至关重要,因为每个单词的预测应仅依赖于之前的单词。

多头组件则涉及将注意力机制拆分为多个"头"。每个头学习数据的不同方面,使得模型能够同时关注来自不同表示子空间的不同位置的信息。这提高了模型在复杂任务中的性能。

用因果注意力隐藏未来单词

对于许多LLM任务,你希望自注意力机制在预测序列中的下一个单词时,仅考虑当前位之前的词元。因果注意力,也称为屏蔽注意力,是自注意力的一种特殊形式。它限制模型在计算注意力分数时,只考虑序列中当前和之前的输入。这与标准自注意力机制不同,后者允许一次访问整个输入序列。

现在,我们将修改标准自注意力机制以创建因果注意力机制,这对于在后续章节中开发LLM至关重要。为了在类似GPT的LLM中实现这一点,对于每个处理的词元,我们会屏蔽当前词元之后的未来词元,如图3.19所示。我们屏蔽对角线以上的注意力权重,并对未屏蔽的注意力权重进行归一化,以确保每一行的注意力权重之和为1。稍后,我们将以代码的形式实现这一屏蔽和归一化过程。

应用因果注意力掩码

我们的下一步是用代码实现因果注意力掩码。为了实现应用因果注意力掩码以获取掩蔽的注意力权重的步骤,如图3.20所示,让我们使用上一节中的注意力分数和权重来编写因果注意力机制的代码。

在第一步中,我们使用 softmax 函数计算注意力权重,正如之前所做的:

ini 复制代码
queries = sa_v2.W_query(inputs)     #1
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

这将产生以下注意力权重:

ini 复制代码
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

我们可以使用 PyTorch 的 tril 函数来创建一个掩码,其中对角线以上的值为零:

ini 复制代码
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

得到的掩码为:

css 复制代码
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

现在,我们可以将这个掩码与注意力权重相乘,以将对角线以上的值置为零:

ini 复制代码
masked_simple = attn_weights * mask_simple
print(masked_simple)

可以看到,对角线以上的元素成功被置为零:

ini 复制代码
tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

第三步是重新归一化注意力权重,使每一行的和为1。我们可以通过将每一行的每个元素除以该行的总和来实现:

ini 复制代码
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

结果是一个注意力权重矩阵,其中对角线以上的注意力权重被置为零,且每行的和为1:

ini 复制代码
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

信息泄漏

当我们应用掩码并重新归一化注意力权重时,初看起来似乎未来的词元信息(我们打算屏蔽的部分)仍然可能影响当前词元,因为它们的值是软max计算的一部分。然而,关键的见解是,当我们在掩蔽后重新归一化注意力权重时,我们实际上是在对一个更小的子集重新计算软max(因为被屏蔽的位置不对软max值产生贡献)。

软max的数学优雅在于,尽管最初包括所有位置在分母中,但在掩蔽和重新归一化后,屏蔽位置的影响被消除了------它们在任何意义上都不对软max得分产生贡献。

简单来说,在掩蔽和重新归一化之后,注意力权重的分布就像是最初仅在未屏蔽的位置之间计算的一样。这确保了没有来自未来(或其他被屏蔽)词元的信息泄漏,正如我们所期望的那样。

虽然我们此时可以结束因果注意力的实现,但我们仍然可以进行改进。让我们利用软max函数的一个数学特性,更高效地实现掩蔽注意力权重的计算,以更少的步骤完成,如图3.21所示。

软max函数将其输入转换为概率分布。当某一行中存在负无穷值(-∞)时,软max函数将其视为零概率。(从数学上讲,这是因为 e^(-∞) 接近于 0。)

我们可以通过创建一个对角线以上为 1 的掩码,并将这些 1 替换为负无穷值(-inf),以更高效地实现这一掩蔽"技巧":

ini 复制代码
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

这将产生如下掩码:

ini 复制代码
tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

现在,我们只需要对这些掩蔽后的结果应用软max函数,就完成了:

scss 复制代码
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

从输出中可以看到,每一行的值都加和为 1,因此不需要进一步的归一化:

ini 复制代码
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

我们现在可以使用修改后的注意力权重通过 context_vec = attn_weights @ values 来计算上下文向量,正如在第 3.4 节中所述。然而,我们首先将讨论另一个对因果注意力机制的小调整,这在训练 LLM 时有助于减少过拟合。

使用 dropout 掩蔽额外的注意力权重

在深度学习中,dropout 是一种技术,它在训练过程中随机忽略被选择的隐藏层单元,有效地"丢弃"它们。这种方法有助于防止过拟合,确保模型不会过于依赖任何特定的隐藏层单元。需要强调的是,dropout 仅在训练期间使用,之后会被禁用。

在 transformer 架构中,包括像 GPT 这样的模型,dropout 通常在两个特定的时机应用:在计算注意力权重后或在将注意力权重应用于值向量后。这里我们将在计算注意力权重之后应用 dropout 掩码,如图 3.22 所示,因为这种变体在实践中更为常见。

在以下代码示例中,我们使用 50% 的 dropout 率,这意味着会掩蔽一半的注意力权重。(当我们在后面的章节中训练 GPT 模型时,将使用较低的 dropout 率,例如 0.1 或 0.2。)我们首先对一个由 1 组成的 6 × 6 张量应用 PyTorch 的 dropout 实现,以简单示例为主:

ini 复制代码
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)    #1
example = torch.ones(6, 6)      #2
print(dropout(example))
#1 我们选择 50% 的 dropout 率。
#2 这里创建一个由 1 组成的矩阵。

如我们所见,约一半的值被置为零:

css 复制代码
tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

当将 dropout 应用于一个注意力权重矩阵时,50% 的速率意味着矩阵中的一半元素会被随机设为零。为了补偿有效元素的减少,矩阵中剩余元素的值会按比例放大 1/0.5 = 2。这种缩放对于保持注意力权重的整体平衡至关重要,确保在训练和推理阶段注意力机制的平均影响保持一致。

现在,让我们将 dropout 应用到注意力权重矩阵本身:

scss 复制代码
torch.manual_seed(123)
print(dropout(attn_weights))

结果的注意力权重矩阵现在有额外的元素被置为零,剩余的 1 被重新缩放:

ini 复制代码
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)

请注意,结果的 dropout 输出可能因操作系统而异;您可以在 PyTorch 问题追踪器 上阅读更多关于此不一致的信息。

在了解了因果注意力和 dropout 掩蔽之后,我们现在可以开发一个简洁的 Python 类。该类旨在有效应用这两种技术。

实现紧凑的因果注意力类

我们将把因果注意力和 dropout 的修改合并到我们在第 3.4 节中开发的 SelfAttention Python 类中。这个类将作为开发多头注意力的模板,这是我们将实现的最终注意力类。

但在开始之前,让我们确保代码能够处理由多个输入组成的批次,以便 CausalAttention 类能够支持我们在第 2 章中实现的数据加载器生成的批量输出。

为简单起见,为了模拟这样的批量输入,我们重复输入文本示例:

ini 复制代码
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)                #1
#1 两个输入,每个输入有六个 token;每个 token 的嵌入维度为 3。

这将生成一个三维张量,包含两个输入文本,每个文本有六个 token,每个 token 是一个三维嵌入向量:

css 复制代码
torch.Size([2, 6, 3])

下面的 CausalAttention 类与我们之前实现的 SelfAttention 类类似,不同之处在于我们添加了 dropout 和因果掩蔽组件。

Listing 3.3 紧凑的因果注意力类

ini 复制代码
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)            #1
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length),
           diagonal=1)
        )             #2

    def forward(self, x):
        b, num_tokens, d_in = x.shape                   #3
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)   
        attn_scores.masked_fill_(                    #4
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

在这个 PyTorch 代码中,CausalAttention 是一个派生自 nn.Module 的类,这是构建 PyTorch 模型的基本构建块,提供了创建和管理模型层所需的功能。

__init__ 方法初始化了查询、键和值的可训练权重矩阵(W_query、W_key 和 W_value),每个矩阵将输入维度 d_in 转换为输出维度 d_out。

在前向传递中,使用 forward 方法,我们通过计算查询和键的乘积来计算注意力分数(attn_scores),然后使用 softmax 进行归一化。最后,通过将值与这些归一化的注意力权重加权,我们创建了上下文向量。

我们可以像之前的 SelfAttention 一样使用这个类:

scss 复制代码
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

结果的上下文向量是一个三维张量,其中每个 token 现在由一个二维嵌入表示:

css 复制代码
context_vecs.shape: torch.Size([2, 6, 2])

图 3.23 总结了到目前为止我们所取得的成就。我们专注于神经网络中因果注意力的概念和实现。接下来,我们将扩展这个概念并实现一个多头注意力模块,该模块可以并行实现多个因果注意力机制。

将单头注意力扩展到多头注意力

我们最后一步将是将之前实现的 causal attention 类扩展为多个头部。这也称为多头注意力。

"多头"一词指的是将注意力机制划分为多个"头",每个头独立操作。在这个上下文中,单个 causal attention 模块可以视为单头注意力,其中只有一组注意力权重在顺序处理中输入。

我们将从 causal attention 扩展到多头注意力。首先,我们将通过堆叠多个 CausalAttention 模块直观地构建一个多头注意力模块。然后,我们将以更复杂但计算效率更高的方式实现相同的多头注意力模块。

堆叠多个单头注意力层

在实际操作中,实现多头注意力涉及创建多个自注意力机制的实例(见图3.18),每个实例都有其自己的权重,然后将它们的输出组合在一起。使用多个自注意力机制的实例可能会消耗大量计算资源,但对于变压器模型(如基于变压器的 LLM)所知的复杂模式识别至关重要。

图3.24 说明了多头注意力模块的结构,该模块由多个单头注意力模块组成,如图3.18 所示,这些模块一个叠一个地堆叠在一起。

正如之前提到的,多头注意力的主要思想是使用不同的、学习到的线性变换多次(并行)运行注意力机制------这些变换是将输入数据(如注意力机制中的查询、键和值向量)乘以权重矩阵的结果。在代码中,我们可以通过实现一个简单的 MultiHeadAttentionWrapper 类来实现这一点,该类堆叠多个之前实现的 CausalAttention 模块的实例。

列表 3.4 多头注意力的包装类实现

ruby 复制代码
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                 d_in, d_out, context_length, dropout, qkv_bias
             ) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

例如,如果我们使用这个 MultiHeadAttentionWrapper 类,并设置两个注意力头(通过 num_heads=2)以及 CausalAttention 输出维度 d_out=2,我们将得到一个四维的上下文向量(d_out*num_heads=4),如图3.25所示。

为了进一步说明这个概念,我们可以像之前的 CausalAttention 类一样使用 MultiHeadAttentionWrapper 类:

ini 复制代码
torch.manual_seed(123)
context_length = batch.shape[1]  # 这是令牌的数量
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

这将得到以下张量,表示上下文向量:

ini 复制代码
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)

结果的 context_vecs 张量的第一维是 2,因为我们有两个输入文本(输入文本被重复,这就是为什么它们的上下文向量完全相同)。第二维指的是每个输入中的 6 个令牌。第三维指的是每个令牌的四维嵌入。

练习 3.2 返回二维嵌入向量

更改 MultiHeadAttentionWrapper(..., num_heads=2) 调用的输入参数,使输出的上下文向量为二维而不是四维,同时保持 num_heads=2 的设置。提示:你不需要修改类的实现,只需更改其他一个输入参数即可。

到目前为止,我们已经实现了一个将多个单头注意力模块结合在一起的 MultiHeadAttentionWrapper。然而,这些模块在 forward 方法中是通过 [head(x) for head in self.heads] 顺序处理的。我们可以通过并行处理这些头来改进这个实现。一种实现方式是通过矩阵乘法同时计算所有注意力头的输出。

实现带权重拆分的多头注意力

到目前为止,我们已经创建了一个 MultiHeadAttentionWrapper 来通过堆叠多个单头注意力模块来实现多头注意力。这是通过实例化和组合多个 CausalAttention 对象完成的。

我们可以将这两个类(MultiHeadAttentionWrapperCausalAttention)的概念合并为一个单一的 MultiHeadAttention 类。此外,除了将 MultiHeadAttentionWrapperCausalAttention 代码合并外,我们还将进行一些其他修改,以更高效地实现多头注意力。

MultiHeadAttentionWrapper 中,多头通过创建一个 CausalAttention 对象的列表(self.heads)来实现,每个对象代表一个单独的注意力头。CausalAttention 类独立地执行注意力机制,来自每个头的结果被串联起来。相比之下,下面的 MultiHeadAttention 类将多头功能集成在一个类中。它通过重塑投影的查询、键和值张量来将输入拆分为多个头,然后在计算注意力后组合这些头的结果。

让我们先看看 MultiHeadAttention 类,然后再进一步讨论。

Listing 3.5 高效的多头注意力类

ini 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads    #1
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)    #2
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)         #3
        queries = self.W_query(x)    #3
        values = self.W_value(x)     #3

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)       #4
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)  
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)                   

        keys = keys.transpose(1, 2)          #5
        queries = queries.transpose(1, 2)    #5
        values = values.transpose(1, 2)      #5

        attn_scores = queries @ keys.transpose(2, 3)   #6
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]    #7

        attn_scores.masked_fill_(mask_bool, -torch.inf)     #8

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)   #9
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)  #10
        context_vec = self.out_proj(context_vec)    #11
        return context_vec

#1 将投影维度减少到所需的输出维度 #2 使用线性层组合头输出 #3 张量形状:(b, num_tokens, d_out) #4 通过添加 num_heads 维度隐式拆分矩阵。然后展开最后一维:(b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)。 #5 从形状 (b, num_tokens, num_heads, head_dim) 转置为 (b, num_heads, num_tokens, head_dim) #6 为每个头计算点积 #7 使掩码截断到令牌数量 #8 使用掩码填充注意力得分 #9 张量形状:(b, num_tokens, n_heads, head_dim) #10 组合头,其中 self.d_out = self.num_heads * self.head_dim #11 添加可选的线性投影

尽管在 MultiHeadAttention 类中的张量重塑(.view)和转置(.transpose)看起来在数学上很复杂,但 MultiHeadAttention 类实现的概念与之前的 MultiHeadAttentionWrapper 相同。

从整体上看,在之前的 MultiHeadAttentionWrapper 中,我们堆叠了多个单头注意力层,并将它们组合成一个多头注意力层。MultiHeadAttention 类采用了一种集成的方法。它从一个多头层开始,然后在内部将该层拆分为单个注意力头,如图 3.26 所示。

通过使用 PyTorch 的 .view.transpose 方法,对张量进行重塑和转置操作,我们实现了对查询、键和值张量的拆分。输入首先经过转换(通过查询、键和值的线性层),然后被重塑以表示多个头。

关键操作是将 d_out 维度拆分为 num_headshead_dim,其中 head_dim = d_out / num_heads。然后使用 .view 方法实现这种拆分:一个维度为 (b, num_tokens, d_out) 的张量被重塑为维度 (b, num_tokens, num_heads, head_dim)

然后对张量进行转置,将 num_heads 维度放在 num_tokens 维度之前,结果形状为 (b, num_heads, num_tokens, head_dim)。这种转置对于正确对齐不同头部的查询、键和值,以及高效地执行批量矩阵乘法至关重要。

为了说明这种批量矩阵乘法,假设我们有以下张量:

css 复制代码
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],    #1
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])
  • #1 该张量的形状为 (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)

现在,我们对该张量自身与其转置了最后两个维度(num_tokenshead_dim)的视图进行批量矩阵乘法:

scss 复制代码
print(a @ a.transpose(2, 3))

结果为:

lua 复制代码
tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])

在这种情况下,PyTorch 中的矩阵乘法实现处理四维输入张量,使得矩阵乘法在最后两个维度(num_tokenshead_dim)之间进行,然后对各个头部重复此操作。

例如,上述操作可以成为分别为每个头部计算矩阵乘法的更紧凑方式:

ini 复制代码
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

结果与我们使用批量矩阵乘法 print(a @ a.transpose(2, 3)) 获得的结果完全相同:

lua 复制代码
First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])

继续讨论 MultiHeadAttention,在计算注意力权重和上下文向量之后,来自所有头部的上下文向量被转置回形状 (b, num_tokens, num_heads, head_dim)。然后将这些向量重塑(展平)为形状 (b, num_tokens, d_out),有效地组合了所有头部的输出。

此外,我们在合并头部之后为 MultiHeadAttention 添加了一个输出投影层(self.out_proj),这在 CausalAttention 类中不存在。这个输出投影层并非严格必要的(更多细节见附录 B),但它在许多 LLM 架构中被广泛使用,这就是我在此添加它以完整性的原因。

尽管由于额外的张量重塑和转置,MultiHeadAttention 类看起来比 MultiHeadAttentionWrapper 更复杂,但它更高效。原因是我们只需要一次矩阵乘法来计算键,例如 keys = self.W_key(x)(查询和值同理)。在 MultiHeadAttentionWrapper 中,我们需要为每个注意力头重复此矩阵乘法,这在计算上是最昂贵的步骤之一。

MultiHeadAttention 类的使用方式与我们之前实现的 SelfAttentionCausalAttention 类类似:

scss 复制代码
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

结果显示输出维度直接由 d_out 参数控制:

ini 复制代码
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])

我们现在已经实现了将在实现和训练 LLM 时使用的 MultiHeadAttention 类。请注意,虽然代码是完全功能性的,但我使用了相对较小的嵌入大小和注意力头数量,以保持输出的可读性。

作为比较,最小的 GPT-2 模型(1.17 亿参数)有 12 个注意力头,且上下文向量嵌入大小为 768。最大的 GPT-2 模型(15 亿参数)有 25 个注意力头,且上下文向量嵌入大小为 1600。在 GPT 模型中,令牌输入的嵌入大小和上下文嵌入是相同的(d_in = d_out)。

总结

  • 注意力机制将输入元素转化为增强的上下文向量表示,这些表示包含了所有输入的信息。

  • 自注意力机制通过对输入的加权求和计算上下文向量表示。

  • 在简化的注意力机制中,注意力权重通过点积计算得出。

  • 点积是一种简洁的方法,用于逐元素相乘两个向量,然后求和。

  • 矩阵乘法虽然不是严格必要的,但通过替换嵌套的 for 循环,能够更高效、紧凑地实现计算。

  • 在 LLM 使用的自注意力机制中,也称为缩放点积注意力,我们引入了可训练的权重矩阵,以计算输入的中间变换:查询、值和键。

  • 在处理从左到右读取和生成文本的 LLM 时,我们添加了因果注意力掩码,以防止 LLM 访问未来的标记。

  • 除了因果注意力掩码以零化注意力权重外,我们还可以添加 dropout 掩码,以减少 LLM 中的过拟合。

  • 基于变压器的 LLM 中的注意力模块涉及多个因果注意力实例,这被称为多头注意力。

  • 我们可以通过堆叠多个因果注意力模块的实例来创建多头注意力模块。

  • 创建多头注意力模块的更高效方法涉及批量矩阵乘法。

相关推荐
m0_7431064643 分钟前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106461 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
AI浩4 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
IE066 小时前
深度学习系列75:sql大模型工具vanna
深度学习
不惑_6 小时前
深度学习 · 手撕 DeepLearning4J ,用Java实现手写数字识别 (附UI效果展示)
java·深度学习·ui
CM莫问7 小时前
python实战(十五)——中文手写体数字图像CNN分类
人工智能·python·深度学习·算法·cnn·图像分类·手写体识别
余炜yw7 小时前
深入探讨激活函数在神经网络中的应用
人工智能·深度学习·机器学习
ARM+FPGA+AI工业主板定制专家8 小时前
基于RK3576/RK3588+FPGA+AI深度学习的轨道异物检测技术研究
人工智能·深度学习
池央9 小时前
StyleGAN - 基于样式的生成对抗网络
人工智能·神经网络·生成对抗网络
小猪咪piggy9 小时前
【深度学习入门】深度学习知识点总结
人工智能·深度学习