【NanoGPT 学习 02】model.py MLP 层和 Block 详解

前言

接上一篇文件末尾:【NanoGPT 学习 01】model.py 代码详解. 继续学习 model.py 的代码设计

Block 组合块

Block 将 LayerNorm、Attention、MLP 层组合成一个"模块",作为 Transformer 的处理块。

py 复制代码
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

Block​ 类的组件及其作用

  1. 层归一化(LayerNorm)

    • 层归一化是在构造中每层的开始处应用的,用于规范化输入数据,确保数据在通过激活函数前具有稳定的分布。这有助于加快训练速度,减少训练过程中的梯度消失或爆炸问题。
    • Block类中,ln_1ln_2分别应用于自注意力模块和前馈神经网络模块的输入,为后续的处理步骤提供稳定的输入。
  2. 自注意力(CausalSelfAttention)

    • 自注意力模块允许模型在处理序列的每个元素时,考虑到序列中的所有元素,这是通过计算序列中每个元素对其它所有元素的加权重要性实现的。它是Transformer的核心,使模型能够捕获长距离依赖。
    • CausalSelfAttention是一种自注意力的变体,它通过掩码操作保证了在生成当前元素的表示时,只考虑到之前的元素(及自身),适用于如文本生成等任务,其中当前词的生成应仅依赖于之前的词。
  3. 前馈神经网络(MLP)

    • 在自注意力模块后,数据会通过一个前馈神经网络,该网络通常包括两个线性层和一个非线性激活函数,这里是GELU函数。这个网络可以在注意力机制处理过的数据上进一步学习复杂的表示。
    • MLP对每个位置的表示进行独立处理,增加了模型处理不同语义信息的能力。

forward 中的"残差连接"

py 复制代码
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

作用

1. 保持信息流畅

实际上,残差连接可以让输入x​直接"跳过"某些层,从一个层直接连接到后面的层。这意味着,在深层网络中,信息可以更直接地传递,减少了信息在传递过程中的损失,帮助模型在深层次上保持有效的梯度流动,这对于训练深层网络十分关键。

2. 缓解梯度消失/爆炸

深度学习网络在训练过程中常遇到的一个问题是梯度消失或爆炸。简单来说,当网络层数增加到一定程度时,反向传播过程中计算得到的梯度可能会变得极小或极大,从而导致网络难以训练。残差连接通过直接将输入添加到输出,使得梯度可以直接通过这些连接反向传播,即便是在非常深的网络中也能更稳定地训练。

3. 提高网络的学习能力

残差连接实际上是让网络学习输入与输出之间的残差(即变化)。如果一个特征不需要任何改变,理论上网络可以学习到一个非常小的变化量,实际上接近于直接将这个特征通过网络传递下去。这使得网络能够更加专注于学习输入与输出之间的差异,而不是完全从头开始学习输出,这样可以提高学习的效率和效果。

4. 支持更深层的网络构建

通过以上优点,残差连接使得构建更深层的网络成为可能。这在传统的深度学习模型中是个挑战,因为越深的网络越容易遇到训练困难和性能衰退的问题。而Transformer模型利用残差连接,成功训练了几十层乃至上百层的网络,实现了 state-of-art 的性能表现。

残差连接为什么使用"相加运算"

直观上的理解

  1. 简化学习目标:残差连接让网络专注于学习输入和期望输出之间的差异,而不是直接学习输出。相加操作使得如果需要的改变很小,网络可以学到一个接近零的差异,效果上等同于直接输出输入。这种方式自然而直观地实现了残差学习。
  2. 保持信息流:相加操作允许原始输入信号直接传递到后续层,没有任何变换,这有利于避免信息流在深层网络中的衰减,同时也有助于梯度在反向传播时直接流回,减少梯度消失的问题。

技术和理论上的考量

  1. 维持数据的尺度相加操作不改变数据的尺度(大小和范围),这对于保持网络各层的数据稳定性有重要作用。如果使用其他运算(如乘法或平均),可能会不断改变数据的尺度和分布,增加训练过程中的复杂度。
  2. 梯度传播的效率:在反向传播时,加法操作有一个很好的性质------梯度可以无改变地流过加法节点。这保证了即使在深层网络中,梯度也能有效地传播回每个层,有助于缓解梯度消失问题。

替代方案及其挑战

尽管"相加"运算是实现残巀连接最常见的方式,但也存在其他的替代方案,如元素级乘法(Hadamard乘法)或等,每种方法都有其优缺点:

  • 元素级乘法:这种方式可能导致输出尺度的变化,同时对输入信号的直接传播不友好。乘法操作会改变梯度分布,可能使训练变得更困难。
  • 连接:虽然通过拼接输入和输出保留了全部信息,但这会导致网络层的输入输出维度不断增加,给网络设计和计算带来挑战。

未完待续

在最新的文章 【NanoGPT 学习 03】GPT 代码架构分析 中,将前两篇文章的组件全部整合在一起进行分析。

相关推荐
数据智能老司机39 分钟前
Kubernetes 上的生成式 AI——模型数据
kubernetes·llm·agent
iceiceiceice39 分钟前
从零开始构建 RAG + DeepSeek Demo
人工智能·llm
302AI1 小时前
大白话聊一聊:为什么OpenClaw那么火
llm·agent·vibecoding
数据智能老司机2 小时前
AI 智能体与应用——使用 LangGraph 构建基于工具的智能体
llm·agent
数据智能老司机2 小时前
AI 智能体与应用——问题转换
llm·agent
数据智能老司机2 小时前
AI 智能体与应用——使用 LangGraph 构建智能体工作流
llm·agent
数据智能老司机2 小时前
AI 智能体与应用——构建研究摘要引擎
llm·agent
数据智能老司机3 小时前
AI 智能体与应用——使用 LangChain 和 LangSmith 构建 Q&A 聊天机器人
llm·agent
Pitayafruit3 小时前
OpenClaw 从装完到真正会用,成为专业养🦞户的攻略
llm·aigc
数据智能老司机4 小时前
AI 智能体与应用——使用 LangChain 进行文本摘要
llm·agent