Transformer训练与生成背后的数学基础

Transformer的训练与生成能力,本质是以概率论为核心的序列建模框架 ,结合神经网络通用逼近定理可学习的注意力核函数基于梯度下降的经验风险最小化构建的完整数学体系。以下从核心数学假设、训练目标的支撑公式、核心公式的严谨推导三个维度展开。

一、Transformer训练的核心数学假设

Transformer的所有训练逻辑都建立在4个可证明的数学前提之上,是其能够拟合数据、生成信息的底层支撑:

  1. 自回归序列建模的链式法则假设 :任意长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T的token序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X = ( x 1 , x 2 , ... , x T ) X=(x_1,x_2,\dots,x_T) </math>X=(x1,x2,...,xT),其联合概率分布可通过概率论链式法则,拆解为每个位置token基于前文上下文的条件概率乘积,彻底规避RNN类模型的串行计算缺陷,是生成式Transformer的底层数学逻辑。
  2. 神经网络通用逼近定理:包含足够多隐藏单元的前馈网络(Transformer的FFN层)与多头注意力模块,能够以任意精度逼近紧集上的任意连续序列到序列映射函数,证明了Transformer的拟合能力上限。
  3. 自注意力的核平滑建模假设:自注意力本质是可学习的内积核加权求和,能够建模序列中任意两个位置的全局依赖关系,突破了RNN类模型的长距离依赖瓶颈,其数学本质是通过相似度权重实现序列信息的最优聚合。
  4. 经验风险最小化与最大似然估计的一致性假设:Transformer的训练目标等价于在训练集上最大化数据的似然概率,即最小化模型分布与真实数据分布的差异,这是监督训练的核心数学准则。

二、支撑Transformer训练目标的核心数学模型与公式

Transformer的训练目标,本质是找到最优模型参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ ,让模型分布尽可能逼近真实数据的分布,核心由以下公式体系支撑:

1. 序列建模的概率分解(生成逻辑的核心)

对于训练集中任意序列样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> X ( i ) = ( x 1 ( i ) , x 2 ( i ) , ... , x T i ( i ) ) X^{(i)}=(x_1^{(i)},x_2^{(i)},\dots,x_{T_i}^{(i)}) </math>X(i)=(x1(i),x2(i),...,xTi(i)),依据概率论链式法则,其联合概率可拆解为自回归条件概率乘积形式,这也是Transformer生成文本的核心逻辑:

<math xmlns="http://www.w3.org/1998/Math/MathML"> P θ ( X ( i ) ) = ∏ t = 1 T i P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) P_\theta(X^{(i)}) = \prod_{t=1}^{T_i} P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>Pθ(X(i))=∏t=1TiPθ(xt(i)∣x1:t−1(i))

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>Pθ(xt(i)∣x1:t−1(i)):Transformer建模的、给定前 <math xmlns="http://www.w3.org/1998/Math/MathML"> t − 1 t-1 </math>t−1个token时,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t个位置token的条件概率;
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 : t − 1 x_{1:t-1} </math>x1:t−1:前 <math xmlns="http://www.w3.org/1998/Math/MathML"> t − t- </math>t−个token组成的上下文序列, <math xmlns="http://www.w3.org/1998/Math/MathML"> t = t= </math>t=时为空序列。

2. 训练目标:最大似然估计(MLE)

最大似然估计是Transformer监督训练的核心数学准则,目标是找到最优参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ∗ \theta^* </math>θ∗,让模型生成整个训练集样本的概率最大化:

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ∗ = arg ⁡ max ⁡ θ L ( θ ; D ) \theta^* = \arg\max_\theta \mathcal{L}(\theta; \mathcal{D}) </math>θ∗=argmaxθL(θ;D)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> D = { X ( 1 ) , X ( 2 ) , ... , X ( N ) } \mathcal{D}=\{X^{(1)},X^{(2)},\dots,X^{(N)}\} </math>D={X(1),X(2),...,X(N)}代表完整训练集,对数似然函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ; D ) \mathcal{L}(\theta; \mathcal{D}) </math>L(θ;D)具体表达式为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ; D ) = ∑ i = 1 N log ⁡ P θ ( X ( i ) ) = ∑ i = 1 N ∑ t = 1 T i log ⁡ P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) \mathcal{L}(\theta; \mathcal{D}) = \sum_{i=1}^N \log P_\theta(X^{(i)}) = \sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>L(θ;D)=∑i=1NlogPθ(X(i))=∑i=1N∑t=1TilogPθ(xt(i)∣x1:t−1(i))

