06-大模型如何"学习":从梯度下降到AdamW优化器

大模型如何"学习":从梯度下降到AdamW优化器

引言:什么是"学习"?

在前面的章节中,我们学习了Transformer的各个组件:注意力机制、MLP、残差连接、LM Head等。但有一个核心问题我们还没有回答:

这些参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W_Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 W_1 </math>W1、 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β、Embedding等)是怎么得到的?

答案是:通过训练(Training)学习得到的!

类比

想象你在学习投篮:

  1. 初始状态:随便投,命中率很低(随机初始化)
  2. 观察结果:投偏了,偏左边10厘米(计算误差)
  3. 调整姿势:下次往右边调整一点(参数更新)
  4. 重复练习:不断投篮、观察、调整(迭代训练)
  5. 最终:命中率很高(模型收敛)

深度学习的训练过程就是这样:通过不断调整参数,让模型的输出越来越接近正确答案

前向传播(Forward Propagation)

定义

前向传播是指数据从输入层经过各层计算,最终得到输出的过程。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 输入 → 计算 输出 \text{输入} \xrightarrow{\text{计算}} \text{输出} </math>输入计算 输出

Transformer的前向传播

让我们用一个具体例子看完整的前向传播流程:

输入 :"今天天气" 目标:预测下一个词(应该是"很好"、"不错"等)

步骤1:Token Embedding

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 输入Token IDs: [ 1234 , 5678 ] 查表: X = E token [ [ 1234 , 5678 ] ] ∈ R 2 × 768 \begin{aligned} \text{输入Token IDs:} & \quad [1234, 5678] \\ \text{查表:} & \quad X = E_{\text{token}}[[1234, 5678]] \in \mathbb{R}^{2 \times 768} \end{aligned} </math>输入Token IDs:查表:[1234,5678]X=Etoken[[1234,5678]]∈R2×768

步骤2:位置编码

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X pos = X + PE ∈ R 2 × 768 X_{\text{pos}} = X + \text{PE} \in \mathbb{R}^{2 \times 768} </math>Xpos=X+PE∈R2×768

步骤3:通过第1层Transformer

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 注意力: X 1 ′ = X pos + Attention ( LN ( X pos ) ) MLP: X 1 = X 1 ′ + MLP ( LN ( X 1 ′ ) ) \begin{aligned} \text{注意力:} & \quad X_1' = X_{\text{pos}} + \text{Attention}(\text{LN}(X_{\text{pos}})) \\ \text{MLP:} & \quad X_1 = X_1' + \text{MLP}(\text{LN}(X_1')) \end{aligned} </math>注意力:MLP:X1′=Xpos+Attention(LN(Xpos))X1=X1′+MLP(LN(X1′))

步骤4:通过第2-12层

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X 2 , X 3 , ... , X 12 X_2, X_3, \ldots, X_{12} </math>X2,X3,...,X12

步骤5:LM Head

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logits = X 12 [ − 1 ] ⋅ W lm ∈ R 50257 \text{logits} = X_{12}[-1] \cdot W_{\text{lm}} \in \mathbb{R}^{50257} </math>logits=X12[−1]⋅Wlm∈R50257

取最后一个位置的隐藏状态,映射到词表。

步骤6:Softmax

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P ( w i ) = e logits i ∑ j e logits j P(w_i) = \frac{e^{\text{logits}_i}}{\sum_j e^{\text{logits}_j}} </math>P(wi)=∑jelogitsjelogitsi

得到每个词的概率。

假设输出

概率
很好 45%
不错 30%
真棒 15%
5%
... ...

关键点

前向传播的特点

  1. 确定性:给定输入和参数,输出是确定的
  2. 单向性:只能从输入到输出,不能反过来
  3. 快速:主要是矩阵乘法,GPU加速很快

用途

  • 训练时:计算输出,然后计算Loss
  • 推理时:直接用来生成文本

损失函数(Loss Function)

前向传播得到了输出(概率分布),但如何知道这个输出好不好?这就需要损失函数。

什么是损失函数?

**损失函数(Loss Function)**衡量模型输出与真实答案之间的差距:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Loss = f ( 模型输出 , 真实答案 ) \text{Loss} = f(\text{模型输出}, \text{真实答案}) </math>Loss=f(模型输出,真实答案)

Loss越小,说明模型越好!

交叉熵损失(Cross-Entropy Loss)

对于语言模型,最常用的是交叉熵损失
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − log ⁡ P ( 正确答案 ) L = -\log P(\text{正确答案}) </math>L=−logP(正确答案)

直观理解

  • 如果模型给正确答案的概率很高(接近1),则 <math xmlns="http://www.w3.org/1998/Math/MathML"> − log ⁡ P ≈ 0 -\log P \approx 0 </math>−logP≈0(损失小)
  • 如果模型给正确答案的概率很低(接近0),则 <math xmlns="http://www.w3.org/1998/Math/MathML"> − log ⁡ P → ∞ -\log P \to \infty </math>−logP→∞(损失大)

具体例子

输入 :"今天天气" 真实答案:"很好"(Token ID: 9527)

模型输出(前向传播后):

概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> P ( w ) P(w) </math>P(w) <math xmlns="http://www.w3.org/1998/Math/MathML"> − log ⁡ P ( w ) -\log P(w) </math>−logP(w)
很好 0.45 0.80
不错 0.30 1.20
真棒 0.15 1.90
0.05 3.00
... ... ...

计算Loss
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − log ⁡ P ( "很好" ) = − log ⁡ ( 0.45 ) = 0.80 L = -\log P(\text{"很好"}) = -\log(0.45) = 0.80 </math>L=−logP("很好")=−log(0.45)=0.80

如果模型预测得更准(概率0.9):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − log ⁡ ( 0.9 ) = 0.11 (损失更小!) L = -\log(0.9) = 0.11 \quad \text{(损失更小!)} </math>L=−log(0.9)=0.11(损失更小!)

如果模型预测得很差(概率0.01):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − log ⁡ ( 0.01 ) = 4.61 (损失很大!) L = -\log(0.01) = 4.61 \quad \text{(损失很大!)} </math>L=−log(0.01)=4.61(损失很大!)

数学形式

更正式地,对于词表大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V 的分类问题:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − ∑ i = 1 V y i log ⁡ ( p i ) L = -\sum_{i=1}^{V} y_i \log(p_i) </math>L=−i=1∑Vyilog(pi)

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi:真实标签的one-hot编码(正确答案为1,其余为0)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p i p_i </math>pi:模型预测的概率

由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y 是one-hot,只有正确答案那一项不为0,所以简化为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − log ⁡ ( p correct ) L = -\log(p_{\text{correct}}) </math>L=−log(pcorrect)

训练目标

训练的目标就是最小化Loss
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> min ⁡ θ E ( x , y ) ∼ Data [ L ( f θ ( x ) , y ) ] \min_{\theta} \mathbb{E}{(x, y) \sim \text{Data}} [L(f\theta(x), y)] </math>θminE(x,y)∼Data[L(fθ(x),y)]

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ:所有模型参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q W_Q </math>WQ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 W_1 </math>W1、Embedding等)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( x ) f_\theta(x) </math>fθ(x):模型的输出
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> E \mathbb{E} </math>E:对所有训练数据的期望

通俗理解:找到一组参数,让模型在所有训练样本上的平均Loss最小。

反向传播(Backward Propagation)

知道了Loss,如何调整参数让Loss变小?这就需要反向传播

核心思想

**梯度(Gradient)**告诉我们:参数往哪个方向调整,Loss会下降最快。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 梯度 = ∂ L ∂ θ \text{梯度} = \frac{\partial L}{\partial \theta} </math>梯度=∂θ∂L

直观理解

想象你在山上迷雾中,想下山(最小化Loss):

  • 梯度:指向上坡最陡的方向
  • 负梯度:指向下坡最陡的方向
  • 沿着负梯度走:最快下山

链式法则(Chain Rule)

反向传播的数学基础是链式法则
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ W 1 = ∂ L ∂ y ⋅ ∂ y ∂ h ⋅ ∂ h ∂ W 1 \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h} \cdot \frac{\partial h}{\partial W_1} </math>∂W1∂L=∂y∂L⋅∂h∂y⋅∂W1∂h

