自注意力机制(Self-Attention Mechanism)简单学习一

拒绝死记硬背!一文彻底打通 Transformer 注意力机制的数据流转全链路

在深度学习与自然语言处理(NLP)的江湖里,Transformer 绝对是当之无愧的"一代宗师"。但很多同学在面试或实际手写代码时,面对 TokenID → Embedding → QKV → Attention 这一长串流程,往往只能背出公式,却很难在脑海中建立起清晰的矩阵维度变化图景。

今天,我们就抛开晦涩的数学推导,以数据流转的视角,一步步硬核拆解 Transformer 的核心计算过程。建议配合代码和矩阵维度变化一起食用,效果更佳!


1. 离散到连续:TokenID → Embedding (X)

模型是不认识文字的,它只认识数字。

  • TokenID :首先,分词器(Tokenizer)将文本切分,并在词表(Vocabulary)中查找对应的整数索引。假设词表大小为 vocab_size,输入序列长度为 seq_len,我们得到一个形状为 [seq_len] 的整数张量。
  • Embedding (X) :模型内部维护着一个巨大的可学习矩阵 Embedding Matrix,形状为 [vocab_size, d_model]。TokenID 本质上就是在这个矩阵中做"查表"操作(One-hot 编码与矩阵相乘)。

维度变化[seq_len][seq_len, d_model]

此时,离散的符号变成了稠密的高维向量,语义相近的词在空间中的距离也会更近。


2. 身份的裂变:线性投影 (Q, K, V)

这是自注意力机制(Self-Attention)的起点。为了让每个 Token 能够"表达需求"并"提供信息",我们需要将输入向量 XX 分别通过三个独立的线性层(全连接层)进行投影:

  • Query (Q) :"我需要什么信息?"(查询向量)
  • Key (K) :"我能提供什么信息?"(键向量)
  • Value (V) :"我实际携带的信息内容。"(值向量)

这三个矩阵由可学习的权重 WQ,WK,WVWQ​,WK​,WV​ 生成。

维度变化

  • Q=X⋅WQQ=X⋅WQ ️ [seq_len, d_k]
  • K=X⋅WKK=X⋅WK ️ [seq_len, d_k]
  • V=X⋅WVV=X⋅WV ️ [seq_len, d_v]

(注:在标准 Transformer 中,通常 dk=dv=dmodeldk​=dv​=dmodel​ )

完整分步解析:线性投影到底在算什么?

为了让你彻底看透线性投影的本质,我们设定一个极简场景:

  • 输入文本:2 个 token( t1,t2t1,t2 ),序列长度 N=2N=2
  • 原始嵌入维度: dmodel=4dmodel=4 (每个 token 的基础语义向量是 4 维)
  • 投影输出维度: dk=2dk=2 (Q、K 向量统一压缩到 2 维)

前置数据准备

输入嵌入矩阵 X2×4X2×4​ :

X=2101102−1X=21​10​02​1−1​

投影权重 WQ,WK,WVWQ​,WK​,WV​ 形状均为 4×24×2 (为简化理解,示例中设成相同,真实模型中它们是完全独立的参数):

WQ=WK=WV=101−10110WQ​=WK​=WV​=​1101​0−110​​

第一步:什么是线性投影?

线性投影本质就是矩阵乘法:用一组可训练权重,把原始高维嵌入向量 XX ,映射到全新维度的专用向量空间,拆分出三种分工向量: Q=X⋅WQ,K=X⋅WK,V=X⋅WVQ=X⋅WQ​,K=X⋅WK​,V=X⋅WV​ 。

  • WQWQ :把原始语义转换成查询向量 Q,代表「当前 token 想要找什么信息」
  • WKWK :把原始语义转换成键向量 K,代表「每个 token 能提供什么信息」
  • WVWV :把原始语义转换成值向量 V,代表「每个 token 真正携带的内容」

第二步:维度匹配规则(为什么输出是 2×22×2 ?)

根据矩阵乘法约束 Am×n⋅Bn×p=Cm×pAm×n​⋅Bn×p​=Cm×p​ :

  • XX 的形状是 2×42×4 (2 行 token,4 列原始特征)
  • WQWQ 的形状是 4×24×2 (4 行匹配输入维度,2 列定义输出维度 dkdk )
  • 输出 QQ 的形状必然是 2×22×2 (2 行 token,2 列投影后特征)

第三步:手动计算投影过程

取 t1t1​ 嵌入行向量 2,1,0,12,1,0,1 乘 WQWQ​ :

  • Q 第 1 维:2×1+1×0+0×1+1×(−1)=12×1+1×0+0×1+1×(−1)=1
  • Q 第 2 维:2×0+1×1+0×1+1×0=12×0+1×1+0×1+1×0=1
    得到 Q1=1,1Q1=1,1

