04-残差连接与Pre-LN:让大模型的深度网络成为可能

深度网络的困境

在前面的章节中,我们学习了注意力机制、位置编码和MLP层。现在让我们把它们组合成一个完整的Transformer层:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 步骤1:多头注意力 X 1 = MultiHeadAttention ( X ) 步骤2:MLP前馈网络 X 2 = MLP ( X 1 ) \begin{aligned} &\text{步骤1:多头注意力} \\ &X_1 = \text{MultiHeadAttention}(X) \\ \\ &\text{步骤2:MLP前馈网络} \\ &X_2 = \text{MLP}(X_1) \end{aligned} </math>步骤1:多头注意力X1=MultiHeadAttention(X)步骤2:MLP前馈网络X2=MLP(X1)

问题来了:如果我们要堆叠很多层(比如GPT-3有96层),会发生什么?

梯度消失与梯度爆炸

深度神经网络的训练依赖反向传播:梯度从输出层一层层往回传,更新每一层的参数。

链式法则
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ W 1 = ∂ L ∂ X 96 ⋅ ∂ X 96 ∂ X 95 ⋅ ∂ X 95 ∂ X 94 ⋯ ∂ X 2 ∂ X 1 ⋅ ∂ X 1 ∂ W 1 \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial X_{96}} \cdot \frac{\partial X_{96}}{\partial X_{95}} \cdot \frac{\partial X_{95}}{\partial X_{94}} \cdots \frac{\partial X_2}{\partial X_1} \cdot \frac{\partial X_1}{\partial W_1} </math>∂W1∂L=∂X96∂L⋅∂X95∂X96⋅∂X94∂X95⋯∂X1∂X2⋅∂W1∂X1

这是一个连乘!

梯度消失

如果每一层的梯度都小于1(比如0.9):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 0. 9 96 ≈ 0.00003 (几乎为0!) 0.9^{96} \approx 0.00003 \quad \text{(几乎为0!)} </math>0.996≈0.00003(几乎为0!)

  • 底层(靠近输入)的梯度变得极小
  • 参数几乎不更新
  • 模型无法有效学习

梯度爆炸

如果每一层的梯度都大于1(比如1.1):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 1. 1 96 ≈ 8533 (爆炸!) 1.1^{96} \approx 8533 \quad \text{(爆炸!)} </math>1.196≈8533(爆炸!)

  • 梯度变得极大
  • 参数更新幅度过大
  • 训练不稳定,模型发散

历史事实

在ResNet(2015)之前,训练超过20层的网络都很困难。直接堆叠更多层,效果反而变差!

问题的本质

信息流的退化

在深度网络中,信息需要经过很多层的变换才能传递:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X → f 1 ( X ) → f 2 ( f 1 ( X ) ) → f 3 ( f 2 ( f 1 ( X ) ) ) → ⋯ X \to f_1(X) \to f_2(f_1(X)) \to f_3(f_2(f_1(X))) \to \cdots </math>X→f1(X)→f2(f1(X))→f3(f2(f1(X)))→⋯

  • 每经过一层,信息都会被"扭曲"和"压缩"
  • 层数越深,原始信息越难保留
  • 梯度也面临同样的问题

直观类比

想象一个传话游戏:

  • 第1个人对第2个人说:"今天天气真好"
  • 第2个人理解后传给第3个人:"今天不错"
  • 第3个人传给第4个人:"挺好"
  • ...
  • 第96个人听到的可能是:"好"(信息几乎丢失!)

这就是为什么需要残差连接!

残差连接(Residual Connection)

核心思想

残差连接的想法非常简单:在变换的同时,保留原始信息的"高速通道"

没有残差连接
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X out = f ( X in ) X_{\text{out}} = f(X_{\text{in}}) </math>Xout=f(Xin)

信息必须经过函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 的变换。

有残差连接
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X out = X in + f ( X in ) X_{\text{out}} = X_{\text{in}} + f(X_{\text{in}}) </math>Xout=Xin+f(Xin)

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> X in X_{\text{in}} </math>Xin:直接传递的原始信息(恒等映射,identity)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( X in ) f(X_{\text{in}}) </math>f(Xin):学习到的"残差"(residual,即修正/补充)
  • 两者相加:原始信息 + 修正

关键洞察

函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 不需要学习完整的映射,只需要学习"差异 "或"修正"!

数学原理

1. 梯度流的改善

有残差连接时,反向传播的链式法则变为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ X out ∂ X in = ∂ ∂ X in [ X in + f ( X in ) ] = I + ∂ f ( X in ) ∂ X in \frac{\partial X_{\text{out}}}{\partial X_{\text{in}}} = \frac{\partial}{\partial X_{\text{in}}} \left[ X_{\text{in}} + f(X_{\text{in}}) \right] = I + \frac{\partial f(X_{\text{in}})}{\partial X_{\text{in}}} </math>∂Xin∂Xout=∂Xin∂[Xin+f(Xin)]=I+∂Xin∂f(Xin)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I 是单位矩阵(恒等映射的梯度)。

关键 :即使 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ f ∂ X in \frac{\partial f}{\partial X_{\text{in}}} </math>∂Xin∂f 很小甚至为0,梯度仍然至少有 <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I(值为1)!
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ X in = ∂ L ∂ X out ⋅ ( I + ∂ f ∂ X in ) \frac{\partial L}{\partial X_{\text{in}}} = \frac{\partial L}{\partial X_{\text{out}}} \cdot \left( I + \frac{\partial f}{\partial X_{\text{in}}} \right) </math>∂Xin∂L=∂Xout∂L⋅(I+∂Xin∂f)

梯度可以直接通过恒等映射传递,不会消失!

2. 多层残差连接的累积效果