从后往前传

  1. 计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ y \frac{\partial L}{\partial y} </math>∂y∂L(Loss对输出的梯度)
  2. 计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ y ∂ h \frac{\partial y}{\partial h} </math>∂h∂y(输出对中间层的梯度)
  3. 计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ h ∂ W 1 \frac{\partial h}{\partial W_1} </math>∂W1∂h(中间层对参数的梯度)
  4. 相乘得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W 1 \frac{\partial L}{\partial W_1} </math>∂W1∂L(Loss对参数的梯度)

具体例子:简单MLP的反向传播

假设一个简单的MLP:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h = W 1 ⋅ x + b 1 h act = ReLU ( h ) y = W 2 ⋅ h act + b 2 L = ( y − y true ) 2 \begin{aligned} h &= W_1 \cdot x + b_1 \\ h_{\text{act}} &= \text{ReLU}(h) \\ y &= W_2 \cdot h_{\text{act}} + b_2 \\ L &= (y - y_{\text{true}})^2 \end{aligned} </math>hhactyL=W1⋅x+b1=ReLU(h)=W2⋅hact+b2=(y−ytrue)2

各步骤说明:

  1. 第一层线性变换
  2. ReLU激活函数
  3. 第二层线性变换(输出)
  4. 计算损失(均方误差)

前向传播 (假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> x = 2 x=2 </math>x=2, <math xmlns="http://www.w3.org/1998/Math/MathML"> y true = 5 y_{\text{true}}=5 </math>ytrue=5):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h = 3 × 2 + 1 = 7 h act = ReLU ( 7 ) = 7 y = 2 × 7 + 0 = 14 L = ( 14 − 5 ) 2 = 81 \begin{aligned} h &= 3 \times 2 + 1 = 7 \\ h_{\text{act}} &= \text{ReLU}(7) = 7 \\ y &= 2 \times 7 + 0 = 14 \\ L &= (14 - 5)^2 = 81 \end{aligned} </math>hhactyL=3×2+1=7=ReLU(7)=7=2×7+0=14=(14−5)2=81

