05-从隐藏向量到文字:LM Head如何输出"下一个词"?

回顾:大模型的完整流程

在前面的章节中,我们学习了Transformer的各个组件。现在让我们回顾一下完整流程:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 输入: "今天天气" ↓ (Tokenization + Embedding) Token表示: X ∈ R n × d model ↓ (位置编码) 加入位置: X + PE ↓ (多层Transformer) Layer 1: Attention + MLP + Residual + LN Layer 2: Attention + MLP + Residual + LN ⋮ Layer N: Attention + MLP + Residual + LN ↓ 最终隐藏状态: H ∈ R n × d model \begin{aligned} &\text{输入:} \quad \text{"今天天气"} \\ &\quad \downarrow \text{(Tokenization + Embedding)} \\ &\text{Token表示:} \quad X \in \mathbb{R}^{n \times d_{\text{model}}} \\ &\quad \downarrow \text{(位置编码)} \\ &\text{加入位置:} \quad X + \text{PE} \\ &\quad \downarrow \text{(多层Transformer)} \\ &\text{Layer 1:} \quad \text{Attention + MLP + Residual + LN} \\ &\text{Layer 2:} \quad \text{Attention + MLP + Residual + LN} \\ &\quad \vdots \\ &\text{Layer N:} \quad \text{Attention + MLP + Residual + LN} \\ &\quad \downarrow \\ &\text{最终隐藏状态:} \quad H \in \mathbb{R}^{n \times d_{\text{model}}} \end{aligned} </math>输入:"今天天气"↓(Tokenization + Embedding)Token表示:X∈Rn×dmodel↓(位置编码)加入位置:X+PE↓(多层Transformer)Layer 1:Attention + MLP + Residual + LNLayer 2:Attention + MLP + Residual + LN⋮Layer N:Attention + MLP + Residual + LN↓最终隐藏状态:H∈Rn×dmodel

问题来了 : <math xmlns="http://www.w3.org/1998/Math/MathML"> H H </math>H 是一个连续的向量(768维或4096维),但我们需要输出的是具体的文字(如"很好"、"不错")。

如何从连续向量变成离散的词?这就是LM Head的作用!

LM Head:语言模型的"输出层"

LM Head(Language Model Head)是Transformer的最后一层,它的作用非常明确:

将Transformer输出的隐藏向量映射到词表空间,为每个词计算概率,然后选择最可能的下一个词

LM Head的结构

LM Head通常就是一个简单的线性层(不带激活函数):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logits = H ⋅ W lm + b lm \text{logits} = H \cdot W_{\text{lm}} + b_{\text{lm}} </math>logits=H⋅Wlm+blm

参数解释

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> H ∈ R n × d model H \in \mathbb{R}^{n \times d_{\text{model}}} </math>H∈Rn×dmodel:Transformer最后一层的输出(每个Token的隐藏表示)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> W lm ∈ R d model × V W_{\text{lm}} \in \mathbb{R}^{d_{\text{model}} \times V} </math>Wlm∈Rdmodel×V:LM Head的权重矩阵
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> b lm ∈ R V b_{\text{lm}} \in \mathbb{R}^{V} </math>blm∈RV:偏置向量(很多模型不使用偏置)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V:词表大小(Vocabulary size),如50257(GPT-2)、32000(LLaMA)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> logits ∈ R n × V \text{logits} \in \mathbb{R}^{n \times V} </math>logits∈Rn×V:每个位置对所有词的"得分"(未归一化)

关键点

  • 输入: <math xmlns="http://www.w3.org/1998/Math/MathML"> d model d_{\text{model}} </math>dmodel 维的连续向量(如768维)
  • 输出: <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 维的分数向量(如32000维),每一维对应词表中的一个词

维度变化示例

假设 GPT-2 模型( <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 768 d_{\text{model}}=768 </math>dmodel=768, <math xmlns="http://www.w3.org/1998/Math/MathML"> V = 50257 V=50257 </math>V=50257):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> H : ( n , 768 ) (最后一层输出) W lm : ( 768 , 50257 ) (LM Head权重) logits : ( n , 50257 ) (每个位置对所有词的分数) \begin{aligned} H &: (n, 768) \quad \text{(最后一层输出)} \\ W_{\text{lm}} &: (768, 50257) \quad \text{(LM Head权重)} \\ \text{logits} &: (n, 50257) \quad \text{(每个位置对所有词的分数)} \end{aligned} </math>HWlmlogits:(n,768)(最后一层输出):(768,50257)(LM Head权重):(n,50257)(每个位置对所有词的分数)

对于最后一个位置(预测下一个词):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h last : ( 768 , ) (最后一个Token的表示) logits last = h last ⋅ W lm : ( 50257 , ) \begin{aligned} h_{\text{last}} &: (768,) \quad \text{(最后一个Token的表示)} \\ \text{logits}{\text{last}} &= h{\text{last}} \cdot W_{\text{lm}} : (50257,) \\ \end{aligned} </math>hlastlogitslast:(768,)(最后一个Token的表示)=hlast⋅Wlm:(50257,)

