【NanoGPT 学习 02】model.py MLP 层和 Block 详解

前言

接上一篇文件末尾:【NanoGPT 学习 01】model.py 代码详解. 继续学习 model.py 的代码设计

Block 组合块

Block 将 LayerNorm、Attention、MLP 层组合成一个"模块",作为 Transformer 的处理块。

py 复制代码
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

Block​ 类的组件及其作用

  1. 层归一化(LayerNorm)

    • 层归一化是在构造中每层的开始处应用的,用于规范化输入数据,确保数据在通过激活函数前具有稳定的分布。这有助于加快训练速度,减少训练过程中的梯度消失或爆炸问题。
    • Block类中,ln_1ln_2分别应用于自注意力模块和前馈神经网络模块的输入,为后续的处理步骤提供稳定的输入。
  2. 自注意力(CausalSelfAttention)

    • 自注意力模块允许模型在处理序列的每个元素时,考虑到序列中的所有元素,这是通过计算序列中每个元素对其它所有元素的加权重要性实现的。它是Transformer的核心,使模型能够捕获长距离依赖。
    • CausalSelfAttention是一种自注意力的变体,它通过掩码操作保证了在生成当前元素的表示时,只考虑到之前的元素(及自身),适用于如文本生成等任务,其中当前词的生成应仅依赖于之前的词。
  3. 前馈神经网络(MLP)

    • 在自注意力模块后,数据会通过一个前馈神经网络,该网络通常包括两个线性层和一个非线性激活函数,这里是GELU函数。这个网络可以在注意力机制处理过的数据上进一步学习复杂的表示。
    • MLP对每个位置的表示进行独立处理,增加了模型处理不同语义信息的能力。

forward 中的"残差连接"

py 复制代码
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

作用

1. 保持信息流畅

实际上,残差连接可以让输入x​直接"跳过"某些层,从一个层直接连接到后面的层。这意味着,在深层网络中,信息可以更直接地传递,减少了信息在传递过程中的损失,帮助模型在深层次上保持有效的梯度流动,这对于训练深层网络十分关键。

2. 缓解梯度消失/爆炸

深度学习网络在训练过程中常遇到的一个问题是梯度消失或爆炸。简单来说,当网络层数增加到一定程度时,反向传播过程中计算得到的梯度可能会变得极小或极大,从而导致网络难以训练。残差连接通过直接将输入添加到输出,使得梯度可以直接通过这些连接反向传播,即便是在非常深的网络中也能更稳定地训练。

3. 提高网络的学习能力

残差连接实际上是让网络学习输入与输出之间的残差(即变化)。如果一个特征不需要任何改变,理论上网络可以学习到一个非常小的变化量,实际上接近于直接将这个特征通过网络传递下去。这使得网络能够更加专注于学习输入与输出之间的差异,而不是完全从头开始学习输出,这样可以提高学习的效率和效果。

4. 支持更深层的网络构建

通过以上优点,残差连接使得构建更深层的网络成为可能。这在传统的深度学习模型中是个挑战,因为越深的网络越容易遇到训练困难和性能衰退的问题。而Transformer模型利用残差连接,成功训练了几十层乃至上百层的网络,实现了 state-of-art 的性能表现。

残差连接为什么使用"相加运算"

直观上的理解

  1. 简化学习目标:残差连接让网络专注于学习输入和期望输出之间的差异,而不是直接学习输出。相加操作使得如果需要的改变很小,网络可以学到一个接近零的差异,效果上等同于直接输出输入。这种方式自然而直观地实现了残差学习。
  2. 保持信息流:相加操作允许原始输入信号直接传递到后续层,没有任何变换,这有利于避免信息流在深层网络中的衰减,同时也有助于梯度在反向传播时直接流回,减少梯度消失的问题。

技术和理论上的考量

  1. 维持数据的尺度相加操作不改变数据的尺度(大小和范围),这对于保持网络各层的数据稳定性有重要作用。如果使用其他运算(如乘法或平均),可能会不断改变数据的尺度和分布,增加训练过程中的复杂度。
  2. 梯度传播的效率:在反向传播时,加法操作有一个很好的性质------梯度可以无改变地流过加法节点。这保证了即使在深层网络中,梯度也能有效地传播回每个层,有助于缓解梯度消失问题。

替代方案及其挑战

尽管"相加"运算是实现残巀连接最常见的方式,但也存在其他的替代方案,如元素级乘法(Hadamard乘法)或等,每种方法都有其优缺点:

  • 元素级乘法:这种方式可能导致输出尺度的变化,同时对输入信号的直接传播不友好。乘法操作会改变梯度分布,可能使训练变得更困难。
  • 连接:虽然通过拼接输入和输出保留了全部信息,但这会导致网络层的输入输出维度不断增加,给网络设计和计算带来挑战。

未完待续

在最新的文章 【NanoGPT 学习 03】GPT 代码架构分析 中,将前两篇文章的组件全部整合在一起进行分析。

相关推荐
zm-v-159304339862 小时前
ArcGIS+GPT:多领域地理分析与决策新方案
gpt·arcgis
winner888110 小时前
从 BERT 到 GPT:Encoder 的 “全局视野” 如何喂饱 Decoder 的 “逐词纠结”
人工智能·gpt·bert·encoder·decoder
q_q王1 天前
本地知识库工具FASTGPT的安装与搭建
python·大模型·llm·知识库·fastgpt
AI布道师Warren2 天前
AI 智能体蓝图:拆解认知、进化与协作核心
llm
JoernLee2 天前
Qwen3术语解密:读懂大模型黑话
人工智能·开源·llm
火云牌神2 天前
本地大模型编程实战(28)查询图数据库NEO4J(1)
python·llm·neo4j·langgraph
Goboy2 天前
用Trae,找初恋,代码写人生,Trae圆你初恋梦。
llm·trae
CoderJia程序员甲2 天前
MarkItDown:如何高效将各类文档转换为适合 LLM 处理的 Markdown 格式
ai·llm·markdown·文档转换
金木讲编程2 天前
用Function Calling让GPT查询数据库(含示例)
gpt·ai编程