反向传播
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ y = 2 ( y − y true ) = 2 ( 14 − 5 ) = 18 ∂ L ∂ W 2 = ∂ L ∂ y ⋅ ∂ y ∂ W 2 = 18 × h act = 18 × 7 = 126 ∂ L ∂ b 2 = ∂ L ∂ y ⋅ ∂ y ∂ b 2 = 18 × 1 = 18 ∂ L ∂ h act = ∂ L ∂ y ⋅ ∂ y ∂ h act = 18 × W 2 = 18 × 2 = 36 ∂ L ∂ h = ∂ L ∂ h act ⋅ ∂ h act ∂ h = 36 × 1 = 36 ∂ L ∂ W 1 = ∂ L ∂ h ⋅ ∂ h ∂ W 1 = 36 × x = 36 × 2 = 72 ∂ L ∂ b 1 = ∂ L ∂ h ⋅ ∂ h ∂ b 1 = 36 × 1 = 36 \begin{aligned} \frac{\partial L}{\partial y} &= 2(y - y_{\text{true}}) = 2(14 - 5) = 18 \\ \\ \frac{\partial L}{\partial W_2} &= \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial W_2} = 18 \times h_{\text{act}} = 18 \times 7 = 126 \\ \\ \frac{\partial L}{\partial b_2} &= \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial b_2} = 18 \times 1 = 18 \\ \\ \frac{\partial L}{\partial h_{\text{act}}} &= \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial h_{\text{act}}} = 18 \times W_2 = 18 \times 2 = 36 \\ \\ \frac{\partial L}{\partial h} &= \frac{\partial L}{\partial h_{\text{act}}} \cdot \frac{\partial h_{\text{act}}}{\partial h} = 36 \times 1 = 36 \\ \\ \frac{\partial L}{\partial W_1} &= \frac{\partial L}{\partial h} \cdot \frac{\partial h}{\partial W_1} = 36 \times x = 36 \times 2 = 72 \\ \\ \frac{\partial L}{\partial b_1} &= \frac{\partial L}{\partial h} \cdot \frac{\partial h}{\partial b_1} = 36 \times 1 = 36 \end{aligned} </math>∂y∂L∂W2∂L∂b2∂L∂hact∂L∂h∂L∂W1∂L∂b1∂L=2(y−ytrue)=2(14−5)=18=∂y∂L⋅∂W2∂y=18×hact=18×7=126=∂y∂L⋅∂b2∂y=18×1=18=∂y∂L⋅∂hact∂y=18×W2=18×2=36=∂hact∂L⋅∂h∂hact=36×1=36=∂h∂L⋅∂W1∂h=36×x=36×2=72=∂h∂L⋅∂b1∂h=36×1=36

注意:ReLU的导数在 <math xmlns="http://www.w3.org/1998/Math/MathML"> h > 0 h > 0 </math>h>0 时为1,在 <math xmlns="http://www.w3.org/1998/Math/MathML"> h ≤ 0 h \leq 0 </math>h≤0 时为0。这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> h = 7 > 0 h = 7 > 0 </math>h=7>0,所以导数为1。

得到梯度

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W 1 = 72 \frac{\partial L}{\partial W_1} = 72 </math>∂W1∂L=72
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ b 1 = 36 \frac{\partial L}{\partial b_1} = 36 </math>∂b1∂L=36
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W 2 = 126 \frac{\partial L}{\partial W_2} = 126 </math>∂W2∂L=126
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ b 2 = 18 \frac{\partial L}{\partial b_2} = 18 </math>∂b2∂L=18

自动微分(Automatic Differentiation)

好消息:我们不需要手动计算梯度!

现代深度学习框架(PyTorch、TensorFlow)会自动计算梯度

python 复制代码
import torch

# 定义参数(requires_grad=True表示需要计算梯度)
W1 = torch.tensor([[3.0]], requires_grad=True)
b1 = torch.tensor([1.0], requires_grad=True)
W2 = torch.tensor([[2.0]], requires_grad=True)
b2 = torch.tensor([0.0], requires_grad=True)

# 前向传播
x = torch.tensor([2.0])
y_true = torch.tensor([5.0])

h = W1 @ x + b1
h_act = torch.relu(h)
y = W2 @ h_act + b2
loss = (y - y_true) ** 2

# 反向传播(自动计算梯度)
loss.backward()

# 查看梯度
print(f"∂L/∂W1 = {W1.grad}")  # tensor([[72.]])
print(f"∂L/∂b1 = {b1.grad}")  # tensor([36.])
print(f"∂L/∂W2 = {W2.grad}")  # tensor([[126.]])
print(f"∂L/∂b2 = {b2.grad}")  # tensor([18.])

完全自动!我们只需要定义前向传播,反向传播由框架完成。

Transformer的反向传播

对于一个12层的GPT-2模型:

  1. 前向传播:输入 → Embedding → Layer1 → ... → Layer12 → LM Head → Loss
  2. 反向传播 :Loss → <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W lm \frac{\partial L}{\partial W_{\text{lm}}} </math>∂Wlm∂L → <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W 12 \frac{\partial L}{\partial W_{12}} </math>∂W12∂L → ... → <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W 1 \frac{\partial L}{\partial W_1} </math>∂W1∂L → <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ E token \frac{\partial L}{\partial E_{\text{token}}} </math>∂Etoken∂L

计算图

css 复制代码
前向:
x → E_token → Layer1 → Layer2 → ... → Layer12 → LM_Head → logits → softmax → loss
                ↓         ↓                ↓         ↓         ↓          ↓       ↓
反向:          ←         ←         ...     ←         ←         ←          ←       ←
             ∂L/∂E    ∂L/∂W1             ∂L/∂W12  ∂L/∂Wlm  ∂L/∂logits  ∂L/∂P   ∂L/∂L