这个50257维的向量,每一维代表词表中对应词的"得分":
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logits last = [ s 1 , s 2 , s 3 , ... , s 50257 ] \text{logits}{\text{last}} = [s_1, s_2, s_3, \ldots, s{50257}] </math>logitslast=[s1,s2,s3,...,s50257]

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> s 1 s_1 </math>s1:词表中第1个词(如 <pad>)的得分
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> s 2 s_2 </math>s2:词表中第2个词(如 <unk>)的得分
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> s 100 s_{100} </math>s100:词表中第100个词(如 "the")的得分
  • ...

得分越高,表示这个词越可能是下一个词。

从Logits到概率:Softmax归一化

Logits只是"得分",不是概率(可以是负数、可以很大)。我们需要将它们转换为概率分布
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P ( w i ) = e s i ∑ j = 1 V e s j = softmax ( logits ) i P(w_i) = \frac{e^{s_i}}{\sum_{j=1}^{V} e^{s_j}} = \text{softmax}(\text{logits})_i </math>P(wi)=∑j=1Vesjesi=softmax(logits)i

性质

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> 0 ≤ P ( w i ) ≤ 1 0 \leq P(w_i) \leq 1 </math>0≤P(wi)≤1(每个概率在0-1之间)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ i = 1 V P ( w i ) = 1 \sum_{i=1}^{V} P(w_i) = 1 </math>∑i=1VP(wi)=1(所有概率加起来等于1)

具体例子

假设最后一个位置的logits(简化为5个词):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logits last = [ 2.3 , 1.5 , 3.8 , 0.5 , 1.2 ] \text{logits}_{\text{last}} = [2.3, 1.5, 3.8, 0.5, 1.2] </math>logitslast=[2.3,1.5,3.8,0.5,1.2]

对应词表中的5个词:["很", "好", "不错", "真", "差"]

应用Softmax
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P ( "很" ) = e 2.3 e 2.3 + e 1.5 + e 3.8 + e 0.5 + e 1.2 = 9.97 9.97 + 4.48 + 44.70 + 1.65 + 3.32 = 0.155 P ( "好" ) = e 1.5 64.12 = 0.070 P ( "不错" ) = e 3.8 64.12 = 0.697 (最高!) P ( "真" ) = e 0.5 64.12 = 0.026 P ( "差" ) = e 1.2 64.12 = 0.052 \begin{aligned} P(\text{"很"}) &= \frac{e^{2.3}}{e^{2.3} + e^{1.5} + e^{3.8} + e^{0.5} + e^{1.2}} = \frac{9.97}{9.97 + 4.48 + 44.70 + 1.65 + 3.32} = 0.155 \\ P(\text{"好"}) &= \frac{e^{1.5}}{64.12} = 0.070 \\ P(\text{"不错"}) &= \frac{e^{3.8}}{64.12} = 0.697 \quad \text{(最高!)} \\ P(\text{"真"}) &= \frac{e^{0.5}}{64.12} = 0.026 \\ P(\text{"差"}) &= \frac{e^{1.2}}{64.12} = 0.052 \end{aligned} </math>P("很")P("好")P("不错")P("真")P("差")=e2.3+e1.5+e3.8+e0.5+e1.2e2.3=9.97+4.48+44.70+1.65+3.329.97=0.155=64.12e1.5=0.070=64.12e3.8=0.697(最高!)=64.12e0.5=0.026=64.12e1.2=0.052

概率分布:

Logit 概率
2.3 15.5%
1.5 7.0%
不错 3.8 69.7%
0.5 2.6%
1.2 5.2%

"不错"得分最高,概率最大,很可能被选为下一个词。

采样策略:如何选择下一个词?

有了概率分布后,如何选择下一个词?有多种策略:

1. Greedy Decoding(贪心解码)

最简单的方法:直接选概率最高的词
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> w next = arg ⁡ max ⁡ i P ( w i ) w_{\text{next}} = \arg\max_{i} P(w_i) </math>wnext=argimaxP(wi)

优点

  • 简单、确定性(每次输出相同)
  • 速度快

缺点

  • 输出单调、缺乏多样性
  • 容易陷入重复("我觉得我觉得我觉得...")
  • 可能错过全局最优解

代码

python 复制代码
# logits: (vocab_size,)
probs = torch.softmax(logits, dim=-1)
next_token = torch.argmax(probs)  # 选择概率最大的

2. Random Sampling(随机采样)

按概率分布随机采样
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> w next ∼ P ( w ) = softmax ( logits ) w_{\text{next}} \sim P(w) = \text{softmax}(\text{logits}) </math>wnext∼P(w)=softmax(logits)

概率高的词更可能被选中,但不是绝对的。

优点

  • 输出多样化
  • 可以探索不同的生成路径

缺点

  • 有时会采样到低概率的"坏词"
  • 输出质量不稳定

代码

python 复制代码
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)  # 按概率采样

3. Temperature Sampling(温度采样)

