从Token预测说起
大语言模型的核心工作原理非常简单:给定前面的0到n个Token,预测第n+1个Token是什么。
举个例子:
- 输入:"今天天气"
- 模型需要预测下一个词可能是:"很好"、"不错"、"真热" 等
但问题来了:当我们要预测下一个Token时,前面的每个Token对当前预测的重要程度是不同的。
比如在句子 "我昨天在北京吃了烤鸭,今天在上海吃了_" 中:
- 要预测最后一个词时,"上海"这个Token显然比"昨天"、"北京"更重要
- 因为我们需要根据"上海"来推测当地的特色美食
注意力机制就是用来解决这个问题的:让模型自动学习,在预测当前Token时,应该把"注意力"放在前面哪些Token上。
注意力机制的核心思想
注意力机制的核心可以用一句话概括:
为每个Token计算一个权重,表示它对当前预测的重要程度,然后加权求和
具体来说,分为三个步骤:
- 计算相关性:当前Token与历史每个Token的相关程度
- 归一化权重:把相关性转换为概率分布(加起来等于1)
- 加权求和:用这些权重对历史Token的信息进行加权平均
QKV矩阵:注意力机制的三个核心角色
为了实现上述思想,注意力机制引入了三个矩阵:Q(Query)、K(Key)、V(Value)
可以用"图书馆查书"来类比理解:
- Q(Query,查询):你想查的内容,比如"我想找关于上海美食的信息"
- K(Key,键):每本书的目录/索引,用来匹配你的查询
- V(Value,值):每本书的实际内容
从Embedding到QKV
每个Token首先被转换为一个Embedding向量(通常是一个高维向量,比如768维或4096维)。假设:
- 输入序列长度为
n(有n个Token) - 每个Token的Embedding维度为
d_model(比如768)
那么输入可以表示为一个矩阵:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X ∈ R n × d model X \in \mathbb{R}^{n \times d_{\text{model}}} </math>X∈Rn×dmodel
其中每一行代表一个Token的Embedding向量。
d_model 到底是什么?
d_model 就是Token向量的维度。更具体地说:
- 输入层 :每个Token通过Embedding层转换为一个
d_model维的向量 - 中间层 :这个维度会贯穿整个Transformer的所有层,每一层的输入和输出都保持
d_model维 - 输出层 :最后一层的输出也是每个Token一个
d_model维的向量
举个例子:
输入文本:"猫 吃 鱼"(3个Token)
<math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 768 d_{\text{model}} = 768 </math>dmodel=768
经过Embedding后:
- Token "猫": <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.1 , 0.2 , ... , 0.5 ] [0.1, 0.2, \ldots, 0.5] </math>[0.1,0.2,...,0.5] ← 768维向量
- Token "吃": <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.3 , 0.1 , ... , 0.7 ] [0.3, 0.1, \ldots, 0.7] </math>[0.3,0.1,...,0.7] ← 768维向量
- Token "鱼": <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.2 , 0.4 , ... , 0.3 ] [0.2, 0.4, \ldots, 0.3] </math>[0.2,0.4,...,0.3] ← 768维向量
整体表示为矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ∈ R 3 × 768 X \in \mathbb{R}^{3 \times 768} </math>X∈R3×768
d_model 在实际模型中的取值:
不同规模的模型,d_model 差异很大:
| 模型规模 | d_model | 典型模型 |
|---|---|---|
| 小型模型 | 768 | BERT-Base, GPT-2 Small |
| 中型模型 | 1024-1280 | BERT-Large, GPT-2 Medium |
| 大型模型 | 1600-5120 | GPT-3 (1.3B-13B), LLaMA-7B/13B |
| 超大型模型 | 6656-12288 | LLaMA-65B, GPT-3 (175B) |
选择 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model d_{\text{model}} </math>dmodel 的原则:
- 必须是64的倍数 :方便GPU计算优化( <math xmlns="http://www.w3.org/1998/Math/MathML"> 768 = 64 × 12 768=64\times12 </math>768=64×12, <math xmlns="http://www.w3.org/1998/Math/MathML"> 1024 = 64 × 16 1024=64\times16 </math>1024=64×16)
- 能被注意力头数整除 :比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 768 d_{\text{model}}=768 </math>dmodel=768, <math xmlns="http://www.w3.org/1998/Math/MathML"> h = 12 h=12 </math>h=12,每个头维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> = 64 =64 </math>=64
- 维度越大,表达能力越强:但参数量和计算量会显著增加
接下来,我们通过三个可学习的权重矩阵,将X分别转换为Q、K、V:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q = X ⋅ W Q 其中 W Q ∈ R d model × d k K = X ⋅ W K 其中 W K ∈ R d model × d k V = X ⋅ W V 其中 W V ∈ R d model × d v \begin{aligned} Q &= X \cdot W_Q \quad \text{其中 } W_Q \in \mathbb{R}^{d_{\text{model}} \times d_k} \\ K &= X \cdot W_K \quad \text{其中 } W_K \in \mathbb{R}^{d_{\text{model}} \times d_k} \\ V &= X \cdot W_V \quad \text{其中 } W_V \in \mathbb{R}^{d_{\text{model}} \times d_v} \end{aligned} </math>QKV=X⋅WQ其中 WQ∈Rdmodel×dk=X⋅WK其中 WK∈Rdmodel×dk=X⋅WV其中 WV∈Rdmodel×dv
参数解释:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q , W K , W V W_Q, W_K, W_V </math>WQ,WK,WV 是三个权重矩阵,在训练过程中学习得到
- <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 是Q和K的维度(通常等于 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model / h d_{\text{model}} / h </math>dmodel/h,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h是多头注意力的头数)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> d v d_v </math>dv 是V的维度(通常也等于 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model / h d_{\text{model}} / h </math>dmodel/h)
- 在单头注意力中,通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = d model d_k = d_v = d_{\text{model}} </math>dk=dv=dmodel
变换后得到:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∈ R n × d k (n个查询向量) K ∈ R n × d k (n个键向量) V ∈ R n × d v (n个值向量) \begin{aligned} Q &\in \mathbb{R}^{n \times d_k} \quad \text{(n个查询向量)} \\ K &\in \mathbb{R}^{n \times d_k} \quad \text{(n个键向量)} \\ V &\in \mathbb{R}^{n \times d_v} \quad \text{(n个值向量)} \end{aligned} </math>QKV∈Rn×dk(n个查询向量)∈Rn×dk(n个键向量)∈Rn×dv(n个值向量)
注意力计算的完整公式
第一步:计算注意力分数(Attention Scores)
用Q和K的点积来衡量相关性:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Scores = Q ⋅ K T \text{Scores} = Q \cdot K^T </math>Scores=Q⋅KT
参数解释:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ⋅ K T Q \cdot K^T </math>Q⋅KT 是矩阵乘法,结果维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( n × n ) (n \times n) </math>(n×n)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Scores [ i , j ] \text{Scores}[i, j] </math>Scores[i,j] 表示第i个Token(查询)与第j个Token(键)的相关性
- 点积越大,说明两个向量越相似,相关性越高
第二步:缩放(Scaling)
为了防止点积结果过大导致梯度消失,除以一个缩放因子:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Scores scaled = Scores d k \text{Scores}_{\text{scaled}} = \frac{\text{Scores}}{\sqrt{d_k}} </math>Scoresscaled=dk Scores
参数解释:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 是缩放因子
- 为什么要除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ?因为当维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk很大时,点积的方差会变大,导致softmax后梯度很小
- 这个缩放操作可以让点积的方差稳定在1左右
第三步:应用Softmax归一化
将分数转换为概率分布:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention_Weights = softmax ( Scores scaled ) \text{Attention\Weights} = \text{softmax}(\text{Scores}{\text{scaled}}) </math>Attention_Weights=softmax(Scoresscaled)
参数解释:
- Softmax函数: <math xmlns="http://www.w3.org/1998/Math/MathML"> softmax ( x i ) = exp ( x i ) ∑ j exp ( x j ) \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} </math>softmax(xi)=∑jexp(xj)exp(xi)
- 作用:把实数分数转换为0-1之间的概率,且所有概率加起来等于1
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Attention_Weights [ i , j ] \text{Attention\_Weights}[i, j] </math>Attention_Weights[i,j] 表示第i个Token应该给第j个Token分配多少注意力权重
注意 :在实际应用中(如GPT),这里还会加一个掩码(Mask),防止模型看到未来的信息:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention_Weights = softmax ( Scores scaled + Mask ) \text{Attention\Weights} = \text{softmax}(\text{Scores}{\text{scaled}} + \text{Mask}) </math>Attention_Weights=softmax(Scoresscaled+Mask)
其中Mask会把未来位置的分数设为负无穷,使得softmax后权重为0。
第四步:加权求和
用注意力权重对V进行加权求和:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Output = Attention_Weights ⋅ V \text{Output} = \text{Attention\_Weights} \cdot V </math>Output=Attention_Weights⋅V
参数解释:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Attention_Weights ∈ R n × n \text{Attention\_Weights} \in \mathbb{R}^{n \times n} </math>Attention_Weights∈Rn×n
- <math xmlns="http://www.w3.org/1998/Math/MathML"> V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} </math>V∈Rn×dv
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Output ∈ R n × d v \text{Output} \in \mathbb{R}^{n \times d_v} </math>Output∈Rn×dv
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Output [ i ] \text{Output}[i] </math>Output[i] 是第i个Token的输出表示,它是所有Token的V向量的加权平均
- 权重就是该Token对其他Token的注意力分数
完整公式(Scaled Dot-Product Attention)
将上述步骤合并,得到注意力机制的标准公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q ⋅ K T d k ) ⋅ V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V </math>Attention(Q,K,V)=softmax(dk Q⋅KT)⋅V
这就是Transformer论文中最著名的公式!
具体例子:理解注意力权重
假设我们有一个简单的句子:"猫 吃 鱼",共3个Token。
第一步:生成QKV
lua
输入 X:
Token 1: [0.1, 0.2, 0.3, 0.4] # "猫"的Embedding
Token 2: [0.5, 0.6, 0.7, 0.8] # "吃"的Embedding
Token 3: [0.2, 0.3, 0.1, 0.5] # "鱼"的Embedding
通过 W_Q, W_K, W_V 变换后得到 Q, K, V(这里简化为2维)
Q: [[q1_1, q1_2], [q2_1, q2_2], [q3_1, q3_2]]
K: [[k1_1, k1_2], [k2_1, k2_2], [k3_1, k3_2]]
V: [[v1_1, v1_2], [v2_1, v2_2], [v3_1, v3_2]]
第二步:计算注意力分数
ini
Scores = Q · K^T
得到一个 3×3 的矩阵:
Token1 Token2 Token3
Token1: [s1_1 s1_2 s1_3] # Token1与所有Token的相关性
Token2: [s2_1 s2_2 s2_3] # Token2与所有Token的相关性
Token3: [s3_1 s3_2 s3_3] # Token3与所有Token的相关性
第三步:Softmax归一化
ini
对每一行应用Softmax,得到注意力权重:
Token1 Token2 Token3
Token1: [0.2 0.3 0.5] # Token1应该关注各Token的权重(和为1)
Token2: [0.1 0.6 0.3] # Token2应该关注各Token的权重(和为1)
Token3: [0.3 0.5 0.2] # Token3应该关注各Token的权重(和为1)
比如Token3("鱼")对Token2("吃")的注意力权重是0.5,说明在理解"鱼"时,"吃"这个词很重要。
第四步:加权求和
ini
Output = Attention_Weights · V
每个Token的输出是所有Token的V向量的加权平均
从单头到多头注意力机制(Multi-Head Attention)
为什么需要多头?
单头注意力只能学习一种"注意力模式"。但实际上,理解一个Token可能需要关注多个不同方面:
以句子 "The animal didn't cross the street because it was too tired" 为例:
- 语义关系头:it 指向 animal(代词指代)
- 句法关系头:cross 指向 street(动宾关系)
- 因果关系头:because 关联前后两个分句
单个注意力头无法同时捕捉这些不同类型的关系,因此需要多个注意力头并行工作。
多头注意力的实现
多头注意力的核心思想:使用h个独立的注意力头,每个头学习不同的表示子空间
步骤1:线性投影到多个子空间
将输入X投影到h组不同的Q、K、V:
对于第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个头( <math xmlns="http://www.w3.org/1998/Math/MathML"> i = 1 , 2 , ... , h i = 1, 2, \ldots, h </math>i=1,2,...,h):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q i = X ⋅ W i Q 其中 W i Q ∈ R d model × d k K i = X ⋅ W i K 其中 W i K ∈ R d model × d k V i = X ⋅ W i V 其中 W i V ∈ R d model × d v \begin{aligned} Q_i &= X \cdot W_i^Q \quad \text{其中 } W_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k} \\ K_i &= X \cdot W_i^K \quad \text{其中 } W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k} \\ V_i &= X \cdot W_i^V \quad \text{其中 } W_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v} \end{aligned} </math>QiKiVi=X⋅WiQ其中 WiQ∈Rdmodel×dk=X⋅WiK其中 WiK∈Rdmodel×dk=X⋅WiV其中 WiV∈Rdmodel×dv
参数解释:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 是注意力头的数量(比如8或16)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = d model / h d_k = d_v = d_{\text{model}} / h </math>dk=dv=dmodel/h(每个头的维度是总维度的1/h)
- 每个头有自己独立的权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V </math>WiQ,WiK,WiV
- 比如 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 768 , h = 12 d_{\text{model}}=768, h=12 </math>dmodel=768,h=12,则每个头的维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = 64 d_k=64 </math>dk=64
步骤2:并行计算h个注意力头
对每个头独立计算注意力:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> head i = Attention ( Q i , K i , V i ) = softmax ( Q i ⋅ K i T d k ) ⋅ V i \text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i \cdot K_i^T}{\sqrt{d_k}}\right) \cdot V_i </math>headi=Attention(Qi,Ki,Vi)=softmax(dk Qi⋅KiT)⋅Vi
参数解释:
- 每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> head i ∈ R n × d v \text{head}_i \in \mathbb{R}^{n \times d_v} </math>headi∈Rn×dv
- <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h个头完全并行计算,互不干扰
- 每个头可以学习关注输入的不同方面
步骤3:拼接并线性变换
将h个头的输出拼接起来,再通过一个线性层:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , ... , head h ) ⋅ W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \cdot W_O </math>MultiHead(Q,K,V)=Concat(head1,head2,...,headh)⋅WO
参数解释:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Concat \text{Concat} </math>Concat 将 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h个头在最后一维拼接: <math xmlns="http://www.w3.org/1998/Math/MathML"> R n × d v × h → R n × h ⋅ d v \mathbb{R}^{n \times d_v} \times h \rightarrow \mathbb{R}^{n \times h \cdot d_v} </math>Rn×dv×h→Rn×h⋅dv
- 由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> d v = d model / h d_v = d_{\text{model}} / h </math>dv=dmodel/h,所以拼接后维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> R n × d model \mathbb{R}^{n \times d_{\text{model}}} </math>Rn×dmodel
- <math xmlns="http://www.w3.org/1998/Math/MathML"> W O ∈ R d model × d model W_O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}} </math>WO∈Rdmodel×dmodel 是输出权重矩阵
- 最终输出维度: <math xmlns="http://www.w3.org/1998/Math/MathML"> R n × d model \mathbb{R}^{n \times d_{\text{model}}} </math>Rn×dmodel,与输入维度一致
完整的多头注意力公式
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MultiHead ( Q , K , V ) = Concat ( head 1 , ... , head h ) ⋅ W O 其中 head i = Attention ( X ⋅ W i Q , X ⋅ W i K , X ⋅ W i V ) \begin{aligned} \text{MultiHead}(Q, K, V) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \cdot W_O \\ \text{其中 } \text{head}_i &= \text{Attention}(X \cdot W_i^Q, X \cdot W_i^K, X \cdot W_i^V) \end{aligned} </math>MultiHead(Q,K,V)其中 headi=Concat(head1,...,headh)⋅WO=Attention(X⋅WiQ,X⋅WiK,X⋅WiV)
参数总结
假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 768 d_{\text{model}} = 768 </math>dmodel=768, <math xmlns="http://www.w3.org/1998/Math/MathML"> h = 12 h = 12 </math>h=12:
| 参数 | 形状 | 数量 | 说明 |
|---|---|---|---|
| <math xmlns="http://www.w3.org/1998/Math/MathML"> W i Q W_i^Q </math>WiQ | <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 768 , 64 ) (768, 64) </math>(768,64) | 12个 | 每个头的Query权重矩阵 |
| <math xmlns="http://www.w3.org/1998/Math/MathML"> W i K W_i^K </math>WiK | <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 768 , 64 ) (768, 64) </math>(768,64) | 12个 | 每个头的Key权重矩阵 |
| <math xmlns="http://www.w3.org/1998/Math/MathML"> W i V W_i^V </math>WiV | <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 768 , 64 ) (768, 64) </math>(768,64) | 12个 | 每个头的Value权重矩阵 |
| <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W_O </math>WO | <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 768 , 768 ) (768, 768) </math>(768,768) | 1个 | 输出权重矩阵 |
总参数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> = 12 × ( 768 × 64 + 768 × 64 + 768 × 64 ) + 768 × 768 ≈ 2.36 M = 12 \times (768\times64 + 768\times64 + 768\times64) + 768\times768 \approx 2.36M </math>=12×(768×64+768×64+768×64)+768×768≈2.36M
多头注意力的优势
- 捕捉多种关系:不同的头可以学习不同类型的依赖关系(语义、句法、位置等)
- 增强表达能力:多个子空间的表示比单一空间更丰富
- 参数效率 :虽然有多个头,但每个头的维度变小( <math xmlns="http://www.w3.org/1998/Math/MathML"> d model / h d_{\text{model}}/h </math>dmodel/h),总参数量与单头相当
- 并行计算 : <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h个头可以完全并行,提高计算效率
可视化理解
想象你在读一篇文章,理解当前词时:
- 单头注意力:只能用一种方式去理解上下文
- 多头注意力 :可以同时从多个角度理解
- 头1:关注句法结构(主谓宾)
- 头2:关注语义关系(同义、反义)
- 头3:关注长距离依赖(代词指代)
- 头4:关注局部搭配(固定词组)
- ...
最后把这些不同角度的理解综合起来,形成对当前词更全面的表示。
小结
- 注意力机制的本质:加权求和,权重由相关性决定
- QKV的作用 :
- Q(Query):当前位置的查询向量
- K(Key):用于匹配查询的键向量
- V(Value):实际要加权求和的内容向量
- 核心公式 : <math xmlns="http://www.w3.org/1998/Math/MathML"> Attention ( Q , K , V ) = softmax ( Q K T d k ) ⋅ V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot V </math>Attention(Q,K,V)=softmax(dk QKT)⋅V
- 多头注意力:并行运行多个注意力头,每个头学习不同的表示子空间,最后拼接融合
通过注意力机制,模型能够动态地决定在预测下一个Token时,应该更多地"关注"历史序列中的哪些Token,从而实现更准确的预测。