拒绝死记硬背!一文彻底打通 Transformer 注意力机制的数据流转全链路
在深度学习与自然语言处理(NLP)的江湖里,Transformer 绝对是当之无愧的"一代宗师"。但很多同学在面试或实际手写代码时,面对 TokenID → Embedding → QKV → Attention 这一长串流程,往往只能背出公式,却很难在脑海中建立起清晰的矩阵维度变化图景。
今天,我们就抛开晦涩的数学推导,以数据流转的视角,一步步硬核拆解 Transformer 的核心计算过程。建议配合代码和矩阵维度变化一起食用,效果更佳!
1. 离散到连续:TokenID → Embedding (X)
模型是不认识文字的,它只认识数字。
- TokenID :首先,分词器(Tokenizer)将文本切分,并在词表(Vocabulary)中查找对应的整数索引。假设词表大小为
vocab_size,输入序列长度为seq_len,我们得到一个形状为[seq_len]的整数张量。 - Embedding (X) :模型内部维护着一个巨大的可学习矩阵
Embedding Matrix,形状为[vocab_size, d_model]。TokenID 本质上就是在这个矩阵中做"查表"操作(One-hot 编码与矩阵相乘)。
维度变化 :[seq_len] ️ [seq_len, d_model]
此时,离散的符号变成了稠密的高维向量,语义相近的词在空间中的距离也会更近。
2. 身份的裂变:线性投影 (Q, K, V)
这是自注意力机制(Self-Attention)的起点。为了让每个 Token 能够"表达需求"并"提供信息",我们需要将输入向量 XX 分别通过三个独立的线性层(全连接层)进行投影:
- Query (Q) :"我需要什么信息?"(查询向量)
- Key (K) :"我能提供什么信息?"(键向量)
- Value (V) :"我实际携带的信息内容。"(值向量)
这三个矩阵由可学习的权重 WQ,WK,WVWQ,WK,WV 生成。
维度变化:
- Q=X⋅WQQ=X⋅WQ ️
[seq_len, d_k] - K=X⋅WKK=X⋅WK ️
[seq_len, d_k] - V=X⋅WVV=X⋅WV ️
[seq_len, d_v]
(注:在标准 Transformer 中,通常 dk=dv=dmodeldk=dv=dmodel )
完整分步解析:线性投影到底在算什么?
为了让你彻底看透线性投影的本质,我们设定一个极简场景:
- 输入文本:2 个 token( t1,t2t1,t2 ),序列长度 N=2N=2
- 原始嵌入维度: dmodel=4dmodel=4 (每个 token 的基础语义向量是 4 维)
- 投影输出维度: dk=2dk=2 (Q、K 向量统一压缩到 2 维)
前置数据准备 :
输入嵌入矩阵 X2×4X2×4 :
X=2101102−1X=2110021−1
投影权重 WQ,WK,WVWQ,WK,WV 形状均为 4×24×2 (为简化理解,示例中设成相同,真实模型中它们是完全独立的参数):
WQ=WK=WV=101−10110WQ=WK=WV=11010−110
第一步:什么是线性投影?
线性投影本质就是矩阵乘法:用一组可训练权重,把原始高维嵌入向量 XX ,映射到全新维度的专用向量空间,拆分出三种分工向量: Q=X⋅WQ,K=X⋅WK,V=X⋅WVQ=X⋅WQ,K=X⋅WK,V=X⋅WV 。
- WQWQ :把原始语义转换成查询向量 Q,代表「当前 token 想要找什么信息」
- WKWK :把原始语义转换成键向量 K,代表「每个 token 能提供什么信息」
- WVWV :把原始语义转换成值向量 V,代表「每个 token 真正携带的内容」
第二步:维度匹配规则(为什么输出是 2×22×2 ?)
根据矩阵乘法约束 Am×n⋅Bn×p=Cm×pAm×n⋅Bn×p=Cm×p :
- XX 的形状是 2×42×4 (2 行 token,4 列原始特征)
- WQWQ 的形状是 4×24×2 (4 行匹配输入维度,2 列定义输出维度 dkdk )
- 输出 QQ 的形状必然是 2×22×2 (2 行 token,2 列投影后特征)
第三步:手动计算投影过程
取 t1t1 嵌入行向量 2,1,0,12,1,0,1 乘 WQWQ :
- Q 第 1 维:2×1+1×0+0×1+1×(−1)=12×1+1×0+0×1+1×(−1)=1
- Q 第 2 维:2×0+1×1+0×1+1×0=12×0+1×1+0×1+1×0=1
得到 Q1=1,1Q1=1,1
取 t2t2 嵌入行向量 1,0,2,−11,0,2,−1 乘 WQWQ :
- Q 第 1 维:1×1+0×0+2×1+(−1)×(−1)=41×1+0×0+2×1+(−1)×(−1)=4
- Q 第 2 维:1×0+0×1+2×1+(−1)×0=21×0+0×1+2×1+(−1)×0=2
得到 Q2=4,2Q2=4,2
最终投影结果:
Q=K=V=1412Q=K=V=1142
第四步:线性投影的核心作用
- 分工解耦:原始嵌入向量同时包含语义、位置、语法信息,通过三组权重拆分出「查询、匹配、内容」三种独立表征;
- 维度压缩 / 变换:自由控制 QK 向量维度 dkdk ,多头注意力里会把大 dmodeldmodel 切分成多个小 dkdk ;
- 赋予模型拟合能力: WQ/WK/WVWQ/WK/WV 是可训练参数,模型会自动学习怎么映射才能精准计算 token 间关联。
核心提示:教学示例 vs 真实模型
这两组矩阵( XX 嵌入、 WW 权重)是人为自定义的教学演示数据,不是模型训练出来的真实参数,目的是降低理解门槛,全程能手动笔算、一眼看懂维度与运算逻辑。
为什么初始化 XX 为 2×42×4 ?
- 控制超参: N=2N=2 序列极短,后续算 QKQK 点积得到的相似度矩阵只有 2×22×2 ,数字少、计算量小; dmodel=4dmodel=4 既能完整演示线性变换,又不会出现超大维度导致手算繁琐。
- 数值搭配:混合正数、0、负数,模拟真实嵌入特征;数字小,相乘后输出全是整数,避开小数,新手笔算无压力。
- 可视化区分:一行 = 一个 token 的完整语义向量,帮你建立「矩阵行代表 token」的思维习惯。
为什么初始化 WW 为 4×24×2 ?
- 维度严格匹配:输入维度 4 → 权重矩阵必须是 4 行;输出维度 2 → 权重矩阵必须是 2 列。
- 极简系数:大量 0、1、-1 减少乘法计算量,包含 -1 模拟真实权重可正向/反向抑制特征。
- 降低负担:真实模型中 WQ,WK,WVWQ,WK,WV 是完全独立的参数,但教学时若三套矩阵不同,需重复算 3 遍。示例强制三者相等,只需算一次线性变换,读者吃透一次矩阵乘法即可。
3. 全局碰撞:QK 点积打分
有了 Q 和 K,接下来就是计算序列中任意两个 Token 之间的相关性。
我们将 QQ 和 KK 的转置进行矩阵乘法:Scores = Q @ K.T。
这个操作的本质是计算每一对 (Q, K) 的点积。点积越大,说明这两个 Token 在语义或语法上的关联度越高。
维度变化 :[seq_len, d_k] @ [d_k, seq_len] ️ [seq_len, seq_len]
此时,我们得到了一个注意力分数矩阵,它记录了序列内部的全局依赖关系。
完整分步解析:QK 点积打分的矩阵乘法原理
第一步:数学公式与转置的意义
相似度分数矩阵: Score=Q⋅K⊤Score=Q⋅K⊤
K⊤K⊤ 是 K 矩阵的转置:把 K 的行列互换,方便做批量点积。
K=1412⇒K⊤=1142K=1142⇒K⊤=1412
第二步:矩阵乘法含义(最关键)
QN×dk⋅Kdk×N⊤=ScoreN×NQN×dk⋅Kdk×N⊤=ScoreN×N
输出方阵 N×NN×N ,每一个元素 Scorei,jScorei,j 含义:第 ii 个 token 的查询向量 QiQi ,和第 jj 个 token 的键向量 KjKj 的点积相似度。数值越大,代表两个 token 语义关联性越强。
第三步:完整计算打分矩阵
Score=1412⋅1142=1×1+4×41×1+4×21×1+2×41×1+2×2=17995Score=1142⋅1412=1×1+4×41×1+2×41×1+4×21×1+2×2=17995
逐个元素解读:
- Score1,1=17Score1,1=17 : t1t1 和自己的关联分数
- Score1,2=9Score1,2=9 : t1t1 和 t2t2 的关联分数
- Score2,1=9Score2,1=9 : t2t2 和 t1t1 的关联分数
- Score2,2=5Score2,2=5 : t2t2 和自己的关联分数
第四步:单点手动拆解(直观理解点积)
以 Score1,2Score1,2 举例: Q1=1,1,K2=4,2Q1=1,1,K2=4,2
点积 = 1×4+1×2=61×4+1×2=6
本质就是两个向量对应维度相乘再求和,衡量向量在高维空间的夹角:
- 向量方向越接近,点积结果越大;
- 方向垂直,点积 = 0;方向相反,点积为负数。
第五步:QK 点积打分的核心作用
- 全局两两匹配:不需要循环遍历,矩阵乘法一次性算出序列内所有 token 之间的关联强度;
- 无位置偏见(自注意力原生特性):不管两个 token 相隔多远,都会计算关联,解决 RNN 长距离依赖缺陷;
- 输出分数矩阵是后续缩放、softmax、加权 V 的输入基础。
4. 防止梯度消失:缩放机制 (Scale)
核心细节:在点积之后,必须除以 dkdk !
为什么必须有「缩放」?原始点积分数存在致命缺陷
1. 问题根源:Q、K 向量维度 dkdk 越大,点积越容易爆炸
向量点积公式:两个 dkdk 维向量 q=q1,q2...qdk,k=k1,k2...kdkq=q1,q2...qdk,k=k1,k2...kdk ,其点积为 q⋅k=q1k1+q2k2+...+qdkkdkq⋅k=q1k1+q2k2+...+qdkkdk 。
假设向量内每个值都服从标准正态分布 N(0,1)N(0,1) ,单个项 qikiqiki 的期望为 0,方差为 1。那么总和点积的方差就等于 dkdk ,标准差为 dkdk 。
当 dkdk 很大时(比如 BERT 单头维度 dk=64dk=64 , 64=864=8 ;更大模型 dk=128dk=128 , 128≈11.3128≈11.3 ),点积结果会变得极大,数值区间跨度巨大。
2. 数值爆炸带来的训练灾难:Softmax 梯度消失
Softmax 公式: softmax(xi)=exi∑jexjsoftmax(xi)=∑jexjexi
指数函数 exex 对输入极其敏感:如果输入分数很大(比如 20、30), e20e20 是天文数字,其他小数值的指数会被完全淹没。一行内只有最大值对应的权重接近 1,其余全部趋近 0。反向传播时梯度几乎为 0,权重无法更新,模型根本训不动。
3. 缩放操作:除以 dkdk ,拉平数值分布
缩放公式: ScaledScore=QK⊤dkScaledScore=dkQK⊤
作用:把点积的标准差重新拉回 1,不管 dkdk 是 64/128/256,分数稳定在小范围,指数不会爆炸。
用我们之前的 2-token 例子直观对比
我们示例中 dk=2dk=2 , 2≈1.41422≈1.4142 。
原始分数矩阵:
Score=17995Score=17995
缩放后:
ScaledScore=17/1.41429/1.41429/1.41425/1.4142≈12.026.366.363.54ScaledScore=17/1.41429/1.41429/1.41425/1.4142≈12.026.366.363.54
如果不缩放, e17e17 会直接让第二行第一个权重无限趋近 0,模型完全忽略 t1t1 对 t2t2 的影响。缩放后,指数压力大幅缓解。
5. 概率化:Softmax 归一化 (A)
缩放只是稳定数值,不能直接用来乘 V 做信息融合,因为分数可正可负、无固定区间,无法直观代表"关注度占比";且每行所有 token 关注度没有约束,无法实现"分配有限注意力资源"。Softmax 归一就是解决这两个问题,逐行独立归一(一行对应一个 token ii ,该行所有值是 ii 对全部 token jj 的关注度)。
计算规则(拿示例第一行 12.02,6.3612.02,6.36 演算)
- 对每个分数取自然指数 exex ,消除负数,所有值变成正数:
e12.02≈166067,e6.36≈578e12.02≈166067,e6.36≈578 - 求该行所有指数值总和做分母:
166067+578=166645166067+578=166645 - 每个指数值 ÷ 总和,得到归一权重:
a11=166067/166645≈0.997a11=166067/166645≈0.997
a12=578/166645≈0.003a12=578/166645≈0.003
Softmax 归一后两大核心特性
- 所有权重 ∈(0,1)∈(0,1) :可以理解成概率、关注度比例;
- 单行所有权重之和严格 = 1:对 token t1t1 来说,它分配给所有 token 的注意力总和是 100%,只是权重大小不同。本例中 t1t1 把 99.7% 注意力给到自己,仅 0.3% 留给 t2t2 。
归一矩阵 A 的意义
A=0.9970.0030.0030.997A=0.9970.0030.0030.997
Ai,jAi,j = token ii 在融合自身新向量时,从 token jj 的 V 向量中提取多少信息。后续输出计算: Output=A⋅VOutput=A⋅V ,用归一后的权重对 V 加权求和,完成全局信息融合。
缩放 vs 归一:职能完全区分(别混淆)
表格
| 模块 | 操作 | 输入 | 核心目的 | 有无可训练参数 |
|---|---|---|---|---|
| 缩放机制 | ÷dk÷dk | QK 点积原始分数 | 约束数值范围,防止 Softmax 指数爆炸、梯度消失 | 无,纯固定数学运算 |
| Softmax 归一 | 指数 + 按行求和归一 | 缩放后的稳定分数 | 将分数转为总和为 1 的注意力分配权重,用于加权 Value | 无,纯固定数学运算 |
补充高频面试细节
- 为什么不用其他归一(比如 Min-Max),非要 Softmax?
Min-Max 只会把数值压缩到 0~1,不能保证每行和为 1,无法做加权求和;Softmax 天然具备"相对竞争"特性:一个 token 关注度升高,其他自动降低,完美模拟"注意力资源有限"的逻辑。 - 缩放只作用在 QK 打分阶段,和 V 无关
V 全程不参与相似度计算,缩放只针对匹配分数,不改变 Value 携带的语义内容。 - 完整串联顺序(不能颠倒)
QK⊤QK⊤ 原始分数 → 缩放 → Softmax 归一得到权重矩阵 A → A×VA×V 。颠倒顺序(先 Softmax 再缩放)会完全失效:先指数放大数值,再缩放于事无补,梯度消失问题依旧存在。
6. 信息提取:A 加权 V
这是整个流程的"收网"阶段。
我们将注意力权重矩阵 AA 与值矩阵 VV 进行矩阵乘法:Output = A @ V。
这个操作的物理意义是:根据注意力权重,对序列中所有 Token 的 Value 进行加权求和。如果某个 Token 的注意力权重极高,它的 Value 信息就会被大量融合到当前位置的输出中。
维度变化 :[seq_len, seq_len] @ [seq_len, d_v] ️ [seq_len, d_v]
7. 完美闭环:注意力输出
经过上述步骤,我们得到了形状为 [seq_len, d_v] 的张量。
在标准的 Transformer 架构中,通常 dv=dmodeldv=dmodel ,因此输出维度完美还原为 [seq_len, d_model]。这个输出不仅包含了 Token 自身的原始语义,还融合了整个序列的上下文信息(比如代词"它"在这里融合了前面提到的"猫"的语义)。
随后,这个输出会被送入残差连接(Add & Norm)和前馈神经网络(FFN),开启下一层的特征提取。
深度进阶:真实模型的权重到底在学什么?
回到我们开头的教学示例,示例中令 WQ=WK=WVWQ=WK=WV 且使用 1、0、-1 等整数,仅仅是为了手写计算的简化工具,无实际模型意义。在真实的 Transformer 中,情况要复杂且精妙得多:
1. 为什么 WQ,WK,WVWQ,WK,WV 必须是三套完全独立的参数?
Q、K、V 承担着完全不同的语义任务,需要各自专属的特征映射规则,绝不能共用一套变换逻辑:
- WQWQ (查询视角) :表达"我这个词,想从全句里找哪些相关信息"。例如句子"小猫追蝴蝶"中,token「追」的 Q 向量会被训练成偏向寻找「施动者(小猫)、对象(蝴蝶)」的特征。
- WKWK (被检索视角) :表达"我这个词,能提供什么信息给其他 token 匹配"。例如 token「小猫」的 K 向量,会编码出「动作发起主体」特征,方便其他词的 Q 匹配到它。
- WVWV (内容存储视角) :不参与相似度计算,只负责存储原始内容信息。就算「小猫」和「追」匹配度很高,V 向量只负责承载"小猫"本身的实体语义