在softmax之前,用温度参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 缩放logits:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P ( w i ) = e s i / T ∑ j e s j / T P(w_i) = \frac{e^{s_i / T}}{\sum_{j} e^{s_j / T}} </math>P(wi)=∑jesj/Tesi/T

参数解释

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1 T = 1 </math>T=1:标准softmax(不改变)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T < 1 T < 1 </math>T<1(如0.5):"降温",概率分布更陡峭,偏向高概率词
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T > 1 T > 1 </math>T>1(如1.5):"升温",概率分布更平缓,增加多样性

直观理解

假设原始logits: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 2.3 , 1.5 , 3.8 , 0.5 , 1.2 ] [2.3, 1.5, 3.8, 0.5, 1.2] </math>[2.3,1.5,3.8,0.5,1.2]

低温 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 0.5 T=0.5 </math>T=0.5
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> scaled_logits = [ 4.6 , 3.0 , 7.6 , 1.0 , 2.4 ] \text{scaled\_logits} = [4.6, 3.0, 7.6, 1.0, 2.4] </math>scaled_logits=[4.6,3.0,7.6,1.0,2.4]

Softmax后:

原始概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 0.5 T=0.5 </math>T=0.5概率
15.5% 3.8%
7.0% 0.8%
不错 69.7% 94.2% ⬆️
2.6% 0.1%
5.2% 0.5%

高概率词("不错")的概率被进一步放大!

高温 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1.5 T=1.5 </math>T=1.5
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> scaled_logits = [ 1.53 , 1.0 , 2.53 , 0.33 , 0.8 ] \text{scaled\_logits} = [1.53, 1.0, 2.53, 0.33, 0.8] </math>scaled_logits=[1.53,1.0,2.53,0.33,0.8]

Softmax后:

原始概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1.5 T=1.5 </math>T=1.5概率
15.5% 21.2% ⬆️
7.0% 12.6% ⬆️
不错 69.7% 58.1% ⬇️
2.6% 6.4% ⬆️
5.2% 10.3% ⬆️

概率分布更均匀,其他词的机会增加!

使用场景

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T < 1 T < 1 </math>T<1:需要确定性、准确性的任务(如翻译、摘要)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1 T = 1 </math>T=1:平衡点
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T > 1 T > 1 </math>T>1:需要创意、多样性的任务(如故事生成、头脑风暴)

代码

python 复制代码
temperature = 0.8
logits_scaled = logits / temperature
probs = torch.softmax(logits_scaled, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

4. Top-K Sampling

只从概率最高的K个词中采样

  1. 对概率排序,保留前K个词
  2. 其余词的概率设为0
  3. 重新归一化
  4. 从这K个词中按概率采样

举例 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> K = 3 K=3 </math>K=3):

原始概率:

概率
不错 69.7%
15.5%
7.0%
5.2%
2.6%

保留Top-3,重新归一化:

新概率
不错 69.7% / (69.7+15.5+7.0) = 75.6%
15.5% / 92.2% = 16.8%
7.0% / 92.2% = 7.6%
0%
0%

优点

  • 过滤掉明显不合适的低概率词
  • 保持一定多样性

缺点

  • K是固定的,不够灵活
  • 有时Top-K之外还有合理的词

代码

python 复制代码
top_k = 50
# 获取top-k的索引和值
top_k_probs, top_k_indices = torch.topk(probs, top_k)
# 重新归一化
top_k_probs = top_k_probs / top_k_probs.sum()
# 从top-k中采样
next_token_idx = torch.multinomial(top_k_probs, num_samples=1)
next_token = top_k_indices[next_token_idx]

5. Top-P Sampling(Nucleus Sampling)

动态选择最小的词集合,使得累积概率达到P

  1. 对概率从高到低排序
  2. 累加概率,直到达到阈值P(如0.9)
  3. 只保留这些词
  4. 重新归一化并采样

举例 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> P = 0.9 P=0.9 </math>P=0.9):

概率 累积概率
不错 69.7% 69.7%
15.5% 85.2%
7.0% 92.2% ✅ 达到90%
5.2% 97.4%
2.6% 100%

保留前3个词(累积概率刚好超过90%):

新概率
不错 69.7% / 92.2% = 75.6%
15.5% / 92.2% = 16.8%
7.0% / 92.2% = 7.6%

优点

  • 自适应:概率分布陡峭时,选择少数词;平缓时,选择更多词
  • 更灵活than Top-K
  • 实践效果好

缺点

  • 计算稍复杂

代码

python 复制代码
top_p = 0.9
# 降序排序
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
# 计算累积概率
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找到累积概率>top_p的位置
mask = cumsum_probs > top_p
# 保留第一个超过top_p的词(确保至少有一个)
mask[1:] = mask[:-1].clone()
mask[0] = False
# 过滤
sorted_probs[mask] = 0
# 重新归一化
sorted_probs = sorted_probs / sorted_probs.sum()
# 采样
next_token_idx = torch.multinomial(sorted_probs, num_samples=1)
next_token = sorted_indices[next_token_idx]

