GPT-2 中的残差权重初始化

GPT-2 中的残差权重初始化

1. 概述

在深度残差网络中,每一层的输出都会被加到其输入上。如果不对这些层的权重进行特殊处理,随着网络深度的增加,残差路径上累加的信号方差可能会不断增大,导致数值不稳定和训练发散。

为了解决这个问题,GPT-2 的作者在论文中提出了一种针对残差连接路径上的层进行特殊缩放的初始化方法

2. 哪些是"残差权重"?

在 GPT-2 的一个 Decoder Block 中,有两个子层的输出会直接被加到残差流(Residual Stream)上。因此,这两层的权重就是我们所说的"残差权重":

  1. 多头自注意力层的输出投影层 : 在 Hugging Face 的实现中,这层通常被称为 c_proj
  2. 前馈网络 (FFN) 的第二个线性层 : 这也是一个投影层,同样被称为 c_proj

3. GPT-2 的初始化策略

GPT-2 的权重初始化分为两个步骤:一个通用的标准初始化,和一个针对上述"残差权重"的特殊缩放。

步骤一:通用的标准初始化

模型中的所有权重(包括嵌入层、QKV 投影层、FFN第一层以及残差层)首先都会从一个均值为 0、标准差为 0.02 的正态分布中进行初始化。

  • 权重 (Weights) : <math xmlns="http://www.w3.org/1998/Math/MathML"> W ∼ N ( 0 , 0.0 2 2 ) W \sim \mathcal{N}(0, 0.02^2) </math>W∼N(0,0.022)
  • 偏置 (Biases) : 所有偏置都初始化为 0。

这是模型参数的基础初始化值。

步骤二:针对残差权重的特殊缩放

在完成通用初始化之后,GPT-2 会专门对"残差权重"进行一次额外的缩放操作。

缩放公式:

根据 GPT-2 论文的描述,这些残差层的权重会被乘以一个缩放因子:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 1 N \frac{1}{\sqrt{N}} </math>N 1

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 是模型中残差层(或说残差连接)的总数量

  • 在一个标准的 GPT-2 模型中,每个 Transformer Block(n_layer)都包含 2 个残差连接(一个在自注意力后,一个在 FFN 后)。
  • 因此,总的残差层数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = 2 × n layer N = 2 \times n_{\text{layer}} </math>N=2×nlayer。

例如:

  • 对于 gpt2-base,它有 12 个 Block (n_layer=12),所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = 2 × 12 = 24 N = 2 \times 12 = 24 </math>N=2×12=24。缩放因子就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 24 \frac{1}{\sqrt{24}} </math>24 1。
  • 对于 gpt2-large,它有 36 个 Block (n_layer=36),所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = 2 × 36 = 72 N = 2 \times 36 = 72 </math>N=2×36=72。缩放因子就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 72 \frac{1}{\sqrt{72}} </math>72 1。

这个操作通常是在代码层面,将这些特定层的权重张量乘以该缩放因子来完成。

4. 为什么要进行特殊缩放?

核心目的:控制残差流中的方差累积。

  • 问题 : 在一个深度网络中,残差流 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 会经过多个 Block 的累加: <math xmlns="http://www.w3.org/1998/Math/MathML"> x final = x initial + output 1 + output 2 + ⋯ + output N x_{\text{final}} = x_{\text{initial}} + \text{output}_1 + \text{output}_2 + \dots + \text{output}_N </math>xfinal=xinitial+output1+output2+⋯+outputN。如果每个 <math xmlns="http://www.w3.org/1998/Math/MathML"> output i \text{output}_i </math>outputi 的方差是 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 \sigma^2 </math>σ2,那么在理想情况下,最终输出的方差会累积到 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × σ 2 N \times \sigma^2 </math>N×σ2。当 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 很大时,方差会爆炸,导致训练不稳定。
  • 解决方案 : 通过将每个残差层的权重乘以 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 N \frac{1}{\sqrt{N}} </math>N 1,其输出的方差大约会被缩放到原来的 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 N \frac{1}{N} </math>N1(因为方差与权重的平方成正比)。
  • 效果 : 这样,当 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 个残差输出累加时,总的方差大约保持在 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 \sigma^2 </math>σ2 的水平( <math xmlns="http://www.w3.org/1998/Math/MathML"> N × σ 2 N = σ 2 N \times \frac{\sigma^2}{N} = \sigma^2 </math>N×Nσ2=σ2),从而保证了无论网络有多深,流经主干道的信息信号强度都能保持稳定。

5. 初始化总结

下表总结了 GPT-2 中不同层的初始化方式:

层 / 参数 标准初始化 特殊缩放 (仅限残差层)
嵌入层 (wte, wpe) <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 0.0 2 2 ) \mathcal{N}(0, 0.02^2) </math>N(0,0.022) 不适用
注意力 QKV 投影 (c_attn) <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 0.0 2 2 ) \mathcal{N}(0, 0.02^2) </math>N(0,0.022) 不适用
注意力输出投影 (c_proj) <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 0.0 2 2 ) \mathcal{N}(0, 0.02^2) </math>N(0,0.022) ,权重乘以 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 2 × n layer \frac{1}{\sqrt{2 \times n_{\text{layer}}}} </math>2×nlayer 1
FFN 第一个线性层 (c_fc) <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 0.0 2 2 ) \mathcal{N}(0, 0.02^2) </math>N(0,0.022) 不适用
FFN 第二个线性层 (c_proj) <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , 0.0 2 2 ) \mathcal{N}(0, 0.02^2) </math>N(0,0.022) ,权重乘以 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 2 × n layer \frac{1}{\sqrt{2 \times n_{\text{layer}}}} </math>2×nlayer 1
所有偏置 (bias) 初始化为 0 不适用

结论: GPT-2 的残差权重初始化是一种精巧的设计,它通过在标准正态初始化之后,对特定的残差层权重应用一个与网络深度相关的缩放因子,成功地稳定了深度 Transformer 模型的训练过程,是其能够有效扩展到更多层数的关键技术之一。

相关推荐
CODECOLLECT几秒前
技术解析|MDM移动设备管理系统无终身买断制度的底层逻辑
人工智能
北京迅为5 分钟前
《【北京迅为】itop-3568开发板NPU使用手册》- 第 7章 使用RKNN-Toolkit-lite2
linux·人工智能·嵌入式·npu
我是一只puppy11 分钟前
使用AI进行代码审查
javascript·人工智能·git·安全·源代码管理
阿杰学AI12 分钟前
AI核心知识91——大语言模型之 Transformer 架构(简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·自然语言处理·aigc·transformer
esmap14 分钟前
ESMAP 智慧消防解决方案:以数字孪生技术构建全域感知消防体系,赋能消防安全管理智能化升级
人工智能·物联网·3d·编辑器·智慧城市
LaughingZhu19 分钟前
Product Hunt 每日热榜 | 2026-02-08
大数据·人工智能·经验分享·搜索引擎·产品运营
芷栀夏28 分钟前
CANN ops-math:筑牢 AI 神经网络底层的高性能数学运算算子库核心实现
人工智能·深度学习·神经网络
用户51914958484529 分钟前
CVE-2025-47812:Wing FTP Server 高危RCE漏洞分析与利用
人工智能·aigc
阿里云大数据AI技术34 分钟前
【AAAI2026】阿里云人工智能平台PAI视频编辑算法论文入选
人工智能
玄同76536 分钟前
我的 Trae Skill 实践|使用 UV 工具一键搭建 Python 项目开发环境
开发语言·人工智能·python·langchain·uv·trae·vibe coding