每个参数都会得到一个梯度,告诉我们如何更新它。

梯度下降(Gradient Descent)

有了梯度,就可以更新参数了。

基本思想

沿着负梯度方向更新参数
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ new = θ old − η ⋅ ∂ L ∂ θ \theta_{\text{new}} = \theta_{\text{old}} - \eta \cdot \frac{\partial L}{\partial \theta} </math>θnew=θold−η⋅∂θ∂L

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ:参数(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 W_1 </math>W1、 <math xmlns="http://www.w3.org/1998/Math/MathML"> b 1 b_1 </math>b1 等)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η:学习率(Learning Rate)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ θ \frac{\partial L}{\partial \theta} </math>∂θ∂L:梯度

直观理解

  • 梯度正:参数增大会让Loss增大 → 减小参数
  • 梯度负:参数增大会让Loss减小 → 增大参数

具体例子

继续上面的MLP例子:

当前参数 : <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 = 3 W_1 = 3 </math>W1=3,梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ W 1 = 72 \frac{\partial L}{\partial W_1} = 72 </math>∂W1∂L=72

选择学习率 : <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 0.01 \eta = 0.01 </math>η=0.01

更新参数
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W 1 new = W 1 old − η ⋅ ∂ L ∂ W 1 = 3 − 0.01 × 72 = 2.28 W_1^{\text{new}} = W_1^{\text{old}} - \eta \cdot \frac{\partial L}{\partial W_1} = 3 - 0.01 \times 72 = 2.28 </math>W1new=W1old−η⋅∂W1∂L=3−0.01×72=2.28

同理
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> b 1 new = 1 − 0.01 × 36 = 0.64 W 2 new = 2 − 0.01 × 126 = 0.74 b 2 new = 0 − 0.01 × 18 = − 0.18 \begin{aligned} b_1^{\text{new}} &= 1 - 0.01 \times 36 = 0.64 \\ W_2^{\text{new}} &= 2 - 0.01 \times 126 = 0.74 \\ b_2^{\text{new}} &= 0 - 0.01 \times 18 = -0.18 \end{aligned} </math>b1newW2newb2new=1−0.01×36=0.64=2−0.01×126=0.74=0−0.01×18=−0.18

验证:用新参数再做一次前向传播,Loss应该变小了。

三种梯度下降

1. Batch Gradient Descent(批量梯度下降)

每次使用所有训练数据计算梯度
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ = θ − η ⋅ 1 N ∑ i = 1 N ∇ L i \theta = \theta - \eta \cdot \frac{1}{N} \sum_{i=1}^{N} \nabla L_i </math>θ=θ−η⋅N1i=1∑N∇Li

优点

  • 梯度准确,收敛稳定

缺点

  • 计算量大(N可能有百万级)
  • 内存不够(无法一次加载所有数据)
  • 更新慢(每个epoch只更新一次)
2. Stochastic Gradient Descent (SGD,随机梯度下降)

每次只用一个样本计算梯度
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ = θ − η ⋅ ∇ L i \theta = \theta - \eta \cdot \nabla L_i </math>θ=θ−η⋅∇Li

优点

  • 更新快(每个样本都更新)
  • 内存友好

缺点

  • 梯度噪声大,不稳定
  • 可能震荡,难以收敛到最优点
3. Mini-Batch Gradient Descent(小批量梯度下降)

每次用一小批样本(如32、64、128个)计算梯度
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ = θ − η ⋅ 1 B ∑ i = 1 B ∇ L i \theta = \theta - \eta \cdot \frac{1}{B} \sum_{i=1}^{B} \nabla L_i </math>θ=θ−η⋅B1i=1∑B∇Li

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> B B </math>B 是batch size(批量大小)。

优点

  • 平衡了Batch GD和SGD的优缺点
  • 梯度相对稳定
  • 可以利用GPU并行计算

缺点

  • 需要调整batch size

这是现代深度学习的标准做法!

训练循环

完整的训练过程:

python 复制代码
for epoch in range(num_epochs):  # 遍历多个epoch
    for batch in dataloader:     # 遍历所有batch
        # 1. 前向传播
        outputs = model(batch['input'])
        loss = loss_fn(outputs, batch['target'])

        # 2. 反向传播
        loss.backward()  # 计算梯度

        # 3. 参数更新
        optimizer.step()  # 更新参数

        # 4. 梯度清零(为下一个batch准备)
        optimizer.zero_grad()

学习率(Learning Rate)

学习率 <math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η 是梯度下降中最重要的超参数。

学习率的影响

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ new = θ old − η ⋅ ∂ L ∂ θ \theta_{\text{new}} = \theta_{\text{old}} - \eta \cdot \frac{\partial L}{\partial \theta} </math>θnew=θold−η⋅∂θ∂L

学习率太大(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 1.0 \eta=1.0 </math>η=1.0)
  • 参数更新幅度太大
  • 可能"跳过"最优点
  • Loss震荡或发散
javascript 复制代码
Loss
  ^
  |     /\    /\
  |    /  \  /  \
  |   /    \/    \
  |  /
  +-----------------> Iteration
学习率太小(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 0.0001 \eta=0.0001 </math>η=0.0001)
  • 参数更新幅度太小
  • 收敛极其缓慢
  • 可能卡在局部最优