取对数是为了避免概率连乘导致的数值下溢,且对数函数为单调递增函数,最大化对数似然与最大化原始似然完全等价。

3. 损失函数:交叉熵损失(与MLE完全等价)

实际训练中,我们通过最小化损失函数实现参数优化,而最小化交叉熵损失等价于最大化对数似然

对于单个token预测任务,设真实标签为one-hot向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> y t y_t </math>yt(仅真实token位置为1,其余位置全为0),模型输出概率分布为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ t = P θ ( ⋅ ∣ x 1 : t − 1 ) \hat{y}t=P\theta(\cdot \mid x_{1:t-1}) </math>y^t=Pθ(⋅∣x1:t−1),单位置交叉熵损失公式为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L C E ( y t , y ^ t ) = − ∑ v = 1 V y t ( v ) ⋅ log ⁡ y ^ t ( v ) \mathcal{L}_{CE}(y_t, \hat{y}t) = -\sum{v=1}^V y_t(v) \cdot \log \hat{y}_t(v) </math>LCE(yt,y^t)=−∑v=1Vyt(v)⋅logy^t(v)

其中$$$$为词表大小。由于y_是one-hot向量,公式可简化为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L C E ( y t , y ^ t ) = − log ⁡ P θ ( x t ∣ x 1 : t − 1 ) \mathcal{L}{CE}(y_t, \hat{y}t) = -\log P\theta(x_t \mid x{1:t-1}) </math>LCE(yt,y^t)=−logPθ(xt∣x1:t−1)

整个训练集的总损失为所有样本所有位置的交叉熵之和:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L t o t a l ( θ ; D ) = ∑ i = 1 N ∑ t = 1 T i L C E ( y t ( i ) , y ^ t ( i ) ) = − L ( θ ; D ) \mathcal{L}{total}(\theta; \mathcal{D}) = \sum{i=1}^N \sum_{t=1}^{T_i} \mathcal{L}_{CE}(y_t^{(i)}, \hat{y}_t^{(i)}) = -\mathcal{L}(\theta; \mathcal{D}) </math>Ltotal(θ;D)=∑i=1N∑t=1TiLCE(yt(i),y^t(i))=−L(θ;D)

至此,训练目标从"最大化对数似然"转化为"最小化交叉熵损失",这是Transformer训练的核心优化目标。

4. 核心组件:缩放点积注意力的数学公式