采样策略对比

策略 多样性 质量稳定性 计算复杂度 适用场景
Greedy 翻译、问答
Random 不推荐
Temperature 可调 通用
Top-K 通用
Top-P 自适应 推荐⭐

实践中的组合

通常会结合多种策略

python 复制代码
# Temperature + Top-P(最常用)
temperature = 0.8
top_p = 0.9

logits_scaled = logits / temperature
probs = torch.softmax(logits_scaled, dim=-1)

# 应用Top-P过滤
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum_probs > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_probs[mask] = 0
sorted_probs = sorted_probs / sorted_probs.sum()

# 采样
next_token = sorted_indices[torch.multinomial(sorted_probs, 1)]

Embedding权重共享(Weight Tying)

LM Head的权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W lm ∈ R d model × V W_{\text{lm}} \in \mathbb{R}^{d_{\text{model}} \times V} </math>Wlm∈Rdmodel×V 非常大!

对于GPT-3( <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 12288 d_{\text{model}}=12288 </math>dmodel=12288, <math xmlns="http://www.w3.org/1998/Math/MathML"> V = 50257 V=50257 </math>V=50257):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 参数量 = 12288 × 50257 = 617 , 155 , 776 ≈ 617 M \text{参数量} = 12288 \times 50257 = 617{,}155{,}776 \approx 617M </math>参数量=12288×50257=617,155,776≈617M

这占了模型总参数的很大一部分!

Token Embedding:从词到向量的学习过程

在讨论权重共享之前,我们先理解Token Embedding层是如何训练的

什么是Token Embedding?

Token Embedding层是模型的第一层,它的作用是将离散的Token ID转换为连续的向量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E token ∈ R V × d model E_{\text{token}} \in \mathbb{R}^{V \times d_{\text{model}}} </math>Etoken∈RV×dmodel

工作原理

对于输入的Token ID(如1234),Embedding层就是一个**查表(lookup)**操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> token_id = 1234 ⇒ vector = E token [ 1234 , : ] ∈ R d model \text{token\id} = 1234 \quad \Rightarrow \quad \text{vector} = E{\text{token}}[1234, :] \in \mathbb{R}^{d_{\text{model}}} </math>token_id=1234⇒vector=Etoken[1234,:]∈Rdmodel

这个向量就是Token 1234的表示。

举例(简化为5维向量):

假设词表有3个词:

Token ID Token Embedding向量
0 "今天" [0.12, -0.34, 0.56, 0.23, -0.45]
1 "天气" [0.87, 0.21, -0.32, 0.54, 0.11]
2 "很好" [-0.23, 0.67, 0.89, -0.12, 0.34]

输入序列:"今天 天气"(Token IDs: [0, 1])
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E token [ 0 ] = [ 0.12 , − 0.34 , 0.56 , 0.23 , − 0.45 ] ("今天"的向量) E token [ 1 ] = [ 0.87 , 0.21 , − 0.32 , 0.54 , 0.11 ] ("天气"的向量) \begin{aligned} E_{\text{token}}[0] &= [0.12, -0.34, 0.56, 0.23, -0.45] \quad \text{("今天"的向量)} \\ E_{\text{token}}[1] &= [0.87, 0.21, -0.32, 0.54, 0.11] \quad \text{("天气"的向量)} \end{aligned} </math>Etoken[0]Etoken[1]=[0.12,−0.34,0.56,0.23,−0.45]("今天"的向量)=[0.87,0.21,−0.32,0.54,0.11]("天气"的向量)

Token Embedding 需要训练吗?

答案:绝对需要!

Token Embedding矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> E token E_{\text{token}} </math>Etoken 是可学习的参数 ,和权重矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W_Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 W_1 </math>W1 等完全一样,通过梯度下降训练。

1. 初始化

训练开始前,Embedding矩阵需要随机初始化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E token [ i , : ] ∼ N ( 0 , σ 2 ) E_{\text{token}}[i, :] \sim \mathcal{N}(0, \sigma^2) </math>Etoken[i,:]∼N(0,σ2)

通常使用:

  • 正态分布初始化 : <math xmlns="http://www.w3.org/1998/Math/MathML"> σ = 0.02 \sigma = 0.02 </math>σ=0.02(GPT系列)
  • 均匀分布初始化 : <math xmlns="http://www.w3.org/1998/Math/MathML"> U ( − 3 / d model , 3 / d model ) \mathcal{U}(-\sqrt{3/d_{\text{model}}}, \sqrt{3/d_{\text{model}}}) </math>U(−3/dmodel ,3/dmodel )

初始状态 :每个词的向量是随机的,完全没有语义!

python 复制代码
import torch
import torch.nn as nn

vocab_size = 50257
d_model = 768

# 创建Embedding层
token_embedding = nn.Embedding(vocab_size, d_model)

# 查看初始化后的值
print("Token 0 的初始embedding:", token_embedding.weight[0][:5])
# 输出:tensor([-0.0134,  0.0089, -0.0156,  0.0201, -0.0178])

