前言
接上一篇文件末尾:【NanoGPT 学习 01】model.py 代码详解. 继续学习 model.py 的代码设计
Block 组合块
Block 将 LayerNorm、Attention、MLP 层组合成一个"模块",作为 Transformer 的处理块。
py
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
Block
类的组件及其作用
-
层归一化(LayerNorm)
- 层归一化是在构造中每层的开始处应用的,用于规范化输入数据,确保数据在通过激活函数前具有稳定的分布。这有助于加快训练速度,减少训练过程中的梯度消失或爆炸问题。
Block
类中,ln_1
和ln_2
分别应用于自注意力模块和前馈神经网络模块的输入,为后续的处理步骤提供稳定的输入。
-
自注意力(CausalSelfAttention)
- 自注意力模块允许模型在处理序列的每个元素时,考虑到序列中的所有元素,这是通过计算序列中每个元素对其它所有元素的加权重要性实现的。它是Transformer的核心,使模型能够捕获长距离依赖。
CausalSelfAttention
是一种自注意力的变体,它通过掩码操作保证了在生成当前元素的表示时,只考虑到之前的元素(及自身),适用于如文本生成等任务,其中当前词的生成应仅依赖于之前的词。
-
前馈神经网络(MLP)
- 在自注意力模块后,数据会通过一个前馈神经网络,该网络通常包括两个线性层和一个非线性激活函数,这里是GELU函数。这个网络可以在注意力机制处理过的数据上进一步学习复杂的表示。
- MLP对每个位置的表示进行独立处理,增加了模型处理不同语义信息的能力。
forward 中的"残差连接"
py
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
作用
1. 保持信息流畅
实际上,残差连接可以让输入x
直接"跳过"某些层,从一个层直接连接到后面的层。这意味着,在深层网络中,信息可以更直接地传递,减少了信息在传递过程中的损失,帮助模型在深层次上保持有效的梯度流动,这对于训练深层网络十分关键。
2. 缓解梯度消失/爆炸
深度学习网络在训练过程中常遇到的一个问题是梯度消失或爆炸。简单来说,当网络层数增加到一定程度时,反向传播过程中计算得到的梯度可能会变得极小或极大,从而导致网络难以训练。残差连接通过直接将输入添加到输出,使得梯度可以直接通过这些连接反向传播,即便是在非常深的网络中也能更稳定地训练。
3. 提高网络的学习能力
残差连接实际上是让网络学习输入与输出之间的残差(即变化)。如果一个特征不需要任何改变,理论上网络可以学习到一个非常小的变化量,实际上接近于直接将这个特征通过网络传递下去。这使得网络能够更加专注于学习输入与输出之间的差异,而不是完全从头开始学习输出,这样可以提高学习的效率和效果。
4. 支持更深层的网络构建
通过以上优点,残差连接使得构建更深层的网络成为可能。这在传统的深度学习模型中是个挑战,因为越深的网络越容易遇到训练困难和性能衰退的问题。而Transformer模型利用残差连接,成功训练了几十层乃至上百层的网络,实现了 state-of-art 的性能表现。
残差连接为什么使用"相加运算"
直观上的理解
- 简化学习目标:残差连接让网络专注于学习输入和期望输出之间的差异,而不是直接学习输出。相加操作使得如果需要的改变很小,网络可以学到一个接近零的差异,效果上等同于直接输出输入。这种方式自然而直观地实现了残差学习。
- 保持信息流:相加操作允许原始输入信号直接传递到后续层,没有任何变换,这有利于避免信息流在深层网络中的衰减,同时也有助于梯度在反向传播时直接流回,减少梯度消失的问题。
技术和理论上的考量
- 维持数据的尺度 :相加操作不改变数据的尺度(大小和范围),这对于保持网络各层的数据稳定性有重要作用。如果使用其他运算(如乘法或平均),可能会不断改变数据的尺度和分布,增加训练过程中的复杂度。
- 梯度传播的效率:在反向传播时,加法操作有一个很好的性质------梯度可以无改变地流过加法节点。这保证了即使在深层网络中,梯度也能有效地传播回每个层,有助于缓解梯度消失问题。
替代方案及其挑战
尽管"相加"运算是实现残巀连接最常见的方式,但也存在其他的替代方案,如元素级乘法(Hadamard乘法)或等,每种方法都有其优缺点:
- 元素级乘法:这种方式可能导致输出尺度的变化,同时对输入信号的直接传播不友好。乘法操作会改变梯度分布,可能使训练变得更困难。
- 连接:虽然通过拼接输入和输出保留了全部信息,但这会导致网络层的输入输出维度不断增加,给网络设计和计算带来挑战。
未完待续
在最新的文章 【NanoGPT 学习 03】GPT 代码架构分析 中,将前两篇文章的组件全部整合在一起进行分析。