markdown 复制代码
Loss
  ^
  |
  |   ___
  |  /
  | /
  |/_______________
  +-----------------> Iteration
     很长时间才下降一点
学习率合适(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> η = 0.001 \eta=0.001 </math>η=0.001)
  • 稳定下降
  • 收敛速度合理
lua 复制代码
Loss
  ^
  |
  |\
  | \
  |  \___
  |      ----____
  +-----------------> Iteration

典型的学习率值

模型规模 典型学习率
小模型(<100M) 1e-3 ~ 1e-4
中型模型(100M-1B) 5e-4 ~ 1e-4
大模型(1B-100B+) 1e-4 ~ 1e-5

为什么大模型用更小的学习率?

  • 参数量大,梯度累积效应强
  • 需要更细致的调整
  • 避免训练不稳定

学习率调度(Learning Rate Scheduling)

问题:固定学习率不是最优的。

解决方案:在训练过程中动态调整学习率。

1. Warmup(预热)

训练初期使用很小的学习率,然后逐渐增大:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> η ( t ) = η max ⋅ min ⁡ ( 1 , t T warmup ) \eta(t) = \eta_{\text{max}} \cdot \min\left(1, \frac{t}{T_{\text{warmup}}}\right) </math>η(t)=ηmax⋅min(1,Twarmupt)

原因

  • 训练初期,参数是随机的,梯度可能很大
  • 小学习率避免参数被"破坏"
  • 通常warmup 1000-10000步

可视化

lua 复制代码
Learning Rate
  ^
  |        /----------
  |       /
  |      /
  |     /
  |    /
  +----+--------------> Step
     warmup
2. Cosine Annealing(余弦退火)

学习率按余弦曲线衰减:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> η ( t ) = η min ⁡ + 1 2 ( η max ⁡ − η min ⁡ ) ( 1 + cos ⁡ ( t π T ) ) \eta(t) = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{t \pi}{T}\right)\right) </math>η(t)=ηmin+21(ηmax−ηmin)(1+cos(Ttπ))

可视化

lua 复制代码
Learning Rate
  ^
  |     ___
  |    /   \___
  |   /        \___
  |  /             \___
  | /                  ---
  +----------------------->\__Step
  warmup   cosine decay
3. 组合:Warmup + Cosine Decay

GPT-3等大模型常用的策略:

python 复制代码
def get_lr(step, warmup_steps, max_steps, lr_max, lr_min):
    if step < warmup_steps:
        # Warmup阶段:线性增长
        return lr_max * step / warmup_steps
    else:
        # Cosine衰减阶段
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))

AdamW优化器

虽然梯度下降很简单,但实践中有很多问题。AdamW是目前训练大模型的标准优化器。

SGD的问题

标准的梯度下降(SGD):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ t + 1 = θ t − η ⋅ g t \theta_{t+1} = \theta_t - \eta \cdot g_t </math>θt+1=θt−η⋅gt

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> g t = ∂ L ∂ θ g_t = \frac{\partial L}{\partial \theta} </math>gt=∂θ∂L 是当前梯度。

问题1:不同参数的梯度尺度差异大

  • 某些参数的梯度很大(如1000)
  • 某些参数的梯度很小(如0.001)
  • 用相同的学习率无法兼顾

问题2:梯度噪声

  • Mini-batch的梯度有噪声
  • 可能在最优点附近震荡

问题3:高维空间的"峡谷"

  • 某些方向梯度大,某些方向梯度小
  • SGD在峡谷中会"之字形"前进,效率低

Adam优化器

Adam(Adaptive Moment Estimation) 通过维护梯度的一阶矩 (均值)和二阶矩(方差)来自适应调整每个参数的学习率。

Adam的核心思想
  1. 动量(Momentum):累积历史梯度的指数移动平均
  2. 自适应学习率:根据梯度的历史方差调整学习率
Adam算法(简化版)

初始化
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> m 0 = 0 v 0 = 0 \begin{aligned} m_0 &= 0 \\ v_0 &= 0 \end{aligned} </math>m0v0=0=0

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> m 0 m_0 </math>m0:一阶矩初始值(梯度的均值)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> v 0 v_0 </math>v0:二阶矩初始值(梯度的方差)

每次迭代
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> g t = ∂ L ∂ θ t m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 m ^ t = m t 1 − β 1 t v ^ t = v t 1 − β 2 t θ t + 1 = θ t − η ⋅ m ^ t v ^ t + ϵ \begin{aligned} g_t &= \frac{\partial L}{\partial \theta_t} \\ \\ m_t &= \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t \\ \\ v_t &= \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ \\ \hat{m}_t &= \frac{m_t}{1 - \beta_1^t} \\ \\ \hat{v}t &= \frac{v_t}{1 - \beta_2^t} \\ \\ \theta{t+1} &= \theta_t - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \end{aligned} </math>gtmtvtm^tv^tθt+1=∂θt∂L=β1⋅mt−1+(1−β1)⋅gt=β2⋅vt−1+(1−β2)⋅gt2=1−β1tmt=1−β2tvt=θt−η⋅v^t +ϵm^t

各步骤说明:

  1. <math xmlns="http://www.w3.org/1998/Math/MathML"> g t g_t </math>gt:计算当前梯度
  2. <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt:更新一阶矩(梯度的指数移动平均)
  3. <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt:更新二阶矩(梯度平方的指数移动平均)
  4. <math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ t \hat{m}_t </math>m^t:偏差修正后的一阶矩
  5. <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t:偏差修正后的二阶矩
  6. <math xmlns="http://www.w3.org/1998/Math/MathML"> θ t + 1 \theta_{t+1} </math>θt+1:根据修正后的矩更新参数

