残差连接与层归一化通俗易懂的详解

残差连接(Residual Connection)和层归一化(Layer Normalization)是深度神经网络(尤其是Transformer架构)中用于提升模型训练稳定性、加速收敛并防止梯度消失/爆炸的两项关键技术。它们通常协同工作,共同构成了Transformer中每个子层(如自注意力层、前馈神经网络层)的核心结构单元。

为了清晰地展示其核心作用与实现,下表对比了这两种技术:

特性 残差连接 (Residual Connection) 层归一化 (Layer Normalization)
核心目的 解决深度网络中的梯度消失/爆炸问题,使网络能够轻松学习恒等映射,从而允许构建极深的网络。 稳定每一层神经网络输入的分布,减少内部协变量偏移,加速模型训练收敛。
核心思想 "跳跃连接"。不直接让网络层拟合目标函数 H(x),而是改为拟合残差函数 F(x) = H(x) - x。最终的输出为 x + F(x) 对单一样本在某一层所有神经元的激活值进行标准化(均值为0,方差为1),然后进行缩放和平移。
在Transformer中的位置 通常应用于每个子层(如自注意力、前馈网络)周围。输入x直接"跳跃"到该子层的输出之后,进行相加。 通常应用于残差相加之后 。即 LayerNorm(x + Sublayer(x))。这种顺序称为"后归一化"(Post-LN)。
通俗比喻 如同在一条主路旁修建一条辅路(新网络层)。如果辅路修得好,车流可以更快;如果修得不好,车流仍可顺畅地走主路,不至于堵死。这降低了修建新路的风险。 如同老师批改试卷时,不是看绝对分数,而是根据全班同学的平均分和分数分布来调整("标准化")每位同学的得分,使得每次考试的难度差异不会对评价单个学生造成过大影响。

残差连接(Residual Connection)详解

1. 用途与解决的问题:

在非常深的神经网络中,信号(梯度)在前向传播和反向传播时,需要经过许多层。传统的堆叠层是让每一层直接学习一个复杂的映射 H(x)。当网络很深时,这个映射可能变得极其复杂且难以优化,容易导致梯度在反向传播时变得极小(消失)或极大(爆炸),使得底层网络的参数几乎无法更新。

残差连接的提出,将一个棘手的学习目标 H(x) 转换成了一个相对更容易学习的目标。它允许网络层学习输入与输出之间的残差(即差异部分)。

2. 实现原理:

对于一个基础的网络层(或一组层)F,其输入为 x,传统的输出是 F(x)

引入残差连接后,该模块的输出变为:
输出 = x + F(x)

这里的 x 就是"跳跃连接"过来的原始输入。F(x) 是网络层需要学习的残差映射。

3. 为什么有效?

  • 恒等映射的便捷性 :如果当前层是多余的(即最优的 H(x) 就是 x),那么网络只需简单地将 F(x) 的参数学习为接近0,即可轻松实现 输出 ≈ x。这比让一个非线性层直接学习恒等函数要容易得多。
  • 梯度回传的捷径 :在反向传播时,梯度不仅通过 F(x) 的路径回传,还会通过跳跃连接直接回传给上一层。这条"捷径"确保了梯度即使经过很深的网络,也不容易消失,因为总有一条梯度为1的直连路径存在。

4. 在Transformer中的具体应用:

在Transformer的编码器和解码器层中,每个子层(自注意力子层、前馈神经网络子层)都被一个残差连接所包裹,然后紧接着进行层归一化。以编码器的自注意力子层为例:

python 复制代码
# 伪代码示意
def encoder_sublayer(x):
    # 1. 自注意力计算
    attention_output = SelfAttention(x)
    # 2. 残差连接:将原始输入 x 与自注意力输出相加
    residual_output = x + attention_output  # 残差连接在此发生
    # 3. 层归一化
    normalized_output = LayerNorm(residual_output)
    return normalized_output

层归一化(Layer Normalization)详解

1. 用途与解决的问题:

在训练深度网络时,每一层输入的分布会随着前一层参数的更新而不断变化,这种现象称为"内部协变量偏移"(Internal Covariate Shift)。这会导致训练过程需要不断适应新的数据分布,从而变得缓慢且不稳定。批归一化(Batch Normalization)通过对一个批次(Batch)内所有样本的同一特征进行归一化来解决这个问题,但在序列长度可变(如NLP任务)或批次较小时效果不佳。