假设有 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 层,每层都有残差连接:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X 1 = X 0 + f 1 ( X 0 ) X 2 = X 1 + f 2 ( X 1 ) = X 0 + f 1 ( X 0 ) + f 2 ( X 1 ) X 3 = X 2 + f 3 ( X 2 ) = X 0 + f 1 ( X 0 ) + f 2 ( X 1 ) + f 3 ( X 2 ) ⋮ X n = X 0 + ∑ i = 1 n f i ( X i − 1 ) \begin{aligned} X_1 &= X_0 + f_1(X_0) \\ X_2 &= X_1 + f_2(X_1) = X_0 + f_1(X_0) + f_2(X_1) \\ X_3 &= X_2 + f_3(X_2) = X_0 + f_1(X_0) + f_2(X_1) + f_3(X_2) \\ &\vdots \\ X_n &= X_0 + \sum_{i=1}^{n} f_i(X_{i-1}) \end{aligned} </math>X1X2X3Xn=X0+f1(X0)=X1+f2(X1)=X0+f1(X0)+f2(X1)=X2+f3(X2)=X0+f1(X0)+f2(X1)+f3(X2)⋮=X0+i=1∑nfi(Xi−1)

发现 :最终输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> X n X_n </math>Xn 包含原始输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> X 0 X_0 </math>X0 加上所有层的"修正"累积!

梯度传播
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ X n ∂ X 0 = I + ∂ ∂ X 0 ∑ i = 1 n f i ( X i − 1 ) \frac{\partial X_n}{\partial X_0} = I + \frac{\partial}{\partial X_0} \sum_{i=1}^{n} f_i(X_{i-1}) </math>∂X0∂Xn=I+∂X0∂i=1∑nfi(Xi−1)

  • 始终有恒等项 <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I
  • 梯度可以直接从第 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 层传到第 0 层
  • 不会因为层数增加而消失

3. 直观理解:多条路径

残差连接创造了指数级的路径

对于3层网络:

  • 没有残差:1条路径( <math xmlns="http://www.w3.org/1998/Math/MathML"> f 3 ∘ f 2 ∘ f 1 f_3 \circ f_2 \circ f_1 </math>f3∘f2∘f1)
  • 有残差:8条路径!

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X 3 = X 0 + f 1 + f 2 + f 3 + f 1 ∘ f 2 + f 1 ∘ f 3 + f 2 ∘ f 3 + f 1 ∘ f 2 ∘ f 3 \begin{aligned} X_3 &= X_0 + f_1 + f_2 + f_3 \\ &\quad + f_1 \circ f_2 + f_1 \circ f_3 + f_2 \circ f_3 \\ &\quad + f_1 \circ f_2 \circ f_3 \end{aligned} </math>X3=X0+f1+f2+f3+f1∘f2+f1∘f3+f2∘f3+f1∘f2∘f3

每一层可以选择"使用"或"跳过",形成多条并行路径,梯度可以通过任意路径流动。

Transformer中的残差连接

在Transformer的每一层中,残差连接被应用在两个地方

1. 注意力子层

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X 1 = X + MultiHeadAttention ( X ) X_1 = X + \text{MultiHeadAttention}(X) </math>X1=X+MultiHeadAttention(X)

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X:子层的输入
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> MultiHeadAttention ( X ) \text{MultiHeadAttention}(X) </math>MultiHeadAttention(X):注意力的输出(学习到的修正)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> X 1 X_1 </math>X1:相加后的输出

2. MLP子层

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X 2 = X 1 + MLP ( X 1 ) X_2 = X_1 + \text{MLP}(X_1) </math>X2=X1+MLP(X1)

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> X 1 X_1 </math>X1:MLP子层的输入
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> MLP ( X 1 ) \text{MLP}(X_1) </math>MLP(X1):前馈网络的输出(学习到的修正)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> X 2 X_2 </math>X2:相加后的输出

完整的一层Transformer
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 步骤1:注意力+残差 X 1 = X + MultiHeadAttention ( X ) 步骤2:MLP+残差 X 2 = X 1 + MLP ( X 1 ) \begin{aligned} \text{步骤1:注意力+残差} \quad &X_1 = X + \text{MultiHeadAttention}(X) \\ \text{步骤2:MLP+残差} \quad &X_2 = X_1 + \text{MLP}(X_1) \end{aligned} </math>步骤1:注意力+残差步骤2:MLP+残差X1=X+MultiHeadAttention(X)X2=X1+MLP(X1)

残差连接的效果

实验证据(ResNet论文):

网络深度 无残差连接 有残差连接
18层 ✅ 能训练 ✅ 能训练
34层 ⚠️ 勉强训练 ✅ 能训练
50层 ❌ 难以训练 ✅ 能训练
101层 ❌ 无法训练 ✅ 能训练
152层 ❌ 无法训练 ✅ 能训练

Transformer的应用

  • GPT-3:96层
  • GPT-4:推测120+层
  • PaLM:118层

没有残差连接,这些深度模型不可能训练成功!

LayerNorm:稳定训练的另一块基石

残差连接解决了梯度流的问题,但还有一个问题:不同层、不同维度的激活值范围可能差异很大

为什么需要归一化?

问题示例

假设某一层的输出:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = [ 100 0.1 50 200 0.05 80 150 0.2 60 ] X = \begin{bmatrix} 100 & 0.1 & 50 \\ 200 & 0.05 & 80 \\ 150 & 0.2 & 60 \end{bmatrix} </math>X= 1002001500.10.050.2508060

  • 第1维:范围100-200(很大)
  • 第2维:范围0.05-0.2(很小)
  • 第3维:范围50-80(中等)

导致的问题

  1. 梯度不平衡

    • 大值维度的梯度很大
    • 小值维度的梯度很小
    • 参数更新不均衡
  2. 数值不稳定

    • softmax、sigmoid等函数对大数值敏感
    • 可能出现 <math xmlns="http://www.w3.org/1998/Math/MathML"> e 100 e^{100} </math>e100 导致溢出
  3. 学习效率低

    • 需要仔细调整学习率
    • 训练速度慢

LayerNorm的定义

Layer Normalization对每个样本的所有特征维度进行归一化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> LayerNorm ( x ) = γ ⋅ x − μ σ 2 + ϵ + β \text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta </math>LayerNorm(x)=γ⋅σ2+ϵ x−μ+β