自注意力是Transformer的核心结构,其前向计算的数学公式为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> Attention ( Q , K , V ) = Softmax ( Q K ⊤ d k ) V \text{Attention}(Q,K,V) = \text{Softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V </math>Attention(Q,K,V)=Softmax(dk QK⊤)V

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> Q = X W Q ,    K = X W K ,    V = X W V Q=XW_Q,\; K=XW_K,\; V=XW_V </math>Q=XWQ,K=XWK,V=XWV:输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X经过可学习投影矩阵变换得到的查询、键、值矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q 、 W K 、 W V W_Q、W_K、W_V </math>WQ、WK、WV是模型核心训练参数;
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk :缩放因子,避免内积值过大导致Softmax梯度消失;
  • Softmax:将注意力得分转化为和为1的权重,实现对值矩阵的加权求和。

5. 优化的数学基础:梯度下降与反向传播

Transformer的参数更新基于梯度下降法,核心是通过反向传播的链式法则,计算损失函数对每个参数的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ θ L t o t a l \nabla_\theta \mathcal{L}_{total} </math>∇θLtotal,并沿梯度反方向更新参数:

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 = θ t − η ⋅ ∇ θ L t o t a l ( θ t ) \theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta \mathcal{L}_{total}(\theta_t) </math>θt+1=θt−η⋅∇θLtotal(θt)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η为学习率,实际训练中通常使用AdamW优化器实现自适应梯度更新。

三、核心公式的详细推导

推导1:交叉熵损失与最大似然估计的等价性推导

这是Transformer训练目标最核心的数学证明,完整推导如下:

前置定义与前提

  1. 训练集 <math xmlns="http://www.w3.org/1998/Math/MathML"> D = { X ( 1 ) , X ( 2 ) , ... , X ( N ) } \mathcal{D}=\{X^{(1)},X^{(2)},\dots,X^{(N)}\} </math>D={X(1),X(2),...,X(N)}为独立同分布(i.i.d.)样本,每个样本服从真实数据分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P data ( X ) P_{\text{data}}(X) </math>Pdata(X);
  2. 模型分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P θ ( X ) P_\theta(X) </math>Pθ(X)由Transformer参数化,训练核心目标是缩小 <math xmlns="http://www.w3.org/1998/Math/MathML"> P θ ( X ) P_\theta(X) </math>Pθ(X)与 <math xmlns="http://www.w3.org/1998/Math/MathML"> P data ( X ) P_{\text{data}}(X) </math>Pdata(X)的分布差异;

步骤1:写出似然函数

  1. 对数函数为严格单调递增函数,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> arg ⁡ max ⁡ θ L ( θ ) = arg ⁡ max ⁡ θ log ⁡ L ( θ ) \arg\max_\theta \mathcal{L}(\theta) = \arg\max_\theta \log \mathcal{L}(\theta) </math>argmaxθL(θ)=argmaxθlogL(θ),最大化似然等价于最大化对数似然。

由于样本独立同分布,模型生成整个训练集的联合概率(似然函数)为各样本概率的乘积:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( θ ; D ) = P θ ( D ; θ ) = ∏ i = 1 N P θ ( X ( i ) ; θ ) \mathcal{L}(\theta; \mathcal{D}) = P_\theta(\mathcal{D}; \theta) = \prod_{i=1}^N P_\theta(X^{(i)}; \theta) </math>L(θ;D)=Pθ(D;θ)=∏i=1NPθ(X(i);θ)

步骤2:转化为对数似然函数

对似然函数取自然对数,将乘积转化为求和,避免数值下溢:

<math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ L ( θ ; D ) = log ⁡ ( ∏ i = 1 N P θ ( X ( i ) ; θ ) ) = ∑ i = 1 N log ⁡ P θ ( X ( i ) ; θ ) \log \mathcal{L}(\theta; \mathcal{D}) = \log \left( \prod_{i=1}^N P_\theta(X^{(i)}; \theta) \right) = \sum_{i=1}^N \log P_\theta(X^{(i)}; \theta) </math>logL(θ;D)=log(∏i=1NPθ(X(i);θ))=∑i=1NlogPθ(X(i);θ)

MLE的目标转化为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ MLE ∗ = arg ⁡ max ⁡ θ ∑ i = 1 N log ⁡ P θ ( X ( i ) ; θ ) \theta^*{\text{MLE}} = \arg\max\theta \sum_{i=1}^N \log P_\theta(X^{(i)}; \theta) </math>θMLE∗=argmaxθ∑i=1NlogPθ(X(i);θ)

步骤3:代入自回归概率分解

根据概率论链式法则,序列的联合概率可分解为条件概率的乘积,取对数后得到:

<math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ P θ ( X ( i ) ) = log ⁡ ( ∏ t = 1 T i P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) ) = ∑ t = 1 T i log ⁡ P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) \log P_\theta(X^{(i)}) = \log \left( \prod_{t=1}^{T_i} P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) \right) = \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>logPθ(X(i))=log(∏t=1TiPθ(xt(i)∣x1:t−1(i)))=∑t=1TilogPθ(xt(i)∣x1:t−1(i))

代入对数似然函数,得到:

<math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ L ( θ ; D ) = ∑ i = 1 N ∑ t = 1 T i log ⁡ P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) \log \mathcal{L}(\theta; \mathcal{D}) = \sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>logL(θ;D)=∑i=1N∑t=1TilogPθ(xt(i)∣x1:t−1(i))

步骤4:转化为最小化负对数似然

机器学习中习惯最小化损失函数,因此将最大化对数似然转化为最小化负对数似然(NLL):

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ∗ = arg ⁡ min ⁡ θ − ∑ i = 1 N ∑ t = 1 T i log ⁡ P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) ⏟ 负对数似然NLL \theta^* = \arg\min_\theta \underbrace{ -\sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) }_{\text{负对数似然NLL}} </math>θ∗=argminθ负对数似然NLL −i=1∑Nt=1∑TilogPθ(xt(i)∣x1:t−1(i))

步骤5:证明NLL与交叉熵损失完全等价

首先给出离散分布的交叉熵定义:对于真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> 和模型分布 和模型分布 </math>和模型分布,交叉熵为

<math xmlns="http://www.w3.org/1998/Math/MathML"> H ( P , Q ) = − ∑ x P ( x ) log ⁡ Q ( x ) H(P,Q) = -\sum_{x} P(x) \log Q(x) </math>H(P,Q)=−∑xP(x)logQ(x)