print("Token 1 的初始embedding:", token_embedding.weight[1][:5])
# 输出:tensor([ 0.0167, -0.0145,  0.0123, -0.0098,  0.0134])

# 完全是随机值,没有任何语义!

2. 前向传播

在前向传播中,Embedding层将Token IDs转换为向量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = E token [ token_ids ] X = E_{\text{token}}[\text{token\_ids}] </math>X=Etoken[token_ids]

举例

python 复制代码
# 输入序列:[1234, 5678, 9012]
token_ids = torch.tensor([1234, 5678, 9012])

# 查表得到embedding
X = token_embedding(token_ids)  # shape: (3, 768)

# X[0] = E_token[1234, :]
# X[1] = E_token[5678, :]
# X[2] = E_token[9012, :]

这些向量会通过Transformer层,最终产生输出。

3. 反向传播

当损失函数的梯度反向传播时,会传到Embedding层:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ E token [ i ] = ∂ L ∂ X [ j ] (如果token_ids[j] = i) \frac{\partial L}{\partial E_{\text{token}}[i]} = \frac{\partial L}{\partial X[j]} \quad \text{(如果token\_ids[j] = i)} </math>∂Etoken[i]∂L=∂X[j]∂L(如果token_ids[j] = i)

关键点

  • 只有出现在输入序列中的Token的embedding会收到梯度
  • 没出现的Token的embedding在这个batch中保持不变

举例

假设输入序列是 [1234, 5678, 9012]

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> E token [ 1234 ] E_{\text{token}}[1234] </math>Etoken[1234] 会收到梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ X [ 0 ] \frac{\partial L}{\partial X[0]} </math>∂X[0]∂L
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> E token [ 5678 ] E_{\text{token}}[5678] </math>Etoken[5678] 会收到梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ X [ 1 ] \frac{\partial L}{\partial X[1]} </math>∂X[1]∂L
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> E token [ 9012 ] E_{\text{token}}[9012] </math>Etoken[9012] 会收到梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ X [ 2 ] \frac{\partial L}{\partial X[2]} </math>∂X[2]∂L
  • 其他49999个Token的embedding在这个batch中不更新

4. 参数更新

使用优化器更新Embedding矩阵:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E token [ i ] ← E token [ i ] − η ⋅ ∂ L ∂ E token [ i ] E_{\text{token}}[i] \leftarrow E_{\text{token}}[i] - \eta \cdot \frac{\partial L}{\partial E_{\text{token}}[i]} </math>Etoken[i]←Etoken[i]−η⋅∂Etoken[i]∂L

和其他参数完全一样的训练过程!

5. 训练后的语义

经过大量数据的训练,Embedding矩阵会学到有意义的语义表示:

相似词的向量会接近

python 复制代码
# 训练后(示意)
E_token["国王"] ≈ [0.23, 0.56, -0.12, ..., 0.45]
E_token["女王"] ≈ [0.25, 0.54, -0.10, ..., 0.43]  # 很接近!

E_token["男人"] ≈ [0.67, 0.21, -0.34, ..., 0.12]
E_token["女人"] ≈ [0.69, 0.19, -0.32, ..., 0.10]  # 很接近!

# 著名的关系:king - man + woman ≈ queen

计算余弦相似度
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> similarity ( "国王" , "女王" ) = E token [ "国王" ] ⋅ E token [ "女王" ] ∥ E token [ "国王" ] ∥ ⋅ ∥ E token [ "女王" ] ∥ ≈ 0.85 \text{similarity}(\text{"国王"}, \text{"女王"}) = \frac{E_{\text{token}}[\text{"国王"}] \cdot E_{\text{token}}[\text{"女王"}]}{\|E_{\text{token}}[\text{"国王"}]\| \cdot \|E_{\text{token}}[\text{"女王"}]\|} \approx 0.85 </math>similarity("国王","女王")=∥Etoken["国王"]∥⋅∥Etoken["女王"]∥Etoken["国王"]⋅Etoken["女王"]≈0.85
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> similarity ( "国王" , "苹果" ) ≈ 0.12 (不相关) \text{similarity}(\text{"国王"}, \text{"苹果"}) \approx 0.12 \quad \text{(不相关)} </math>similarity("国王","苹果")≈0.12(不相关)

6. Embedding的参数量

对于GPT-3( <math xmlns="http://www.w3.org/1998/Math/MathML"> V = 50257 V=50257 </math>V=50257, <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 12288 d_{\text{model}}=12288 </math>dmodel=12288):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Embedding参数量 = V × d model = 50257 × 12288 = 617 , 155 , 776 ≈ 617 M \text{Embedding参数量} = V \times d_{\text{model}} = 50257 \times 12288 = 617{,}155{,}776 \approx 617M </math>Embedding参数量=V×dmodel=50257×12288=617,155,776≈617M

和LM Head的参数量一样大!(如果不共享权重的话)

7. 稀疏更新的效率