参数解释

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 = 0.9 \beta_1 = 0.9 </math>β1=0.9:一阶矩的衰减率(通常)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 = 0.999 \beta_2 = 0.999 </math>β2=0.999:二阶矩的衰减率(通常)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ = 1 0 − 8 \epsilon = 10^{-8} </math>ϵ=10−8:防止除零的小常数
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> m ^ t \hat{m}_t </math>m^t:偏差修正后的一阶矩
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t:偏差修正后的二阶矩
直观理解

一阶矩 <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt(动量)

  • 累积历史梯度的加权平均
  • 让参数更新更平滑,减少震荡
  • 类似物理中的"惯性"

二阶矩 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt(方差)

  • 累积历史梯度平方的加权平均
  • 反映梯度的波动程度
  • 梯度波动大的参数,学习率自动减小

自适应学习率
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> effective_lr = η v ^ t + ϵ \text{effective\_lr} = \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} </math>effective_lr=v^t +ϵη

  • 梯度波动大( <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t大):学习率小
  • 梯度波动小( <math xmlns="http://www.w3.org/1998/Math/MathML"> v ^ t \hat{v}_t </math>v^t小):学习率大

AdamW:Adam + Weight Decay

AdamW 是Adam的改进版本,专门针对深度学习中的权重衰减(Weight Decay)

什么是权重衰减?

问题:如果不加限制,神经网络的参数可能会变得非常大。

python 复制代码
# 举个例子
# 假设模型学到了这样的参数:
W = [[100, -200],
     [300, -400]]

# 过大的参数会导致:
1. 过拟合:模型对训练数据"记住"而不是"理解"
2. 数值不稳定:计算时容易溢出
3. 泛化能力差:在新数据上表现不好

权重衰减是一种正则化技术,通过惩罚大参数来防止上述问题:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L total = L + λ 2 ∥ θ ∥ 2 L_{\text{total}} = L + \frac{\lambda}{2} \|\theta\|^2 </math>Ltotal=L+2λ∥θ∥2

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L:原始损失函数
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ:权重衰减系数(如0.01、0.1)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ∥ θ ∥ 2 = θ 1 2 + θ 2 2 + ... + θ n 2 \|\theta\|^2 = \theta_1^2 + \theta_2^2 + \ldots + \theta_n^2 </math>∥θ∥2=θ12+θ22+...+θn2:参数的L2范数

直观理解

ini 复制代码
不用权重衰减:
Loss = 预测错误程度

使用权重衰减:
Loss = 预测错误程度 + λ × (参数大小的惩罚)

模型需要在"预测准确"和"参数不要太大"之间取得平衡
Adam中的权重衰减问题

在标准Adam中,权重衰减是通过把 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ θ \lambda \theta </math>λθ 加到梯度上实现的:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> g t = ∂ L ∂ θ + λ θ g_t = \frac{\partial L}{\partial \theta} + \lambda \theta </math>gt=∂θ∂L+λθ

然后用这个"带权重衰减的梯度"来更新动量和方差:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 \begin{aligned} m_t &= \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t \\ v_t &= \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 \end{aligned} </math>mtvt=β1⋅mt−1+(1−β1)⋅gt=β2⋅vt−1+(1−β2)⋅gt2

注意:这里的 <math xmlns="http://www.w3.org/1998/Math/MathML"> g t g_t </math>gt 包含了权重衰减项 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ θ \lambda \theta </math>λθ,因此 <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt 都会受到权重衰减的影响。

问题在哪里?

权重衰减项 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ θ \lambda \theta </math>λθ 被混入了自适应学习率的计算中:

python 复制代码
# 假设某个参数 θ = 10(比较大)
# 真实梯度 g = 0.1(比较小)
# λ = 0.01

# Adam中:
g_with_decay = 0.1 + 0.01 × 10 = 0.2

# 这个0.2会被用来计算 v_t(梯度方差)
v_t = 0.999 × v_{t-1} + 0.001 × 0.2² = 0.999 × v_{t-1} + 0.00004

# 因为参数大 → 权重衰减项大 → g_with_decay大 → v_t变大
# 而实际学习率 = lr / √v_t
# v_t变大 → 实际学习率变小 → 权重衰减效果被削弱!

# 结果:权重衰减的效果被自适应学习率"对冲"了

本质问题:权重衰减应该独立于梯度大小,但在Adam中它却受到自适应学习率的影响。

AdamW的解决方案

AdamW将权重衰减从梯度中解耦,直接在参数更新时应用

标准Adam的做法(错误)

arduino 复制代码
步骤1:g_t = ∂L/∂θ + λθ          (梯度混入权重衰减)
步骤2:m_t = β₁×m_{t-1} + (1-β₁)×g_t
步骤3:v_t = β₂×v_{t-1} + (1-β₂)×g_t²
步骤4:θ_{t+1} = θ_t - lr × m̂_t/(√v̂_t + ε)
       ↑
    权重衰减效果被自适应学习率削弱

AdamW的做法(正确)

arduino 复制代码
步骤1:g_t = ∂L/∂θ
      (纯梯度,不含权重衰减)