其物理意义是用模型分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> 编码真实分布 编码真实分布 </math>编码真实分布的样本所需的平均比特数, <math xmlns="http://www.w3.org/1998/Math/MathML"> H ( P , Q H(P,Q </math>H(P,Q越小, <math xmlns="http://www.w3.org/1998/Math/MathML"> 越接近 越接近 </math>越接近。

对于序列中第$$$$个位置的预测任务:

  • 真实分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> P data P_{\text{data}} </math>Pdata是one-hot分布:仅在真实token <math xmlns="http://www.w3.org/1998/Math/MathML"> x t ( i ) x_t^{(i)} </math>xt(i)处取值为1,其余位置为0;
  • 模型分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q为 <math xmlns="http://www.w3.org/1998/Math/MathML"> P θ ( ⋅ ∣ x 1 : t − 1 ( i ) ) P_\theta(\cdot \mid x_{1:t-1}^{(i)}) </math>Pθ(⋅∣x1:t−1(i)),即Transformer输出的token概率分布。

代入交叉熵定义,该位置的交叉熵为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> H ( P data , P θ ) = − ∑ v = 1 V P data ( v ∣ x 1 : t − 1 ( i ) ) ⋅ log ⁡ P θ ( v ∣ x 1 : t − 1 ( i ) ) H(P_{\text{data}}, P_\theta) = -\sum_{v=1}^V P_{\text{data}}(v \mid x_{1:t-1}^{(i)}) \cdot \log P_\theta(v \mid x_{1:t-1}^{(i)}) </math>H(Pdata,Pθ)=−∑v=1VPdata(v∣x1:t−1(i))⋅logPθ(v∣x1:t−1(i))

由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> P d a t a P_{data} </math>Pdata是one-hot分布,仅 <math xmlns="http://www.w3.org/1998/Math/MathML"> v = x t ( i ) v=x_t^{(i)} </math>v=xt(i)时项非零,因此公式简化为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> H ( P data , P θ ) = − log ⁡ P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) H(P_{\text{data}}, P_\theta) = -\log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>H(Pdata,Pθ)=−logPθ(xt(i)∣x1:t−1(i))

这恰好就是该位置的负对数似然。

步骤6:总损失的等价性结论

整个训练集的总交叉熵损失,就是所有样本所有位置的交叉熵之和:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L C E = ∑ i = 1 N ∑ t = 1 T i H ( P data ( i , t ) , P θ ( i , t ) ) = − ∑ i = 1 N ∑ t = 1 T i log ⁡ P θ ( x t ( i ) ∣ x 1 : t − 1 ( i ) ) \mathcal{L}{CE} = \sum{i=1}^N \sum_{t=1}^{T_i} H(P_{\text{data}}^{(i,t)}, P_\theta^{(i,t)}) = -\sum_{i=1}^N \sum_{t=1}^{T_i} \log P_\theta(x_t^{(i)} \mid x_{1:t-1}^{(i)}) </math>LCE=∑i=1N∑t=1TiH(Pdata(i,t),Pθ(i,t))=−∑i=1N∑t=1TilogPθ(xt(i)∣x1:t−1(i))

该式与负对数似然完全一致。

最终结论:Transformer训练中最小化交叉熵损失,等价于在训练集上执行最大似然估计,核心目标是让模型分布无限逼近真实数据的分布,这就是Transformer能够学习数据规律、生成符合语义信息的核心数学支撑。


推导2:缩放点积注意力的反向传播梯度推导

自注意力是Transformer的核心组件,其参数训练依赖反向传播的梯度计算,完整推导如下:

前置定义

简化符号便于推导,设缩放点积注意力输入为 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q , K , V ∈ R T × d Q,K,V \in \mathbb{R}^{T \times d} </math>Q,K,V∈RT×d,统一令 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d v = d d_k=d_v=d </math>dk=dv=d,前向计算核心流程如下:

  1. 注意力得分矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> S = Q K ⊤ d ∈ R T × T S = \frac{Q K^\top}{\sqrt{d}} \in \mathbb{R}^{T \times T} </math>S=d QK⊤∈RT×T
  2. 注意力权重矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> A = Softmax ( S ) ∈ R T × T A = \text{Softmax}(S) \in \mathbb{R}^{T \times T} </math>A=Softmax(S)∈RT×T(按行做Softmax归一化,每行元素和为1)

反向传播核心目标:已知损失对输出矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> O O </math>O的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d O = ∇ O L ∈ R T × d dO = \nabla_O \mathcal{L} \in \mathbb{R}^{T \times d} </math>dO=∇OL∈RT×d,逐层推导损失对查询矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q、键矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、值矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d Q 、 d K 、 d V dQ、dK、dV </math>dQ、dK、dV,用于后续参数更新。

  1. 注意力输出矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> O = A ⋅ V ∈ R T × d O = A \cdot V \in \mathbb{R}^{T \times d} </math>O=A⋅V∈RT×d

步骤1:计算对 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d v dv </math>dv

由 <math xmlns="http://www.w3.org/1998/Math/MathML"> O = A O = A </math>O=A,根据矩阵求导法则,对 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A求偏导:

<math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ V j , k = ∑ i = 1 T ∂ L ∂ O i , k ⋅ ∂ O i , k ∂ V j , k = ∑ i = 1 T d O i , k ⋅ A i , j \frac{\partial \mathcal{L}}{\partial V_{j,k}} = \sum_{i=1}^T \frac{\partial \mathcal{L}}{\partial O_{i,k}} \cdot \frac{\partial O_{i,k}}{\partial V_{j,k}} = \sum_{i=1}^T dO_{i,k} \cdot A_{i,j} </math>∂Vj,k∂L=∑i=1T∂Oi,k∂L⋅∂Vj,k∂Oi,k=∑i=1TdOi,k⋅Ai,j

写成矩阵形式为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> d V = A ⊤ ⋅ d O dV = A^\top \cdot dO </math>dV=A⊤⋅dO

步骤2:计算对注意力权重A的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d A dA </math>dA

同样由 <math xmlns="http://www.w3.org/1998/Math/MathML"> O = A O = A </math>O=A,对$$$$的元素求偏导:

<math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ A i , j = ∑ k = 1 d ∂ L ∂ O i , k ⋅ ∂ O i , k ∂ A i , j = ∑ k = 1 d d O i , k ⋅ V j , k \frac{\partial \mathcal{L}}{\partial A_{i,j}} = \sum_{k=1}^d \frac{\partial \mathcal{L}}{\partial O_{i,k}} \cdot \frac{\partial O_{i,k}}{\partial A_{i,j}} = \sum_{k=1}^d dO_{i,k} \cdot V_{j,k} </math>∂Ai,j∂L=∑k=1d∂Oi,k∂L⋅∂Ai,j∂Oi,k=∑k=1ddOi,k⋅Vj,k

写成矩阵形式为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> d A = d O ⋅ V ⊤ dA = dO \cdot V^\top </math>dA=dO⋅V⊤

步骤3:计算对得分矩阵S的梯度dS

<math xmlns="http://www.w3.org/1998/Math/MathML"> A = Softmax ( S ) A = \text{Softmax}(S) </math>A=Softmax(S),先推导Softmax函数的单元素偏导规则。对于任意行向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> s ∈ R T s \in \mathbb{R}^T </math>s∈RT,Softmax归一化后输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> a i = e s i ∑ k = 1 T e s k = e s i Z a_i = \frac{e^{s_i}}{\sum_{k=1}^T e^{s_k}} = \frac{e^{s_i}}{Z} </math>ai=∑k=1Teskesi=Zesi,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> Z = ∑ k = 1 T e s k Z=\sum_{k=1}^T e^{s_k} </math>Z=∑k=1Tesk为该行的归一化常数,保证每行权重和为1。

对Softmax输出求导,分两种核心情况推导单元素偏微分:

  1. 同位置求导(i=j) : <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ a i ∂ s i = e s i ⋅ Z − e s i ⋅ e s i Z 2 = a i ( 1 − a i ) \frac{\partial a_i}{\partial s_i} = \frac{e^{s_i} \cdot Z - e^{s_i} \cdot e^{s_i}}{Z^2} = a_i (1 - a_i) </math>∂si∂ai=Z2esi⋅Z−esi⋅esi=ai(1−ai)
  2. 异位置求导(i≠j) : <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ a i ∂ s j = − e s i ⋅ e s j Z 2 = − a i a j \frac{\partial a_i}{\partial s_j} = \frac{ - e^{s_i} \cdot e^{s_j} }{Z^2} = - a_i a_j </math>∂sj∂ai=Z2−esi⋅esj=−aiaj

结合反向传播链式法则,损失对得分矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> S S </math>S的梯度,需要通过损失对注意力权重 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A的梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d A dA </math>dA递推得到,完整矩阵形式的梯度公式为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> d S = A ⊙ ( d A − row_sum ( d A ⊙ A ) ) dS = A \odot \left( dA - \text{row\_sum}(dA \odot A) \right) </math>dS=A⊙(dA−row_sum(dA⊙A))

公式符号说明: <math xmlns="http://www.w3.org/1998/Math/MathML"> ⊙ \odot </math>⊙代表哈达玛积(矩阵对应元素逐点相乘), <math xmlns="http://www.w3.org/1998/Math/MathML"> row_sum \text{row\_sum} </math>row_sum代表对矩阵每一行单独求和,再将结果广播至该行所有列,保持矩阵维度不变,这一步是为了适配Softmax行归一化的梯度特性,避免梯度计算偏差。

步骤4:计算查询矩阵Q与键矩阵K的梯度

得分矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> S S </math>S由查询矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q和键矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K通过缩放内积得到,核心公式为 <math xmlns="http://www.w3.org/1998/Math/MathML"> S = Q K ⊤ d S = \frac{Q K^\top}{\sqrt{d}} </math>S=d QK⊤,基于矩阵微分链式法则,分别推导对 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q和 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K的梯度:

  1. 对查询矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q求梯度:缩放因子保持不变,直接关联得分矩阵梯度与键矩阵转置,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> d Q = 1 d ⋅ d S ⋅ K dQ = \frac{1}{\sqrt{d}} \cdot dS \cdot K </math>dQ=d 1⋅dS⋅K
  2. 对键矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K求梯度:需要先对得分矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> S S </math>S转置,再关联梯度与查询矩阵,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> d K = 1 d ⋅ d S ⊤ ⋅ Q dK = \frac{1}{\sqrt{d}} \cdot dS^\top \cdot Q </math>dK=d 1⋅dS⊤⋅Q

缩放点积注意力反向传播最终梯度汇总

整合所有梯度推导结果,得到自注意力层完整的反向传播梯度公式,所有公式均适配标准LaTeX渲染规则,无复杂嵌套语法,确保正常显示:

<math xmlns="http://www.w3.org/1998/Math/MathML"> d V = A ⊤ ⋅ d O dV = A^\top \cdot dO </math>dV=A⊤⋅dO

<math xmlns="http://www.w3.org/1998/Math/MathML"> d A = d O ⋅ V ⊤ dA = dO \cdot V^\top </math>dA=dO⋅V⊤

<math xmlns="http://www.w3.org/1998/Math/MathML"> d S = A ⊙ ( d A − row_sum ( d A ⊙ A ) ) dS = A \odot \left( dA - \text{row\_sum}(dA \odot A) \right) </math>dS=A⊙(dA−row_sum(dA⊙A))

<math xmlns="http://www.w3.org/1998/Math/MathML"> d Q = 1 d ⋅ d S ⋅ K dQ = \frac{1}{\sqrt{d}} \cdot dS \cdot K </math>dQ=d 1⋅dS⋅K

<math xmlns="http://www.w3.org/1998/Math/MathML"> d K = 1 d ⋅ d S ⊤ ⋅ Q dK = \frac{1}{\sqrt{d}} \cdot dS^\top \cdot Q </math>dK=d 1⋅dS⊤⋅Q

这套完整的梯度推导流程,是Transformer自注意力模块参数更新的核心数学依据,结合反向传播链式法则,可将梯度逐层回传至模型所有可学习参数(投影矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q , W K , W V W_Q, W_K, W_V </math>WQ,WK,WV、前馈网络权重等),最终通过梯度下降完成模型训练,让模型逐步学习序列数据的内在分布规律。

相关推荐
CoovallyAIHub2 小时前
MSSP | 不停机不贴标监测旋转风机叶片:无人机+YOLOv5+DeepSORT,2MW 风机现场测试频率误差<2%
人工智能·架构
marteker2 小时前
Pinterest发布AI广告“增效秘籍”:全自动工具可降低超10%点击成本
人工智能·搜索引擎
喵叔哟2 小时前
29_内容生产质量网关Skill:草稿生成+事实校验+发布前检查
网络·人工智能
不开大的凯20772 小时前
B 端 AI 新图景:阿里悟空的战略价值与爱智能 ATOA 的行业实践
人工智能
lay_liu2 小时前
Spring Boot 自动配置
java·spring boot·后端
程序员cxuan2 小时前
说点掏心窝子的话
后端·程序员
写Cpp的小黑黑2 小时前
WebSocket 连通性测试方法
后端
开心就好20252 小时前
Windows 上传 IPA 到 App Store 的步骤讲解
后端·ios