参数解释

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R d model x \in \mathbb{R}^{d_{\text{model}}} </math>x∈Rdmodel:输入向量(单个Token的表示)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> μ = 1 d model ∑ i = 1 d model x i \mu = \frac{1}{d_{\text{model}}} \sum_{i=1}^{d_{\text{model}}} x_i </math>μ=dmodel1∑i=1dmodelxi:该向量的均值
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 = 1 d model ∑ i = 1 d model ( x i − μ ) 2 \sigma^2 = \frac{1}{d_{\text{model}}} \sum_{i=1}^{d_{\text{model}}} (x_i - \mu)^2 </math>σ2=dmodel1∑i=1dmodel(xi−μ)2:该向量的方差
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ:防止除零的小常数(通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 5 10^{-5} </math>10−5 或 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 0 − 6 10^{-6} </math>10−6)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β ∈ R d model \gamma, \beta \in \mathbb{R}^{d_{\text{model}}} </math>γ,β∈Rdmodel:可学习的缩放和平移参数

步骤分解
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 步骤1:计算均值 μ = 1 d ∑ i = 1 d x i 步骤2:计算方差 σ 2 = 1 d ∑ i = 1 d ( x i − μ ) 2 步骤3:标准化 x ^ i = x i − μ σ 2 + ϵ 步骤4:缩放和平移 y i = γ i ⋅ x ^ i + β i \begin{aligned} \text{步骤1:计算均值} \quad &\mu = \frac{1}{d} \sum_{i=1}^{d} x_i \\ \\ \text{步骤2:计算方差} \quad &\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 \\ \\ \text{步骤3:标准化} \quad &\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ \\ \text{步骤4:缩放和平移} \quad &y_i = \gamma_i \cdot \hat{x}_i + \beta_i \end{aligned} </math>步骤1:计算均值步骤2:计算方差步骤3:标准化步骤4:缩放和平移μ=d1i=1∑dxiσ2=d1i=1∑d(xi−μ)2x^i=σ2+ϵ xi−μyi=γi⋅x^i+βi

效果

  • 步骤3后: <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ \hat{x} </math>x^ 的均值为0,方差为1(强制标准化)
  • 步骤4:通过可学习的 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β,让模型自己决定最优的分布

<math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 是如何学习的?

<math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 是可学习的参数 ,和模型的权重矩阵(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W_Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 W_1 </math>W1 等)完全一样,通过反向传播和梯度下降进行训练。

1. 初始化

在训练开始前, <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 需要初始化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> γ = [ 1 , 1 , 1 , ... , 1 ] ∈ R d model (初始化为全1) β = [ 0 , 0 , 0 , ... , 0 ] ∈ R d model (初始化为全0) \begin{aligned} \gamma &= [1, 1, 1, \ldots, 1] \in \mathbb{R}^{d_{\text{model}}} \quad \text{(初始化为全1)} \\ \beta &= [0, 0, 0, \ldots, 0] \in \mathbb{R}^{d_{\text{model}}} \quad \text{(初始化为全0)} \end{aligned} </math>γβ=[1,1,1,...,1]∈Rdmodel(初始化为全1)=[0,0,0,...,0]∈Rdmodel(初始化为全0)

为什么这样初始化?

  • 这样初始状态下: <math xmlns="http://www.w3.org/1998/Math/MathML"> y = 1 ⋅ x ^ + 0 = x ^ y = 1 \cdot \hat{x} + 0 = \hat{x} </math>y=1⋅x^+0=x^
  • 相当于直接使用标准化后的结果(均值0,方差1)
  • 模型可以从这个"中性"状态开始学习最优分布

2. 前向传播

在前向传播中,LayerNorm计算输出:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y = γ ⊙ x ^ + β y = \gamma \odot \hat{x} + \beta </math>y=γ⊙x^+β

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ⊙ \odot </math>⊙ 表示逐元素乘法。

示例 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 4 d_{\text{model}}=4 </math>dmodel=4):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ = [ 0.169 , − 1.183 , − 0.507 , 1.521 ] (标准化后) γ = [ 1.2 , 0.8 , 1.5 , 0.9 ] (训练学到的) β = [ 0.1 , − 0.2 , 0.3 , 0.0 ] (训练学到的) y 1 = 1.2 × 0.169 + 0.1 = 0.303 y 2 = 0.8 × ( − 1.183 ) + ( − 0.2 ) = − 1.146 y 3 = 1.5 × ( − 0.507 ) + 0.3 = − 0.461 y 4 = 0.9 × 1.521 + 0.0 = 1.369 y = [ 0.303 , − 1.146 , − 0.461 , 1.369 ] \begin{aligned} \hat{x} &= [0.169, -1.183, -0.507, 1.521] \quad \text{(标准化后)} \\ \gamma &= [1.2, 0.8, 1.5, 0.9] \quad \text{(训练学到的)} \\ \beta &= [0.1, -0.2, 0.3, 0.0] \quad \text{(训练学到的)} \\ \\ y_1 &= 1.2 \times 0.169 + 0.1 = 0.303 \\ y_2 &= 0.8 \times (-1.183) + (-0.2) = -1.146 \\ y_3 &= 1.5 \times (-0.507) + 0.3 = -0.461 \\ y_4 &= 0.9 \times 1.521 + 0.0 = 1.369 \\ \\ y &= [0.303, -1.146, -0.461, 1.369] \end{aligned} </math>x^γβy1y2y3y4y=[0.169,−1.183,−0.507,1.521](标准化后)=[1.2,0.8,1.5,0.9](训练学到的)=[0.1,−0.2,0.3,0.0](训练学到的)=1.2×0.169+0.1=0.303=0.8×(−1.183)+(−0.2)=−1.146=1.5×(−0.507)+0.3=−0.461=0.9×1.521+0.0=1.369=[0.303,−1.146,−0.461,1.369]

3. 反向传播

在反向传播时,损失函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 的梯度会传到 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ γ i = ∂ L ∂ y i ⋅ ∂ y i ∂ γ i = ∂ L ∂ y i ⋅ x ^ i ∂ L ∂ β i = ∂ L ∂ y i ⋅ ∂ y i ∂ β i = ∂ L ∂ y i ⋅ 1 \begin{aligned} \frac{\partial L}{\partial \gamma_i} &= \frac{\partial L}{\partial y_i} \cdot \frac{\partial y_i}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \hat{x}_i \\ \\ \frac{\partial L}{\partial \beta_i} &= \frac{\partial L}{\partial y_i} \cdot \frac{\partial y_i}{\partial \beta_i} = \frac{\partial L}{\partial y_i} \cdot 1 \end{aligned} </math>∂γi∂L∂βi∂L=∂yi∂L⋅∂γi∂yi=∂yi∂L⋅x^i=∂yi∂L⋅∂βi∂yi=∂yi∂L⋅1