取 t2t2​ 嵌入行向量 1,0,2,−11,0,2,−1 乘 WQWQ​ :

  • Q 第 1 维:1×1+0×0+2×1+(−1)×(−1)=41×1+0×0+2×1+(−1)×(−1)=4
  • Q 第 2 维:1×0+0×1+2×1+(−1)×0=21×0+0×1+2×1+(−1)×0=2
    得到 Q2=4,2Q2=4,2

最终投影结果:

Q=K=V=1412Q=K=V=11​42​

第四步:线性投影的核心作用

  • 分工解耦:原始嵌入向量同时包含语义、位置、语法信息,通过三组权重拆分出「查询、匹配、内容」三种独立表征;
  • 维度压缩 / 变换:自由控制 QK 向量维度 dkdk ,多头注意力里会把大 dmodeldmodel 切分成多个小 dkdk ;
  • 赋予模型拟合能力: WQ/WK/WVWQ/WK/WV 是可训练参数,模型会自动学习怎么映射才能精准计算 token 间关联。

核心提示:教学示例 vs 真实模型

这两组矩阵( XX 嵌入、 WW 权重)是人为自定义的教学演示数据,不是模型训练出来的真实参数,目的是降低理解门槛,全程能手动笔算、一眼看懂维度与运算逻辑。

为什么初始化 XX 为 2×42×4 ?

  1. 控制超参: N=2N=2 序列极短,后续算 QKQK 点积得到的相似度矩阵只有 2×22×2 ,数字少、计算量小; dmodel=4dmodel=4 既能完整演示线性变换,又不会出现超大维度导致手算繁琐。
  2. 数值搭配:混合正数、0、负数,模拟真实嵌入特征;数字小,相乘后输出全是整数,避开小数,新手笔算无压力。
  3. 可视化区分:一行 = 一个 token 的完整语义向量,帮你建立「矩阵行代表 token」的思维习惯。

为什么初始化 WW 为 4×24×2 ?

  1. 维度严格匹配:输入维度 4 → 权重矩阵必须是 4 行;输出维度 2 → 权重矩阵必须是 2 列。
  2. 极简系数:大量 0、1、-1 减少乘法计算量,包含 -1 模拟真实权重可正向/反向抑制特征。
  3. 降低负担:真实模型中 WQ,WK,WVWQ,WK,WV 是完全独立的参数,但教学时若三套矩阵不同,需重复算 3 遍。示例强制三者相等,只需算一次线性变换,读者吃透一次矩阵乘法即可。

3. 全局碰撞:QK 点积打分

有了 Q 和 K,接下来就是计算序列中任意两个 Token 之间的相关性

我们将 QQ 和 KK 的转置进行矩阵乘法:Scores = Q @ K.T

这个操作的本质是计算每一对 (Q, K) 的点积。点积越大,说明这两个 Token 在语义或语法上的关联度越高。

维度变化[seq_len, d_k] @ [d_k, seq_len][seq_len, seq_len]

此时,我们得到了一个注意力分数矩阵,它记录了序列内部的全局依赖关系。

完整分步解析:QK 点积打分的矩阵乘法原理

第一步:数学公式与转置的意义

相似度分数矩阵: Score=Q⋅K⊤Score=Q⋅K⊤

K⊤K⊤ 是 K 矩阵的转置:把 K 的行列互换,方便做批量点积。

K=1412⇒K⊤=1142K=11​42​⇒K⊤=14​12​

第二步:矩阵乘法含义(最关键)

QN×dk⋅Kdk×N⊤=ScoreN×NQN×dk​​⋅Kdk​×N⊤​=ScoreN×N​

输出方阵 N×NN×N ,每一个元素 Scorei,jScorei,j​ 含义:第 ii 个 token 的查询向量 QiQi​ ,和第 jj 个 token 的键向量 KjKj​ 的点积相似度。数值越大,代表两个 token 语义关联性越强。

第三步:完整计算打分矩阵

Score=14121142=1×1+4×41×1+4×21×1+2×41×1+2×2=17995Score=11​42​14​12​=1×1+4×41×1+2×4​1×1+4×21×1+2×2​=179​95​

逐个元素解读:

  • Score1,1=17Score1,1=17 : t1t1 和自己的关联分数
  • Score1,2=9Score1,2=9 : t1t1 和 t2t2 的关联分数
  • Score2,1=9Score2,1=9 : t2t2 和 t1t1 的关联分数
  • Score2,2=5Score2,2=5 : t2t2 和自己的关联分数

第四步:单点手动拆解(直观理解点积)