层归一化提供了另一种思路:针对单个样本,对其在该层所有神经元(或隐藏单元)的激活值进行归一化。这使得它对批次大小不敏感,非常适合Transformer这类处理变长序列的模型。

2. 实现原理:

对于一个输入向量 h(代表某一层某个样本的所有神经元输出),其维度为 [d_model](例如512或768)。层归一化的计算步骤如下:

  • 计算均值和方差 :统计该样本在该层所有 d_model 个维度上的均值 μ 和方差 σ²
    • μ = (1/d_model) * Σ_{i=1}^{d_model} h_i
    • σ² = (1/d_model) * Σ_{i=1}^{d_model} (h_i - μ)²
  • 归一化 :使用计算出的均值和方差对 h 进行标准化,得到均值为0、方差为1的向量。
    • ĥ_i = (h_i - μ) / sqrt(σ² + ε),其中 ε 是一个极小的数(如1e-5),防止除以零。
  • 缩放与平移 :引入两个可学习的参数向量 γβ(维度均为 [d_model]),对归一化后的结果进行缩放和平移。这一步至关重要,它使模型有能力恢复归一化可能破坏掉的特征表示能力。
    • output_i = γ_i * ĥ_i + β_i

3. 为什么有效?

  • 稳定训练:通过将每层的输入稳定在相似的分布(零均值、单位方差附近),减少了内部协变量偏移,使得可以使用更大的学习率,并加速模型收敛。
  • 缓解梯度问题:归一化操作在一定程度上也有助于缓解梯度消失或爆炸问题。
  • 适用于序列模型:其计算独立于批次内其他样本,因此完美适配变长序列和不同批次大小的训练场景。

4. 在Transformer中的具体应用:

如前所述,在标准的Transformer架构中,层归一化被应用于残差相加之后。这种设计使得归一化操作的输入包含了原始输入和子层输出,有助于稳定整个子层的输出分布。以下是一个简化的PyTorch风格代码示例,展示了一个完整的Transformer子层(如前馈网络)如何结合残差连接和层归一化:

python 复制代码
import torch
import torch.nn as nn

class TransformerSublayerWithResidualAndLN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        # 示例子层:一个简单的前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        # 层归一化层
        self.layer_norm = nn.LayerNorm(d_model)
        
    def forward(self, x):
        # 保存原始输入,用于残差连接
        residual = x
        # 前馈网络计算
        sublayer_output = self.ffn(x)
        # 残差连接:原始输入 + 子层输出
        residual_output = residual + sublayer_output  # 残差连接
        # 层归一化
        output = self.layer_norm(residual_output)     # 层归一化
        return output

# 使用示例
d_model = 512
d_ff = 2048
batch_size = 2
seq_len = 10

model = TransformerSublayerWithResidualAndLN(d_model, d_ff)
input_tensor = torch.randn(batch_size, seq_len, d_model)
output_tensor = model(input_tensor)
print(f"输入形状: {input_tensor.shape}")
print(f"输出形状: {output_tensor.shape}")  # 应与输入形状一致

总结:残差连接通过引入"跳跃捷径",确保了信息(尤其是梯度)在深度网络中的顺畅流动,使得训练超深网络成为可能;层归一化则通过标准化每个样本层内的激活值,稳定了训练过程并加速收敛。在Transformer中,这两者以前文所述的顺序紧密结合,构成了其强大且可稳定训练的深层架构的基石,为后续BERT、GPT等大型语言模型的成功奠定了基础。


参考来源

相关推荐
csdn_aspnet1 小时前
Python 算法快闪 LeetCode 编号 70 - 爬楼梯
python·算法·leetcode·职场和发展
fantasy_arch2 小时前
pytorch人脸匹配模型
人工智能·pytorch·python
熊猫_豆豆2 小时前
广义相对论水星近日点进动完整详细数学推导
python·天体·广义相对论
科技那些事儿2 小时前
实时洞察,视觉赋能:国内情绪识别API公司推荐及计算机视觉流派深度解析
人工智能·计算机视觉
web3.08889992 小时前
1688 图搜接口(item_search_img / 拍立淘) 接入方法
开发语言·python
德思特2 小时前
从 Dify 配置页理解 RAG 的重要参数
java·人工智能·llm·dify·rag
火山引擎开发者社区2 小时前
ArkClaw AI 盯盘管家 —— 从手动口令到自动推送,4 套预置定时任务模版一键启用
人工智能
sxgzzn2 小时前
新能源场站数智化转型:基于数字孪生与AI的智慧运维管理平台解析
大数据·运维·人工智能