由于每个batch只更新出现的Token,Embedding层的更新是稀疏的

  • 总Token数:50257
  • 每个batch出现的Token:~100-1000
  • 更新比例:<2%

这也是为什么Embedding训练需要大量数据------需要让每个Token都有足够的训练机会。

8. 与位置编码的区别

Token Embedding(可学习):

  • 表示"这是什么词"
  • 通过训练学习
  • 每个Token有独立的向量

位置编码(一般是固定的):

  • 表示"词在哪个位置"
  • 可以是固定公式(Sinusoidal)或可学习参数
  • 不同位置有不同的向量

两者相加:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X input = E token [ token_ids ] + PE [ positions ] X_{\text{input}} = E_{\text{token}}[\text{token\_ids}] + \text{PE}[\text{positions}] </math>Xinput=Etoken[token_ids]+PE[positions]

完整的训练流程图

ini 复制代码
训练前:
E_token = 随机初始化

第1个epoch:
输入:"今天天气很好" → [1234, 5678, 9012, 4567]
  ↓
E_token[1234], E_token[5678], E_token[9012], E_token[4567] 被使用
  ↓
通过Transformer → 计算Loss
  ↓
反向传播 → 这4个Token的embedding收到梯度
  ↓
E_token[1234], E_token[5678], E_token[9012], E_token[4567] 被更新
(其他49,999个Token不变)

第2个epoch:
输入:"天气真不错" → [5678, 3456, 7890, 2345]
  ↓
又有4个Token的embedding被更新
...

经过数百万个样本:
  ↓
所有Token都被更新过很多次
  ↓
E_token 学到了丰富的语义表示!

总结:Token Embedding的训练

特性 Token Embedding 权重矩阵 (W₁, W₂)
是否可学习 ✅ 是 ✅ 是
初始化方式 正态分布(σ=0.02) He/Xavier初始化
训练方式 梯度下降(稀疏更新) 梯度下降(稠密更新)
参数量 <math xmlns="http://www.w3.org/1998/Math/MathML"> V × d model V \times d_{\text{model}} </math>V×dmodel <math xmlns="http://www.w3.org/1998/Math/MathML"> d in × d out d_{\text{in}} \times d_{\text{out}} </math>din×dout
语义 训练后学到词义 训练后学到变换规则

关键点

  • Token Embedding 不是预定义的,而是从随机初始化开始训练的
  • 它是模型参数的重要组成部分(占比可达20%+)
  • 训练后会自动学到语义相似性(相似词的向量接近)
  • 每个batch只更新出现的Token(稀疏更新)

什么是权重共享?

现在我们理解了Token Embedding也是训练出来的,让我们看权重共享的概念。

模型的第一层是Token Embedding层
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E token ∈ R V × d model E_{\text{token}} \in \mathbb{R}^{V \times d_{\text{model}}} </math>Etoken∈RV×dmodel

它将词表中的每个词映射到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model d_{\text{model}} </math>dmodel 维向量。

观察

  • Embedding矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> ( V , d model ) (V, d_{\text{model}}) </math>(V,dmodel)
  • LM Head矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> ( d model , V ) (d_{\text{model}}, V) </math>(dmodel,V)

两者是转置关系

权重共享(Weight Tying):让LM Head直接使用Embedding矩阵的转置:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W lm = E token T W_{\text{lm}} = E_{\text{token}}^T </math>Wlm=EtokenT

为什么可以共享?

直观理解

  • Embedding :词 → 向量("猫" → <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.1 , 0.5 , . . . , 0.3 ] [0.1, 0.5, ..., 0.3] </math>[0.1,0.5,...,0.3])
  • LM Head :向量 → 词( <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.1 , 0.5 , . . . , 0.3 ] [0.1, 0.5, ..., 0.3] </math>[0.1,0.5,...,0.3] → "猫")

它们在做相反的事情 !如果一个词的embedding向量是 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v,那么当隐藏状态接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v 时,应该输出这个词。

数学上
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logit i = h ⋅ W lm [ : , i ] = h ⋅ E token [ i , : ] T = h ⋅ e i \text{logit}i = h \cdot W{\text{lm}}[:, i] = h \cdot E_{\text{token}}[i, :]^T = h \cdot e_i </math>logiti=h⋅Wlm[:,i]=h⋅Etoken[i,:]T=h⋅ei

即:词 <math xmlns="http://www.w3.org/1998/Math/MathML"> w i w_i </math>wi 的logit等于隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 与该词的embedding <math xmlns="http://www.w3.org/1998/Math/MathML"> e i e_i </math>ei 的点积(相似度)。

越相似,logit越大,该词越可能被选中!

权重共享的优缺点

优点

  1. 大幅减少参数量

    • 不共享:Embedding参数 + LM Head参数
    • 共享:只有Embedding参数
    • 节省: <math xmlns="http://www.w3.org/1998/Math/MathML"> 617 M 617M </math>617M 参数(对于GPT-3)
  2. 理论优雅

    • Embedding和LM Head在语义空间中对称
    • 鼓励模型学习一致的表示
  3. 正则化效果

    • 相当于对两个矩阵施加了约束
    • 可能提高泛化能力