以 Score1,2Score1,2​ 举例: Q1=1,1,K2=4,2Q1​=1,1,K2​=4,2

点积 = 1×4+1×2=61×4+1×2=6

本质就是两个向量对应维度相乘再求和,衡量向量在高维空间的夹角:

  • 向量方向越接近,点积结果越大;
  • 方向垂直,点积 = 0;方向相反,点积为负数。

第五步:QK 点积打分的核心作用

  • 全局两两匹配:不需要循环遍历,矩阵乘法一次性算出序列内所有 token 之间的关联强度;
  • 无位置偏见(自注意力原生特性):不管两个 token 相隔多远,都会计算关联,解决 RNN 长距离依赖缺陷;
  • 输出分数矩阵是后续缩放、softmax、加权 V 的输入基础

4. 防止梯度消失:缩放机制 (Scale)

核心细节:在点积之后,必须除以 dkdk​​ !

为什么必须有「缩放」?原始点积分数存在致命缺陷

1. 问题根源:Q、K 向量维度 dkdk​ 越大,点积越容易爆炸

向量点积公式:两个 dkdk​ 维向量 q=q1,q2...qdk,k=k1,k2...kdkq=q1​,q2​...qdk​​,k=k1​,k2​...kdk​​ ,其点积为 q⋅k=q1k1+q2k2+...+qdkkdkq⋅k=q1​k1​+q2​k2​+...+qdk​​kdk​​ 。

假设向量内每个值都服从标准正态分布 N(0,1)N(0,1) ,单个项 qikiqi​ki​ 的期望为 0,方差为 1。那么总和点积的方差就等于 dkdk​ ,标准差为 dkdk​​ 。

当 dkdk​ 很大时(比如 BERT 单头维度 dk=64dk​=64 , 64=864​=8 ;更大模型 dk=128dk​=128 , 128≈11.3128​≈11.3 ),点积结果会变得极大,数值区间跨度巨大。

2. 数值爆炸带来的训练灾难:Softmax 梯度消失

Softmax 公式: softmax(xi)=exi∑jexjsoftmax(xi​)=∑j​exj​exi​​

指数函数 exex 对输入极其敏感:如果输入分数很大(比如 20、30), e20e20 是天文数字,其他小数值的指数会被完全淹没。一行内只有最大值对应的权重接近 1,其余全部趋近 0。反向传播时梯度几乎为 0,权重无法更新,模型根本训不动。

3. 缩放操作:除以 dkdk​​ ,拉平数值分布

缩放公式: ScaledScore=QK⊤dkScaledScore=dk​​QK⊤​

作用:把点积的标准差重新拉回 1,不管 dkdk​ 是 64/128/256,分数稳定在小范围,指数不会爆炸。

用我们之前的 2-token 例子直观对比

我们示例中 dk=2dk​=2 , 2≈1.41422​≈1.4142 。

原始分数矩阵:

Score=17995Score=179​95​

缩放后:

ScaledScore=17/1.41429/1.41429/1.41425/1.414212.026.366.363.54ScaledScore=17/1.41429/1.4142​9/1.41425/1.4142​12.026.36​6.363.54​

如果不缩放, e17e17 会直接让第二行第一个权重无限趋近 0,模型完全忽略 t1t1​ 对 t2t2​ 的影响。缩放后,指数压力大幅缓解。


5. 概率化:Softmax 归一化 (A)

缩放只是稳定数值,不能直接用来乘 V 做信息融合,因为分数可正可负、无固定区间,无法直观代表"关注度占比";且每行所有 token 关注度没有约束,无法实现"分配有限注意力资源"。Softmax 归一就是解决这两个问题,逐行独立归一(一行对应一个 token ii ,该行所有值是 ii 对全部 token jj 的关注度)。

计算规则(拿示例第一行 12.02,6.3612.02,6.36 演算)
  1. 对每个分数取自然指数 exex ,消除负数,所有值变成正数:
    e12.02≈166067,e6.36≈578e12.02≈166067,e6.36≈578
  2. 求该行所有指数值总和做分母:
    166067+578=166645166067+578=166645
  3. 每个指数值 ÷ 总和,得到归一权重:
    a11=166067/166645≈0.997a11=166067/166645≈0.997
    a12=578/166645≈0.003a12=578/166645≈0.003
Softmax 归一后两大核心特性
  • 所有权重 ∈(0,1)∈(0,1) :可以理解成概率、关注度比例;
  • 单行所有权重之和严格 = 1:对 token t1t1 来说,它分配给所有 token 的注意力总和是 100%,只是权重大小不同。本例中 t1t1 把 99.7% 注意力给到自己,仅 0.3% 留给 t2t2 。