直观理解

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> γ i \gamma_i </math>γi 的梯度 = 下游梯度 × 标准化后的值 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ i \hat{x}_i </math>x^i
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β i \beta_i </math>βi 的梯度 = 下游梯度(直接传递)

4. 参数更新

使用优化器(如AdamW)更新参数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> γ ← γ − η ⋅ ∂ L ∂ γ β ← β − η ⋅ ∂ L ∂ β \begin{aligned} \gamma &\leftarrow \gamma - \eta \cdot \frac{\partial L}{\partial \gamma} \\ \beta &\leftarrow \beta - \eta \cdot \frac{\partial L}{\partial \beta} \end{aligned} </math>γβ←γ−η⋅∂γ∂L←β−η⋅∂β∂L

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η 是学习率。

和其他参数完全一样的训练过程!

python 复制代码
# PyTorch中的实现
class LayerNorm(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        # 定义可学习参数
        self.gamma = nn.Parameter(torch.ones(d_model))   # 初始化为1
        self.beta = nn.Parameter(torch.zeros(d_model))   # 初始化为0
        self.eps = 1e-6

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        var = x.var(dim=-1, keepdim=True)    # (batch, seq_len, 1)

        # 标准化
        x_norm = (x - mean) / torch.sqrt(var + self.eps)

        # 缩放和平移(gamma和beta会自动参与梯度更新)
        output = self.gamma * x_norm + self.beta

        return output

# 查看参数
ln = LayerNorm(d_model=768)
print(f"gamma是可学习参数: {ln.gamma.requires_grad}")  # True
print(f"beta是可学习参数: {ln.beta.requires_grad}")    # True
print(f"参数数量: {ln.gamma.numel() + ln.beta.numel()}")  # 768 + 768 = 1536

# 在训练时,optimizer会自动更新这些参数
optimizer = torch.optim.AdamW(ln.parameters(), lr=1e-3)

5. 为什么需要 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β?

标准化强制把数据变成均值0方差1,但这不一定是最优的分布

问题:某些层可能需要不同的均值和方差才能更好地学习。

解决方案 :通过可学习的 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β,让模型自己决定:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ:控制每个维度的"缩放"(方差)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β:控制每个维度的"偏移"(均值)

极端情况 :如果模型学到 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ i = σ i \gamma_i = \sigma_i </math>γi=σi, <math xmlns="http://www.w3.org/1998/Math/MathML"> β i = μ i \beta_i = \mu_i </math>βi=μi(原始的方差和均值),那就等于恢复了标准化之前的分布!
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y i = γ i ⋅ x ^ i + β i = σ i ⋅ x i − μ i σ i + μ i = x i y_i = \gamma_i \cdot \hat{x}_i + \beta_i = \sigma_i \cdot \frac{x_i - \mu_i}{\sigma_i} + \mu_i = x_i </math>yi=γi⋅x^i+βi=σi⋅σixi−μi+μi=xi

这给了模型自由度:可以保留归一化的好处,也可以根据需要调整分布。

6. 参数量占比

每个LayerNorm层的参数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 参数量 = d model × 2 = 768 × 2 = 1 , 536 参数 \text{参数量} = d_{\text{model}} \times 2 = 768 \times 2 = 1{,}536 \text{ 参数} </math>参数量=dmodel×2=768×2=1,536 参数

对比一个MLP层( <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 768 d_{\text{model}}=768 </math>dmodel=768, <math xmlns="http://www.w3.org/1998/Math/MathML"> d ff = 3072 d_{\text{ff}}=3072 </math>dff=3072):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MLP参数量 = 768 × 3072 + 3072 × 768 ≈ 4 , 700 , 000 参数 \text{MLP参数量} = 768 \times 3072 + 3072 \times 768 \approx 4{,}700{,}000 \text{ 参数} </math>MLP参数量=768×3072+3072×768≈4,700,000 参数

LayerNorm的参数量不到0.05%,几乎可以忽略不计!但它的作用却至关重要。

具体例子

假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 4 d_{\text{model}} = 4 </math>dmodel=4,某个Token的表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x = [ 100 , 0.1 , 50 , 200 ] x = [100, 0.1, 50, 200] </math>x=[100,0.1,50,200]

步骤1:计算均值
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ = 100 + 0.1 + 50 + 200 4 = 87.525 \mu = \frac{100 + 0.1 + 50 + 200}{4} = 87.525 </math>μ=4100+0.1+50+200=87.525

步骤2:计算方差
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> σ 2 = ( 100 − 87.525 ) 2 + ( 0.1 − 87.525 ) 2 + ( 50 − 87.525 ) 2 + ( 200 − 87.525 ) 2 4 = 156.14 + 7644.62 + 1406.14 + 12650.14 4 = 5464.26 \begin{aligned} \sigma^2 &= \frac{(100-87.525)^2 + (0.1-87.525)^2 + (50-87.525)^2 + (200-87.525)^2}{4} \\ &= \frac{156.14 + 7644.62 + 1406.14 + 12650.14}{4} \\ &= 5464.26 \end{aligned} </math>σ2=4(100−87.525)2+(0.1−87.525)2+(50−87.525)2+(200−87.525)2=4156.14+7644.62+1406.14+12650.14=5464.26
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> σ = 5464.26 ≈ 73.92 \sigma = \sqrt{5464.26} \approx 73.92 </math>σ=5464.26 ≈73.92

步骤3:标准化
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ 1 = 100 − 87.525 73.92 ≈ 0.169 x ^ 2 = 0.1 − 87.525 73.92 ≈ − 1.183 x ^ 3 = 50 − 87.525 73.92 ≈ − 0.507 x ^ 4 = 200 − 87.525 73.92 ≈ 1.521 \begin{aligned} \hat{x}_1 &= \frac{100 - 87.525}{73.92} \approx 0.169 \\ \hat{x}_2 &= \frac{0.1 - 87.525}{73.92} \approx -1.183 \\ \hat{x}_3 &= \frac{50 - 87.525}{73.92} \approx -0.507 \\ \hat{x}_4 &= \frac{200 - 87.525}{73.92} \approx 1.521 \end{aligned} </math>x^1x^2x^3x^4=73.92100−87.525≈0.169=73.920.1−87.525≈−1.183=73.9250−87.525≈−0.507=73.92200−87.525≈1.521
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ = [ 0.169 , − 1.183 , − 0.507 , 1.521 ] \hat{x} = [0.169, -1.183, -0.507, 1.521] </math>x^=[0.169,−1.183,−0.507,1.521]

验证:均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> ≈ 0 \approx 0 </math>≈0,方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> ≈ 1 \approx 1 </math>≈1

步骤4:缩放和平移 (假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ = [ 1 , 1 , 1 , 1 ] \gamma=[1,1,1,1] </math>γ=[1,1,1,1], <math xmlns="http://www.w3.org/1998/Math/MathML"> β = [ 0 , 0 , 0 , 0 ] \beta=[0,0,0,0] </math>β=[0,0,0,0])
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y = x ^ = [ 0.169 , − 1.183 , − 0.507 , 1.521 ] y = \hat{x} = [0.169, -1.183, -0.507, 1.521] </math>y=x^=[0.169,−1.183,−0.507,1.521]

对比

维度 原始值 归一化后
1 100 0.169
2 0.1 -1.183
3 50 -0.507
4 200 1.521

所有维度现在都在相近的范围内!

LayerNorm vs BatchNorm:归一化的维度差异

在深度学习中,还有一种常见的归一化:Batch Normalization。它们的核心区别在于归一化的维度不同

直观理解:用矩阵来看

假设我们有一个batch的数据,形状为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( N , d ) (N, d) </math>(N,d):

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N:batch大小(样本数量),比如32
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d:特征维度(每个样本的向量长度),比如768

数据可以表示为一个矩阵:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = [ x 1 , 1 x 1 , 2 ⋯ x 1 , 768 x 2 , 1 x 2 , 2 ⋯ x 2 , 768 ⋮ ⋮ ⋱ ⋮ x 32 , 1 x 32 , 2 ⋯ x 32 , 768 ] ← 样本1的768维特征 ← 样本2的768维特征 ← 样本32的768维特征 X = \begin{bmatrix} x_{1,1} & x_{1,2} & \cdots & x_{1,768} \\ x_{2,1} & x_{2,2} & \cdots & x_{2,768} \\ \vdots & \vdots & \ddots & \vdots \\ x_{32,1} & x_{32,2} & \cdots & x_{32,768} \end{bmatrix} \begin{array}{l} \leftarrow \text{样本1的768维特征} \\ \leftarrow \text{样本2的768维特征} \\ \\ \leftarrow \text{样本32的768维特征} \end{array} </math>X= x1,1x2,1⋮x32,1x1,2x2,2⋮x32,2⋯⋯⋱⋯x1,768x2,768⋮x32,768 ←样本1的768维特征←样本2的768维特征←样本32的768维特征

BatchNorm(纵向归一化)

每一列(同一特征维度的所有样本)计算均值和方差:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ j = 1 N ∑ i = 1 N x i , j (第j维特征在所有样本上的均值) σ j 2 = 1 N ∑ i = 1 N ( x i , j − μ j ) 2 \begin{aligned} \mu_j &= \frac{1}{N} \sum_{i=1}^{N} x_{i,j} \quad \text{(第j维特征在所有样本上的均值)} \\ \sigma_j^2 &= \frac{1}{N} \sum_{i=1}^{N} (x_{i,j} - \mu_j)^2 \end{aligned} </math>μjσj2=N1i=1∑Nxi,j(第j维特征在所有样本上的均值)=N1i=1∑N(xi,j−μj)2

可视化:

makefile 复制代码
BatchNorm: 对每一列归一化
        维度1  维度2  维度3  ... 维度768
样本1    x     x      x    ...  x
样本2    x     x      x    ...  x
样本3    x     x      x    ...  x
...
样本32   x     x      x    ...  x
         ↓     ↓      ↓         ↓
        μ₁    μ₂     μ₃   ...  μ₇₆₈  (对列求均值)

LayerNorm(横向归一化)

每一行(单个样本的所有特征维度)计算均值和方差:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ i = 1 d ∑ j = 1 d x i , j (第i个样本的所有维度的均值) σ i 2 = 1 d ∑ j = 1 d ( x i , j − μ i ) 2 \begin{aligned} \mu_i &= \frac{1}{d} \sum_{j=1}^{d} x_{i,j} \quad \text{(第i个样本的所有维度的均值)} \\ \sigma_i^2 &= \frac{1}{d} \sum_{j=1}^{d} (x_{i,j} - \mu_i)^2 \end{aligned} </math>μiσi2=d1j=1∑dxi,j(第i个样本的所有维度的均值)=d1j=1∑d(xi,j−μi)2

可视化:

makefile 复制代码
LayerNorm: 对每一行归一化
        维度1  维度2  维度3  ... 维度768
样本1    x     x      x    ...  x      → μ₁ (对行求均值)
样本2    x     x      x    ...  x      → μ₂
样本3    x     x      x    ...  x      → μ₃
...
样本32   x     x      x    ...  x      → μ₃₂

具体数值例子

假设有3个样本,每个样本4维特征:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = [ 1 2 3 4 5 6 7 8 2 4 6 8 ] X = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 2 & 4 & 6 & 8 \end{bmatrix} </math>X= 152264376488

BatchNorm计算(对每一列):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 第1维: μ 1 = 1 + 5 + 2 3 = 2.67 标准化: [ 1 , 5 , 2 ] → [ − 0.87 , 1.31 , − 0.44 ] 第2维: μ 2 = 2 + 6 + 4 3 = 4.00 标准化: [ 2 , 6 , 4 ] → [ − 1.00 , 1.00 , 0.00 ] 第3维: μ 3 = 3 + 7 + 6 3 = 5.33 标准化: [ 3 , 7 , 6 ] → [ − 1.07 , 0.93 , 0.13 ] 第4维: μ 4 = 4 + 8 + 8 3 = 6.67 标准化: [ 4 , 8 , 8 ] → [ − 1.15 , 0.58 , 0.58 ] \begin{aligned} \text{第1维:} \quad &\mu_1 = \frac{1+5+2}{3} = 2.67 \\ &\text{标准化:} [1, 5, 2] \to [-0.87, 1.31, -0.44] \\ \\ \text{第2维:} \quad &\mu_2 = \frac{2+6+4}{3} = 4.00 \\ &\text{标准化:} [2, 6, 4] \to [-1.00, 1.00, 0.00] \\ \\ \text{第3维:} \quad &\mu_3 = \frac{3+7+6}{3} = 5.33 \\ &\text{标准化:} [3, 7, 6] \to [-1.07, 0.93, 0.13] \\ \\ \text{第4维:} \quad &\mu_4 = \frac{4+8+8}{3} = 6.67 \\ &\text{标准化:} [4, 8, 8] \to [-1.15, 0.58, 0.58] \end{aligned} </math>第1维:第2维:第3维:第4维:μ1=31+5+2=2.67标准化:[1,5,2]→[−0.87,1.31,−0.44]μ2=32+6+4=4.00标准化:[2,6,4]→[−1.00,1.00,0.00]μ3=33+7+6=5.33标准化:[3,7,6]→[−1.07,0.93,0.13]μ4=34+8+8=6.67标准化:[4,8,8]→[−1.15,0.58,0.58]

结果:每个特征维度在batch中被归一化

LayerNorm计算(对每一行):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 样本1: μ 1 = 1 + 2 + 3 + 4 4 = 2.5 标准化: [ 1 , 2 , 3 , 4 ] → [ − 1.34 , − 0.45 , 0.45 , 1.34 ] 样本2: μ 2 = 5 + 6 + 7 + 8 4 = 6.5 标准化: [ 5 , 6 , 7 , 8 ] → [ − 1.34 , − 0.45 , 0.45 , 1.34 ] 样本3: μ 3 = 2 + 4 + 6 + 8 4 = 5.0 标准化: [ 2 , 4 , 6 , 8 ] → [ − 1.34 , − 0.45 , 0.45 , 1.34 ] \begin{aligned} \text{样本1:} \quad &\mu_1 = \frac{1+2+3+4}{4} = 2.5 \\ &\text{标准化:} [1, 2, 3, 4] \to [-1.34, -0.45, 0.45, 1.34] \\ \\ \text{样本2:} \quad &\mu_2 = \frac{5+6+7+8}{4} = 6.5 \\ &\text{标准化:} [5, 6, 7, 8] \to [-1.34, -0.45, 0.45, 1.34] \\ \\ \text{样本3:} \quad &\mu_3 = \frac{2+4+6+8}{4} = 5.0 \\ &\text{标准化:} [2, 4, 6, 8] \to [-1.34, -0.45, 0.45, 1.34] \end{aligned} </math>样本1:样本2:样本3:μ1=41+2+3+4=2.5标准化:[1,2,3,4]→[−1.34,−0.45,0.45,1.34]μ2=45+6+7+8=6.5标准化:[5,6,7,8]→[−1.34,−0.45,0.45,1.34]μ3=42+4+6+8=5.0标准化:[2,4,6,8]→[−1.34,−0.45,0.45,1.34]

结果:每个样本内部的特征被归一化

关键区别总结

维度 BatchNorm LayerNorm
归一化方向 纵向(跨样本) 横向(跨特征)
均值/方差计算 同一特征在batch中的统计 同一样本的所有特征的统计
依赖关系 依赖batch中的其他样本 只依赖当前样本自己
batch大小影响 很大(小batch效果差) 无影响(每个样本独立)
训练vs推理 不一致(推理用移动平均) 一致(相同计算)
适用场景 CV(图像、batch稳定) NLP(序列、batch不稳定)

为什么Transformer用LayerNorm?

1. 序列长度可变

NLP任务中,不同句子长度差异很大:

arduino 复制代码
样本1: "你好" (2个Token)
样本2: "今天天气真好,我们一起去公园玩吧" (14个Token)

如果用BatchNorm:

  • 需要padding或truncate到相同长度
  • padding的Token会影响统计量(需要mask)
  • 实现复杂,效果不稳定

如果用LayerNorm:

  • 每个样本独立计算,长度无关
  • 无需padding的特殊处理
  • 简单高效

2. Batch统计不稳定

Transformer训练时:

  • batch大小通常较小(2-32,因为序列长)
  • 不同batch的序列长度、内容差异大
  • BatchNorm的统计量方差很大

LayerNorm避免了这个问题:每个样本自己归一化,不受batch影响。

3. 训练与推理一致

BatchNorm的推理问题

训练时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ train = 1 N ∑ i = 1 N x i (当前batch的均值) \mu_{\text{train}} = \frac{1}{N} \sum_{i=1}^{N} x_i \quad \text{(当前batch的均值)} </math>μtrain=N1i=1∑Nxi(当前batch的均值)

推理时(batch=1):

  • 不能用单个样本的统计(方差为0!)
  • 必须使用训练时积累的移动平均

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ test = moving_average ( μ train ) \mu_{\text{test}} = \text{moving\average}(\mu{\text{train}}) </math>μtest=moving_average(μtrain)

这导致训练和推理行为不一致!

LayerNorm的一致性

训练时和推理时使用相同的公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ = 1 d ∑ i = 1 d x i (当前样本自己的均值) \mu = \frac{1}{d} \sum_{i=1}^{d} x_i \quad \text{(当前样本自己的均值)} </math>μ=d1i=1∑dxi(当前样本自己的均值)

完全一致,没有移动平均的复杂性。

实际应用

领域 常用归一化 原因
图像分类(CNN) BatchNorm 固定大小、batch稳定、通道维度有意义
目标检测 BatchNorm / GroupNorm 固定大小,但小batch时用GroupNorm
语言模型(Transformer) LayerNorm 序列长度可变、batch小
语音识别 LayerNorm 序列长度可变
强化学习 LayerNorm batch概念弱

现代趋势

即使在CV领域,也有向LayerNorm或GroupNorm转变的趋势(如Vision Transformer),因为:

  • 更容易迁移到不同batch大小
  • 训练推理一致
  • 分布式训练更简单(不需要跨GPU同步batch统计)

Post-LN vs Pre-LN:放在哪里更好?

LayerNorm在Transformer中的位置有两种方案,效果差异很大。

Post-LN(原始Transformer)

结构:先做变换,后归一化
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 注意力子层: X 1 = LayerNorm ( X + Attention ( X ) ) MLP子层: X 2 = LayerNorm ( X 1 + MLP ( X 1 ) ) \begin{aligned} \text{注意力子层:} \quad &X_1 = \text{LayerNorm}(X + \text{Attention}(X)) \\ \text{MLP子层:} \quad &X_2 = \text{LayerNorm}(X_1 + \text{MLP}(X_1)) \end{aligned} </math>注意力子层:MLP子层:X1=LayerNorm(X+Attention(X))X2=LayerNorm(X1+MLP(X1))

流程图

css 复制代码
X → [Attention] → [+ (残差)] → [LayerNorm] → X₁
X₁ → [MLP] → [+ (残差)] → [LayerNorm] → X₂

特点

  • ✅ 原始Transformer论文的方案
  • ✅ 理论上更符合ResNet的设计
  • ❌ 训练不稳定,需要warmup
  • ❌ 深层网络(>12层)容易梯度爆炸

问题分析

残差相加后,值的范围可能很大,然后才归一化。在深层网络中,累积效应会导致:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∥ X + f ( X ) ∥ ≫ ∥ X ∥ \|X + f(X)\| \gg \|X\| </math>∥X+f(X)∥≫∥X∥

梯度在反向传播时可能放大,导致训练不稳定。

Pre-LN(现代Transformer)

结构:先归一化,后做变换
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 注意力子层: X 1 = X + Attention ( LayerNorm ( X ) ) MLP子层: X 2 = X 1 + MLP ( LayerNorm ( X 1 ) ) \begin{aligned} \text{注意力子层:} \quad &X_1 = X + \text{Attention}(\text{LayerNorm}(X)) \\ \text{MLP子层:} \quad &X_2 = X_1 + \text{MLP}(\text{LayerNorm}(X_1)) \end{aligned} </math>注意力子层:MLP子层:X1=X+Attention(LayerNorm(X))X2=X1+MLP(LayerNorm(X1))

流程图

css 复制代码
X → [LayerNorm] → [Attention] → [+ (残差)] → X₁
X₁ → [LayerNorm] → [MLP] → [+ (残差)] → X₂

特点

  • ✅ 训练更稳定,不需要warmup
  • ✅ 可以训练更深的网络(100+层)
  • ✅ 梯度更平滑
  • ⚠️ 理论上可能略损失一点性能(但实践中差异很小)

为什么更稳定?

  1. 归一化在变换前

    • 每个子层的输入都经过归一化
    • 激活值范围稳定在合理区间
    • 不会因为深度增加而爆炸
  2. 残差连接更直接

    • 原始信息直接加到子层输出
    • 梯度传播路径更清晰

对比总结

特性 Post-LN Pre-LN
归一化位置 残差相加之后 子层输入之前
训练稳定性 较差,需要warmup 好,不需要warmup
适用深度 适合浅层(<24层) 适合深层(100+层)
学习率敏感度 高,需仔细调整 低,更鲁棒
使用模型 原始Transformer, BERT GPT-2/3/4, LLaMA

现代趋势

  • GPT-2开始采用Pre-LN
  • GPT-3、GPT-4:Pre-LN
  • LLaMA系列:Pre-LN
  • 几乎所有新的大模型:Pre-LN

原因:规模越来越大(从几层到上百层),稳定性比理论上的小幅性能差异更重要。

完整的Transformer层

综合所有组件,一个完整的Transformer层(Pre-LN版本):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 步骤1:LayerNorm + 多头注意力 + 残差 X 1 = X + MultiHeadAttention ( LayerNorm ( X ) ) 步骤2:LayerNorm + MLP + 残差 X 2 = X 1 + MLP ( LayerNorm ( X 1 ) ) \begin{aligned} \text{步骤1:LayerNorm + 多头注意力 + 残差} \\ X_1 &= X + \text{MultiHeadAttention}(\text{LayerNorm}(X)) \\ \\ \text{步骤2:LayerNorm + MLP + 残差} \\ X_2 &= X_1 + \text{MLP}(\text{LayerNorm}(X_1)) \end{aligned} </math>步骤1:LayerNorm + 多头注意力 + 残差X1步骤2:LayerNorm + MLP + 残差X2=X+MultiHeadAttention(LayerNorm(X))=X1+MLP(LayerNorm(X1))

详细展开
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> // 注意力子层 X ^ = LayerNorm ( X ) Q , K , V = X ^ W Q , X ^ W K , X ^ W V Attn = softmax ( Q K T d k ) V X 1 = X + Attn (残差连接) // MLP子层 X ^ 1 = LayerNorm ( X 1 ) h = Activation ( W 1 X ^ 1 + b 1 ) MLP_out = W 2 h + b 2 X 2 = X 1 + MLP_out (残差连接) \begin{aligned} &\text{// 注意力子层} \\ &\hat{X} = \text{LayerNorm}(X) \\ &Q, K, V = \hat{X} W_Q, \hat{X} W_K, \hat{X} W_V \\ &\text{Attn} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V \\ &X_1 = X + \text{Attn} \quad \text{(残差连接)} \\ \\ &\text{// MLP子层} \\ &\hat{X}_1 = \text{LayerNorm}(X_1) \\ &h = \text{Activation}(W_1 \hat{X}_1 + b_1) \\ &\text{MLP\_out} = W_2 h + b_2 \\ &X_2 = X_1 + \text{MLP\_out} \quad \text{(残差连接)} \end{aligned} </math>// 注意力子层X^=LayerNorm(X)Q,K,V=X^WQ,X^WK,X^WVAttn=softmax(dk QKT)VX1=X+Attn(残差连接)// MLP子层X^1=LayerNorm(X1)h=Activation(W1X^1+b1)MLP_out=W2h+b2X2=X1+MLP_out(残差连接)

代码实现

python 复制代码
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    """
    Pre-LN版本的Transformer块
    """
    def __init__(self, d_model=768, n_heads=12, d_ff=3072):
        super().__init__()

        # 两个LayerNorm
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # 多头注意力
        self.attention = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            batch_first=True
        )

        # MLP(两层全连接)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        """
        Args:
            x: shape (batch, seq_len, d_model)
        Returns:
            output: shape (batch, seq_len, d_model)
        """
        # 子层1:LayerNorm → Attention → 残差
        x_norm = self.ln1(x)
        attn_out, _ = self.attention(x_norm, x_norm, x_norm)
        x = x + attn_out  # 残差连接

        # 子层2:LayerNorm → MLP → 残差
        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x + mlp_out  # 残差连接

        return x

# 测试
model = TransformerBlock(d_model=768, n_heads=12, d_ff=3072)
x = torch.randn(2, 10, 768)  # (batch=2, seq_len=10, d_model=768)
output = model(x)
print(f"输入 shape: {x.shape}")
print(f"输出 shape: {output.shape}")
print(f"维度保持不变: {x.shape == output.shape}")

信息流可视化

让我们追踪一个Token通过Transformer层的完整流程:

ini 复制代码
初始输入 x: [0.5, -0.3, 0.8, ..., 0.2]  (d_model维)
           ↓
      [LayerNorm]  归一化到均值0方差1
           ↓
      [Attention]  与其他Token交互,学习上下文
           ↓
   x + Attention   残差连接,保留原始信息
           ↓
        x₁: [0.6, -0.2, 0.9, ..., 0.3]  (加入了上下文信息)
           ↓
      [LayerNorm]  再次归一化
           ↓
         [MLP]     非线性变换,学习复杂模式
           ↓
     x₁ + MLP     残差连接,保留前面的信息
           ↓
        x₂: [0.7, -0.1, 1.0, ..., 0.4]  (最终输出)

关键点

  1. 每次变换后都有残差连接,保证信息不丢失
  2. 每次变换前都有LayerNorm,保证数值稳定
  3. 最终输出融合了:原始输入 + 上下文信息 + 非线性特征

实验:残差连接的重要性

让我们通过一个简单实验看看残差连接的效果。

实验设置

训练一个10层的小型Transformer:

  • 有残差连接版本
  • 无残差连接版本
python 复制代码
# 无残差版本(会失败)
class BadTransformerBlock(nn.Module):
    def forward(self, x):
        x_norm = self.ln1(x)
        attn_out, _ = self.attention(x_norm, x_norm, x_norm)
        x = attn_out  # 没有残差!

        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = mlp_out  # 没有残差!

        return x

# 有残差版本(会成功)
class GoodTransformerBlock(nn.Module):
    def forward(self, x):
        x_norm = self.ln1(x)
        attn_out, _ = self.attention(x_norm, x_norm, x_norm)
        x = x + attn_out  # 有残差!

        x_norm = self.ln2(x)
        mlp_out = self.mlp(x_norm)
        x = x + mlp_out  # 有残差!

        return x

实验结果

指标 无残差 有残差
训练loss收敛 ❌ 不收敛 ✅ 正常收敛
梯度范数 爆炸或消失 稳定
最终准确率 接近随机 85%+
训练稳定性 发散 稳定

观察到的现象(无残差版本):

  • 前几层梯度消失(接近0)
  • 后几层梯度爆炸(>1000)
  • Loss曲线剧烈震荡
  • 最终无法学习到有用的表示

结论 :对于10层以上的网络,残差连接是必需的,不是可选的!

小结

  1. 深度网络的困境

    • 梯度消失:连乘导致底层梯度趋近于0
    • 梯度爆炸:连乘导致梯度指数增长
    • 信息退化:深度变换导致原始信息丢失
  2. 残差连接的作用

    • 公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> X out = X in + f ( X in ) X_{\text{out}} = X_{\text{in}} + f(X_{\text{in}}) </math>Xout=Xin+f(Xin)
    • 提供梯度的"高速通道":梯度可以直接传播
    • 保留原始信息:输出始终包含输入
    • 创造多条并行路径:指数级的信息流路径
  3. LayerNorm的作用

    • 归一化每个样本的所有特征维度
    • 稳定激活值范围,防止数值问题
    • 加速训练收敛
    • 公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> LN ( x ) = γ x − μ σ 2 + ϵ + β \text{LN}(x) = \gamma \frac{x-\mu}{\sqrt{\sigma^2+\epsilon}} + \beta </math>LN(x)=γσ2+ϵ x−μ+β
  4. Pre-LN vs Post-LN

    • Post-LN:先变换后归一化,训练不稳定
    • Pre-LN:先归一化后变换,训练稳定
    • 现代大模型(GPT-3/4、LLaMA)都用Pre-LN
    • Pre-LN让100+层的超深网络成为可能
  5. 组合效果

    • 残差连接 + LayerNorm = 稳定的深度训练
    • 没有这两项技术,就没有今天的大模型
    • GPT-3(96层)、PaLM(118层)都依赖这些技术

历史意义

  • ResNet(2015):证明残差连接的有效性
  • LayerNorm(2016):为Transformer提供稳定性
  • Pre-LN(2018-2019):让超深Transformer成为可能
  • 这些看似简单的技术,是大模型革命的基石!
相关推荐
Assby2 小时前
深入理解Java:为什么String类要用final修饰?
后端·面试
Penge6662 小时前
Go 泛型中的 [0]func(T)
后端
Penge6662 小时前
Go-依赖注入
后端
斯瓦辛武2 小时前
webchat中间件的搭建过程
后端
Penge6662 小时前
Go 泛型:一行代码提升依赖注入的类型安全
后端
小熊巨离谱2 小时前
🔥从聊天到干活:三分钟搞懂 LLM、Agent、RAG、Skill
aigc
凌云拓界2 小时前
TypeWell全攻略(四):AI键位分析,让数据开口说话
前端·人工智能·后端·python·ai·交互
kyrie学java2 小时前
SpringBoot搭建项目调试与问题解决
java·spring boot·后端
SimonKing2 小时前
多数据源:CSV、内存对象可以通过SQL查询,甚至联查,你敢信!
java·后端·程序员