步骤2:m_t = β₁×m_{t-1} + (1-β₁)×g_t

步骤3:v_t = β₂×v_{t-1} + (1-β₂)×g_t²

步骤4:θ_{t+1} = θ_t - lr × [m̂_t/(√v̂_t + ε) + λ×θ_t]
                              ↑                ↑
                          Adam更新        权重衰减(独立)

数学形式
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ t + 1 = θ t − η ⋅ ( m ^ t v ^ t + ϵ + λ θ t ) \theta_{t+1} = \theta_t - \eta \cdot \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t \right) </math>θt+1=θt−η⋅(v^t +ϵm^t+λθt)

或者分两步写更清楚:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ t + 1 ′ = θ t − η ⋅ m ^ t v ^ t + ϵ θ t + 1 = θ t + 1 ′ − η ⋅ λ ⋅ θ t \begin{aligned} \theta_{t+1}' &= \theta_t - \eta \cdot \frac{\hat{m}t}{\sqrt{\hat{v}t} + \epsilon} \\ \theta{t+1} &= \theta{t+1}' - \eta \cdot \lambda \cdot \theta_t \end{aligned} </math>θt+1′θt+1=θt−η⋅v^t +ϵm^t=θt+1′−η⋅λ⋅θt

其中:

  • 第一步是标准的Adam更新
  • 第二步是权重衰减,独立应用

具体例子对比

python 复制代码
# 继续上面的例子
# θ = 10, 真实梯度 g = 0.1, λ = 0.01

# AdamW中:
g_t = 0.1  # 只用真实梯度
v_t = 0.999 × v_{t-1} + 0.001 × 0.1²  # 不受权重衰减影响

# Adam更新部分:
Δθ_adam = lr × 0.1 / √v_t

# 权重衰减部分(独立应用):
Δθ_decay = lr × λ × θ = lr × 0.01 × 10 = lr × 0.1

# 总更新:
θ_{t+1} = θ - Δθ_adam - Δθ_decay

# 关键:权重衰减不受自适应学习率影响!
# 参数越大,衰减越强,完全符合正则化的本意

可视化对比

arduino 复制代码
Adam (标准权重衰减):
┌─────────────────────────────────────┐
│  梯度 g + 权重衰减 λθ                │
└──────────┬──────────────────────────┘
           ↓
    ┌──────────────┐
    │  计算 m_t    │  ← 权重衰减混入动量
    │  计算 v_t    │  ← 权重衰减影响方差
    └──────┬───────┘
           ↓
    自适应学习率 lr/√v_t
           ↓
    参数更新 ← 权重衰减效果被削弱 ✗


AdamW (解耦权重衰减):
┌─────────────────────────────────────┐
│  纯梯度 g (不含权重衰减)             │
└──────────┬──────────────────────────┘
           ↓
    ┌──────────────┐
    │  计算 m_t    │  ← 只用真实梯度
    │  计算 v_t    │  ← 只用真实梯度
    └──────┬───────┘
           ↓
    ┌──────────────┐      ┌──────────────┐
    │  Adam更新    │  +   │ 权重衰减独立  │
    │  lr×m̂/√v̂    │      │   lr×λ×θ     │
    └──────┬───────┘      └──────┬───────┘
           └──────────┬───────────┘
                      ↓
              参数更新 ← 权重衰减效果完整 ✓

效果对比

对比项 Adam AdamW
权重衰减位置 混入梯度中 独立于梯度
权重衰减是否受自适应学习率影响 是(被削弱) 否(独立)
大参数的衰减效果 较弱 强(符合预期)
大模型训练效果 较差 更好
泛化能力 一般 更好
是否被广泛采用 较少 GPT-3、LLaMA标准

总结

  • Adam : <math xmlns="http://www.w3.org/1998/Math/MathML"> g t = ∂ L ∂ θ + λ θ g_t = \frac{\partial L}{\partial \theta} + \lambda \theta </math>gt=∂θ∂L+λθ,然后用 <math xmlns="http://www.w3.org/1998/Math/MathML"> g t g_t </math>gt 更新 <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt → 权重衰减效果被自适应学习率削弱
  • AdamW : <math xmlns="http://www.w3.org/1998/Math/MathML"> g t = ∂ L ∂ θ g_t = \frac{\partial L}{\partial \theta} </math>gt=∂θ∂L,更新 <math xmlns="http://www.w3.org/1998/Math/MathML"> m t m_t </math>mt 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt 后,再独立应用权重衰减 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ← θ − η λ θ \theta \leftarrow \theta - \eta \lambda \theta </math>θ←θ−ηλθ → 权重衰减效果完整保留

这就是为什么现代大模型训练都使用AdamW而不是Adam!

AdamW的超参数

参数 典型值 说明
<math xmlns="http://www.w3.org/1998/Math/MathML"> η \eta </math>η(学习率) 1e-4 ~ 1e-5 最重要的超参数
<math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 \beta_1 </math>β1 0.9 一阶矩衰减率
<math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 \beta_2 </math>β2 0.999 或 0.95 二阶矩衰减率
<math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ(weight decay) 0.01 ~ 0.1 权重衰减系数
<math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ 1e-8 数值稳定性常数

GPT-3的设置

  • Learning rate: 6e-5(with warmup and cosine decay)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 \beta_1 </math>β1 = 0.9
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β 2 \beta_2 </math>β2 = 0.95
  • Weight decay: 0.1