缺点

  1. 灵活性降低

    • Embedding和LM Head被强制对称
    • 可能限制表达能力
  2. 实践中效果不总是最好

    • 小模型上效果好(参数少,需要正则化)
    • 大模型上效果不明显(参数足够,不需要强约束)

现代模型的选择

模型 是否共享 原因
BERT ✅ 共享 小模型(110M-340M),节省参数
GPT-2 ✅ 共享 中等模型(117M-1.5B),节省参数
GPT-3 ❌ 不共享 大模型(175B),参数足够
LLaMA ❌ 不共享 大模型(7B-65B),追求性能
T5 ❌ 不共享 编码器-解码器架构,更复杂

趋势 :随着模型规模增大,越来越多的模型选择不共享,以获得更大的灵活性和表达能力。

代码实现

不共享权重

python 复制代码
class TransformerLM(nn.Module):
    def __init__(self, vocab_size=50257, d_model=768):
        super().__init__()
        # Token Embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # Transformer layers
        self.transformer = nn.ModuleList([...])

        # LM Head(独立的权重)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids):
        x = self.token_embedding(input_ids)  # (batch, seq, d_model)

        # 通过Transformer
        for layer in self.transformer:
            x = layer(x)

        # LM Head
        logits = self.lm_head(x)  # (batch, seq, vocab_size)
        return logits

共享权重

python 复制代码
class TransformerLM_Tied(nn.Module):
    def __init__(self, vocab_size=50257, d_model=768):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.ModuleList([...])

        # LM Head使用Embedding的转置
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # 关键:权重绑定
        self.lm_head.weight = self.token_embedding.weight

    def forward(self, input_ids):
        x = self.token_embedding(input_ids)

        for layer in self.transformer:
            x = layer(x)

        # LM Head使用共享的权重
        logits = self.lm_head(x)  # 实际上是 x @ token_embedding.weight.T
        return logits

手动实现共享

python 复制代码
# 更显式的写法
def forward(self, input_ids):
    x = self.token_embedding(input_ids)  # (batch, seq, d_model)

    for layer in self.transformer:
        x = layer(x)

    # 手动使用embedding权重的转置
    logits = torch.matmul(x, self.token_embedding.weight.T)  # (batch, seq, vocab)
    return logits

完整的生成流程

让我们把所有内容串起来,看一个完整的文本生成例子:

输入:"今天天气"

目标:生成下一个词

步骤1:Tokenization

yaml 复制代码
"今天天气" → [1234, 5678, 9012]

步骤2:Embedding + 位置编码

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = E token [ [ 1234 , 5678 , 9012 ] ] + PE X = E_{\text{token}}[[1234, 5678, 9012]] + \text{PE} </math>X=Etoken[[1234,5678,9012]]+PE
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X ∈ R 3 × 768 X \in \mathbb{R}^{3 \times 768} </math>X∈R3×768

步骤3:通过Transformer层

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X 1 = TransformerLayer 1 ( X ) X 2 = TransformerLayer 2 ( X 1 ) ⋮ H = TransformerLayer 12 ( X 11 ) \begin{aligned} X_1 &= \text{TransformerLayer}1(X) \\ X_2 &= \text{TransformerLayer}2(X_1) \\ &\vdots \\ H &= \text{TransformerLayer}{12}(X{11}) \end{aligned} </math>X1X2H=TransformerLayer1(X)=TransformerLayer2(X1)⋮=TransformerLayer12(X11)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> H ∈ R 3 × 768 H \in \mathbb{R}^{3 \times 768} </math>H∈R3×768

步骤4:取最后一个位置

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h last = H [ 2 , : ] ∈ R 768 h_{\text{last}} = H[2, :] \in \mathbb{R}^{768} </math>hlast=H[2,:]∈R768

这个向量包含了"今天天气"后面应该接什么的所有信息。

步骤5:LM Head映射到词表

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logits = h last ⋅ W lm ∈ R 50257 \text{logits} = h_{\text{last}} \cdot W_{\text{lm}} \in \mathbb{R}^{50257} </math>logits=hlast⋅Wlm∈R50257

假设结果(简化):

python 复制代码
logits = {
    "很": 2.3,
    "好": 1.5,
    "不错": 3.8,
    "真": 0.5,
    "差": 1.2,
    ...
}

步骤6:Softmax归一化

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> probs = softmax ( logits ) \text{probs} = \text{softmax}(\text{logits}) </math>probs=softmax(logits)

python 复制代码
probs = {
    "很": 15.5%,
    "好": 7.0%,
    "不错": 69.7%,  # 最高
    "真": 2.6%,
    "差": 5.2%,
    ...
}

步骤7:采样(Top-P,p=0.9)

python 复制代码
# 累积概率达到90%的词:["不错", "很", "好"]
# 重新归一化后按概率采样
next_token = sample(["不错", "很", "好"], probs=[0.756, 0.168, 0.076])
# 结果:next_token = "不错"