归一矩阵 A 的意义

A=0.9970.0030.0030.997A=0.9970.003​0.0030.997​

Ai,jAi,j​ = token ii 在融合自身新向量时,从 token jj 的 V 向量中提取多少信息。后续输出计算: Output=A⋅VOutput=A⋅V ,用归一后的权重对 V 加权求和,完成全局信息融合。

缩放 vs 归一:职能完全区分(别混淆)

表格

模块 操作 输入 核心目的 有无可训练参数
缩放机制 ÷dk÷dk​​ QK 点积原始分数 约束数值范围,防止 Softmax 指数爆炸、梯度消失 无,纯固定数学运算
Softmax 归一 指数 + 按行求和归一 缩放后的稳定分数 将分数转为总和为 1 的注意力分配权重,用于加权 Value 无,纯固定数学运算
补充高频面试细节
  1. 为什么不用其他归一(比如 Min-Max),非要 Softmax?
    Min-Max 只会把数值压缩到 0~1,不能保证每行和为 1,无法做加权求和;Softmax 天然具备"相对竞争"特性:一个 token 关注度升高,其他自动降低,完美模拟"注意力资源有限"的逻辑。
  2. 缩放只作用在 QK 打分阶段,和 V 无关
    V 全程不参与相似度计算,缩放只针对匹配分数,不改变 Value 携带的语义内容。
  3. 完整串联顺序(不能颠倒)
    QK⊤QK⊤ 原始分数 → 缩放 → Softmax 归一得到权重矩阵 A → A×VA×V 。颠倒顺序(先 Softmax 再缩放)会完全失效:先指数放大数值,再缩放于事无补,梯度消失问题依旧存在。

6. 信息提取:A 加权 V

这是整个流程的"收网"阶段。

我们将注意力权重矩阵 AA 与值矩阵 VV 进行矩阵乘法:Output = A @ V

这个操作的物理意义是:根据注意力权重,对序列中所有 Token 的 Value 进行加权求和。如果某个 Token 的注意力权重极高,它的 Value 信息就会被大量融合到当前位置的输出中。

维度变化[seq_len, seq_len] @ [seq_len, d_v][seq_len, d_v]


7. 完美闭环:注意力输出

经过上述步骤,我们得到了形状为 [seq_len, d_v] 的张量。

在标准的 Transformer 架构中,通常 dv=dmodeldv​=dmodel​ ,因此输出维度完美还原为 [seq_len, d_model]。这个输出不仅包含了 Token 自身的原始语义,还融合了整个序列的上下文信息(比如代词"它"在这里融合了前面提到的"猫"的语义)。

随后,这个输出会被送入残差连接(Add & Norm)和前馈神经网络(FFN),开启下一层的特征提取。


深度进阶:真实模型的权重到底在学什么?

回到我们开头的教学示例,示例中令 WQ=WK=WVWQ​=WK​=WV​ 且使用 1、0、-1 等整数,仅仅是为了手写计算的简化工具,无实际模型意义。在真实的 Transformer 中,情况要复杂且精妙得多:

1. 为什么 WQ,WK,WVWQ​,WK​,WV​ 必须是三套完全独立的参数?

Q、K、V 承担着完全不同的语义任务,需要各自专属的特征映射规则,绝不能共用一套变换逻辑:

  • WQWQ​ (查询视角) :表达"我这个词,想从全句里找哪些相关信息"。例如句子"小猫追蝴蝶"中,token「追」的 Q 向量会被训练成偏向寻找「施动者(小猫)、对象(蝴蝶)」的特征。
  • WKWK​ (被检索视角) :表达"我这个词,能提供什么信息给其他 token 匹配"。例如 token「小猫」的 K 向量,会编码出「动作发起主体」特征,方便其他词的 Q 匹配到它。
  • WVWV​ (内容存储视角) :不参与相似度计算,只负责存储原始内容信息。就算「小猫」和「追」匹配度很高,V 向量只负责承载"小猫"本身的实体语义
相关推荐
Lihua奏3 小时前
# 从 printf 到屏幕:程序和硬件之间发生了什么
架构
shepherd1114 小时前
吞吐量提升 10 倍:高并发大批量数据处理任务的架构演进与性能调优
java·后端·架构
禅思院6 小时前
前端部署“三层漏斗”完全指南:从CI/CD到自动回滚的工程化实战【开题】
前端·架构·前端框架
Patrick_Wilson1 天前
幂等到底是什么?从前端视角讲透 SQL、HTTP 与 POST 接口的幂等设计
前端·后端·架构
禅思院1 天前
Vite vs Webpack 深度对比:从启动原理到生产构建,一篇就够了
前端·架构·前端框架