PyTorch实现

python 复制代码
import torch
import torch.optim as optim

# 定义模型
model = GPT2Model(...)

# 创建AdamW优化器
optimizer = optim.AdamW(
    model.parameters(),
    lr=6e-5,              # 学习率
    betas=(0.9, 0.95),    # (β₁, β₂)
    weight_decay=0.1,     # 权重衰减
    eps=1e-8              # ε
)

# 训练循环
for batch in dataloader:
    # 前向传播
    outputs = model(batch['input'])
    loss = loss_fn(outputs, batch['target'])

    # 反向传播
    loss.backward()

    # 梯度裁剪(防止梯度爆炸)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # 参数更新(AdamW)
    optimizer.step()

    # 梯度清零
    optimizer.zero_grad()

完整的训练流程

让我们把所有概念串起来,看一个完整的训练流程:

伪代码

python 复制代码
# 1. 初始化
model = Transformer(...)  # 随机初始化参数
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
lr_scheduler = CosineAnnealingWarmup(optimizer, warmup_steps=2000)

# 2. 训练循环
for epoch in range(num_epochs):
    for batch in dataloader:
        # (1) 前向传播
        input_ids = batch['input']       # [batch_size, seq_len]
        target_ids = batch['target']     # [batch_size, seq_len]

        logits = model(input_ids)        # [batch_size, seq_len, vocab_size]

        # (2) 计算Loss
        loss = cross_entropy(logits, target_ids)

        # (3) 反向传播
        loss.backward()  # 自动计算所有参数的梯度

        # (4) 梯度裁剪(可选,防止梯度爆炸)
        clip_grad_norm_(model.parameters(), max_norm=1.0)

        # (5) 参数更新(AdamW)
        optimizer.step()

        # (6) 学习率调整
        lr_scheduler.step()

        # (7) 梯度清零
        optimizer.zero_grad()

        # (8) 打印进度
        if step % 100 == 0:
            print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}, LR: {lr_scheduler.get_last_lr()[0]:.6f}")

完整流程图

scss 复制代码
初始化模型参数(随机)
    ↓
┌─────────────────────── Training Loop ───────────────────────┐
│                                                               │
│  加载一个batch的数据                                          │
│    ↓                                                         │
│  前向传播:input → Transformer → logits                       │
│    ↓                                                         │
│  计算Loss:Cross-Entropy(logits, target)                     │
│    ↓                                                         │
│  反向传播:计算梯度 ∂L/∂θ                                     │
│    ↓                                                         │
│  梯度裁剪:防止梯度爆炸                                        │
│    ↓                                                         │
│  参数更新:θ ← θ - AdamW(∂L/∂θ)                              │
│    ↓                                                         │
│  学习率调整:Warmup + Cosine Decay                            │
│    ↓                                                         │
│  梯度清零:准备下一个batch                                     │
│    ↓                                                         │
│  [循环]                                                       │
│                                                               │
└───────────────────────────────────────────────────────────────┘
    ↓
训练完成,得到优化后的参数

小结

  1. 前向传播

    • 数据从输入到输出的计算过程
    • 得到模型的预测结果
  2. 损失函数

    • 衡量模型输出与真实答案的差距
    • 语言模型常用交叉熵损失: <math xmlns="http://www.w3.org/1998/Math/MathML"> L = − log ⁡ P ( 正确答案 ) L = -\log P(\text{正确答案}) </math>L=−logP(正确答案)
  3. 反向传播

    • 利用链式法则计算梯度 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ θ \frac{\partial L}{\partial \theta} </math>∂θ∂L
    • 现代框架自动完成(AutoGrad)
  4. 梯度下降

    • 沿着负梯度方向更新参数: <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ← θ − η ∇ L \theta \leftarrow \theta - \eta \nabla L </math>θ←θ−η∇L
    • 实践中使用Mini-Batch GD
  5. 学习率

    • 控制参数更新的步长
    • 需要careful tuning
    • 通常使用Warmup + Cosine Decay
  6. AdamW优化器

    • 结合动量和自适应学习率
    • 独立的权重衰减
    • 是训练大模型的标准选择

训练的本质:通过不断的"前向计算-计算误差-反向传播-更新参数"循环,让模型的输出越来越接近正确答案,最终学会预测下一个词。

这就是大模型"学习"的秘密!

相关推荐
得鹿1 小时前
MySQL基础架构与存储引擎、索引、事务、锁、日志
后端
程序员飞哥2 小时前
Block科技公司裁员四千人,竟然是因为 AI ?
人工智能·后端·程序员
王小酱2 小时前
Everything Claude Code 速查指南
openai·ai编程·aiops
JavaEdge在掘金2 小时前
Claude Code 直连 Ollama / LM Studio:本地、云端开源模型都能跑
后端
LSTM972 小时前
使用 Python 将 TXT 转换为 PDF (自动分页)
后端
于眠牧北2 小时前
Java开发学习提高效率的辅助软件和插件:一键生成接口文档,AI制作原型等
后端
JordanHaidee2 小时前
Python 中 `if x:` 到底在判断什么?
后端·python
开心就好20252 小时前
不越狱能抓到 HTTPS 吗?在未越狱 iPhone 上抓取 HTTPS
后端·ios
用户908324602732 小时前
Spring Boot + MyBatis-Plus 多租户实战:从数据隔离到权限控制的完整方案
java·后端