步骤8:输出

arduino 复制代码
"今天天气" + "不错" = "今天天气不错"

继续生成(自回归)

如果要继续生成:

  1. 将"不错"加入输入序列
  2. 重复步骤2-7
  3. 生成下一个词(如",")
  4. 继续迭代...

最终可能生成:

arduino 复制代码
"今天天气不错,适合出去玩。"

LM Head的参数量和计算量

参数量

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> LM Head参数量 = d model × V \text{LM Head参数量} = d_{\text{model}} \times V </math>LM Head参数量=dmodel×V

模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> d model d_{\text{model}} </math>dmodel <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V LM Head参数量 占总参数比例
BERT-Base 768 30,522 23M 21%
GPT-2 768 50,257 39M 26%
GPT-3 12,288 50,257 617M 0.35%
LLaMA-7B 4,096 32,000 131M 1.9%
LLaMA-65B 8,192 32,000 262M 0.4%

观察

  • 小模型:LM Head占比很大(20%+)
  • 大模型:LM Head占比很小(<2%)

原因 :LM Head的参数量与词表大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 成正比,与模型深度无关。大模型通过增加层数和宽度来扩大规模,但词表大小基本不变,所以LM Head的占比下降。

计算量

每次生成一个Token的计算量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 计算量 = d model × V \text{计算量} = d_{\text{model}} \times V </math>计算量=dmodel×V

对于LLaMA-7B( <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 4096 d_{\text{model}}=4096 </math>dmodel=4096, <math xmlns="http://www.w3.org/1998/Math/MathML"> V = 32000 V=32000 </math>V=32000):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 计算量 = 4096 × 32000 = 131 , 072 , 000 ≈ 131 M FLOPs \text{计算量} = 4096 \times 32000 = 131{,}072{,}000 \approx 131M \text{ FLOPs} </math>计算量=4096×32000=131,072,000≈131M FLOPs

对比

  • 一个Transformer层的MLP: <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 × 4096 × 11008 × 2 ≈ 180 M 2 \times 4096 \times 11008 \times 2 \approx 180M </math>2×4096×11008×2≈180M FLOPs
  • LM Head与一个MLP层的计算量相当

虽然参数占比小,但计算量不可忽略!

小结

  1. LM Head的作用

    • 将Transformer输出的连续向量映射到词表空间
    • 为每个词计算logit(得分)
    • 通过softmax转换为概率分布
  2. 核心公式

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logits = H ⋅ W lm ∈ R n × V P ( w ) = softmax ( logits ) \begin{aligned} \text{logits} &= H \cdot W_{\text{lm}} \in \mathbb{R}^{n \times V} \\ P(w) &= \text{softmax}(\text{logits}) \end{aligned} </math>logitsP(w)=H⋅Wlm∈Rn×V=softmax(logits)

  3. 采样策略

    • Greedy:选最大概率(确定性,缺乏多样性)
    • Temperature:调整概率分布的陡峭程度
    • Top-K:只从前K个词中采样
    • Top-P:自适应选择词集合(推荐⭐)
  4. 权重共享

    • 小模型:常用权重共享,节省参数
    • 大模型:常用独立权重,追求性能
    • 公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> W lm = E token T W_{\text{lm}} = E_{\text{token}}^T </math>Wlm=EtokenT
  5. 参数量分析

    • 小模型中占比大(20%+)
    • 大模型中占比小(<2%)
    • 但计算量不可忽略
  6. 完整流程

    • Tokenization → Embedding → Transformer → LM Head → Softmax → Sampling → Token

LM Head看似简单(就是一个线性层),但它是连接模型内部表示和外部文字的关键桥梁,没有它,再强大的Transformer也无法输出一个字!

相关推荐
绝无仅有1 小时前
计算机网络核心面试知识深入解析
后端·面试·架构
树獭叔叔1 小时前
03-大模型的非线性变化:从MLP到MOE,大模型2/3的参数都在这里
后端·aigc·openai
mantch1 小时前
全网最全 Claude Skills 指南:从原理到应用,一篇搞定!
人工智能·aigc·agent
多恩Stone2 小时前
【3D-AICG 系列-15】Trellis 2 的 O-voxel Shape: Flexible Dual Grid 代码与论文对应
人工智能·python·算法·3d·aigc
短剑重铸之日2 小时前
《Seata从入门到实战》第七章:seata总结
java·后端·seata
李云龙炮击平安线程2 小时前
Python中的接口、抽象基类和协议
开发语言·后端·python·面试·跳槽
稻草猫.2 小时前
TCP与UDP:传输层协议深度解析
笔记·后端·网络协议
Moment3 小时前
此 KFC 不是肯德基,Kafka、Flink、ClickHouse 怎么搭、何时省掉 Flink
前端·后端·面试
Charlie_lll3 小时前
力扣解题-438. 找到字符串中所有字母异位词
后端·算法·leetcode