【PyTorch】 nn.TransformerEncoderLayer 详解

PyTorch TransformerEncoderLayer 详解

在 Transformer 架构中,编码器层(Encoder Layer)是处理序列的核心模块,它能捕捉序列内部依赖,并提取每个 token 的高阶特征。在 PyTorch 中,nn.TransformerEncoderLayer 实现了这个层。本文将从源码层面数学原理实际张量流动三个角度做解析。


一、整体架构回顾

在开始之前,推荐回顾前面关于transformer的介绍:
【机器学习】21. Transformer: 最通俗易懂讲解

Transformer 编码器层主要包含三个核心模块:

  1. 多头自注意力(Multi-Head Self-Attention)

    • 输入序列中每个 token 都会"看"其他 token,捕捉全局依赖。
    • 多头机制让模型可以在不同子空间同时捕捉不同类型的依赖。
  2. 前馈神经网络(Feedforward Network, FFN)

    • 对每个 token 的特征进行非线性变换。
    • 提升表达能力。
  3. 残差连接 + 层归一化 + Dropout

    • 保证梯度流稳定,训练不发散。
    • 防止过拟合。

简化的层结构图如下:

复制代码
       +-------------------+
       |   Input Tensor    |
       +-------------------+
                 |
                 v
          Multi-Head Attention
                 |
            Dropout + Residual
                 |
             LayerNorm
                 |
             FeedForward
                 |
            Dropout + Residual
                 |
             LayerNorm
                 |
             Output Tensor

二、初始化方法 __init__ 分析

源码如下:

python 复制代码
def __init__(
    self,
    d_model: int,
    nhead: int,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
    layer_norm_eps: float = 1e-5,
    batch_first: bool = False,
    norm_first: bool = False,
    bias: bool = True,
    device=None,
    dtype=None,
) -> None:

1. 核心参数解析

参数 含义 例子
d_model 输入特征维度,每个 token embedding 的大小 512
nhead 多头注意力头数,每个 head 独立学习注意力 8
dim_feedforward 前馈网络隐藏层大小 2048
dropout dropout 比例 0.1
activation 前馈网络激活函数 relu / gelu
layer_norm_eps LayerNorm epsilon 1e-5
batch_first 输入是否是 (batch, seq, feature) False
norm_first 是否使用 pre-norm False
bias 是否在线性层中使用 bias True
device / dtype 张量设备和类型 cpu / float32

2. 多头自注意力

python 复制代码
self.self_attn = MultiheadAttention(
    d_model,
    nhead,
    dropout=dropout,
    bias=bias,
    batch_first=batch_first,
    **factory_kwargs,
)
  • 输入:Q,K,V 均为 (seq_len, batch_size, d_model)(batch_size, seq_len, d_model)

  • 处理逻辑:

    1. 将输入映射到 query/key/value:

      • 每个头的维度 head_dim = d_model / nhead
      • 每个头独立计算注意力权重。
    2. 计算注意力矩阵:softmax(QK^T / sqrt(head_dim))

    3. 将加权后的 V 进行 concat 并线性投影。

  • 输出与输入形状相同。

为什么要多头?

  • 单头注意力只能捕捉一个子空间的依赖关系,多头可以捕捉多种不同关系。
  • 例如,语言模型可能同时关注语法结构和词义相似性。

3. 前馈网络

python 复制代码
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
  • 前馈网络 FFN 对每个 token 单独处理,不依赖其他 token。
  • 计算公式:

F F N ( x ) = W 2 ⋅ Dropout ( Activation ( W 1 ⋅ x + b 1 ) ) + b 2 FFN(x) = W_2 \cdot \text{Dropout}(\text{Activation}(W_1 \cdot x + b_1)) + b_2 FFN(x)=W2⋅Dropout(Activation(W1⋅x+b1))+b2

  • 维度变化:

    • 输入:(seq_len, batch, d_model)
    • linear1 → (seq_len, batch, dim_feedforward)
    • linear2 → (seq_len, batch, d_model),与输入维度一致

4. LayerNorm 和 Dropout

python 复制代码
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
  • LayerNorm:沿最后一维归一化,使输出均值为 0、方差为 1,提高训练稳定性。
  • Dropout:随机丢弃神经元,防止过拟合。

三、前向传播 forward 分析

python 复制代码
def forward(
    self,
    src: Tensor,
    src_mask: Optional[Tensor] = None,
    src_key_padding_mask: Optional[Tensor] = None,
    is_causal: bool = False,
) -> Tensor:

1. 输入

  • src:序列张量,shape [seq_len, batch, d_model]
  • src_mask:控制 token 之间注意力可见性,例如防止看到未来 token。
  • src_key_padding_mask:屏蔽 padding。
  • is_causal:是否应用因果 mask,用于解码器自回归。

2. 掩码规范化

python 复制代码
src_key_padding_mask = F._canonical_mask(...)
src_mask = F._canonical_mask(...)
  • PyTorch 内部函数,将不同类型掩码统一成标准 tensor。
  • 保证后续 self_attn 可以正确处理。

3. 自注意力 + 前馈的执行顺序

python 复制代码
x = src
if self.norm_first:
    x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
    x = x + self._ff_block(self.norm2(x))
else:
    x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
    x = self.norm2(x + self._ff_block(x))
  • norm_first=False 是默认设置,也叫 Post-Norm。

  • 执行步骤:

    1. 自注意力块 _sa_block

      • 输入 x → Q=K=V
      • 输出经过 dropout1
      • 残差连接:x + SA(x)
    2. LayerNorm1

      • 对残差后的输出进行归一化
    3. 前馈块 _ff_block

      • linear1 → activation → dropout → linear2 → dropout2
      • 残差连接:x + FFN(x)
    4. LayerNorm2

      • 对前馈输出归一化

4. 自注意力块 _sa_block

python 复制代码
def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
    x = self.self_attn(
        x, x, x,
        attn_mask=attn_mask,
        key_padding_mask=key_padding_mask,
        need_weights=False,
        is_causal=is_causal,
    )[0]
    return self.dropout1(x)
  • 调用 MultiheadAttention

    • Q=K=V=x
    • attn_mask 控制哪些 token 可见
    • key_padding_mask 屏蔽 padding token
  • 输出经过 Dropout

张量维度流

复制代码
Input x: (seq_len, batch, d_model)
Linear Q/K/V → (seq_len, batch, d_model)
Split heads → (seq_len, batch, nhead, head_dim)
Compute attention → (seq_len, batch, nhead, head_dim)
Concat heads → (seq_len, batch, d_model)
Dropout → (seq_len, batch, d_model)

5. 前馈块 _ff_block

python 复制代码
def _ff_block(self, x: Tensor) -> Tensor:
    x = self.linear2(self.dropout(self.activation(self.linear1(x))))
    return self.dropout2(x)
  • 逐 token 独立计算
  • 激活函数可以是 ReLU / GELU
  • 输出维度与输入一致

四、输入输出总结

输入

  • [seq_len, batch, d_model][batch, seq, d_model]

  • 可选掩码:

    • src_mask:序列之间的可见性
    • src_key_padding_mask:padding token 屏蔽

内部流程

  1. 自注意力:

    • Q/K/V=输入
    • 计算 attention
    • Dropout
    • 残差连接
  2. LayerNorm

  3. 前馈网络:

    • linear → activation → dropout → linear → dropout
    • 残差连接
  4. LayerNorm

输出

  • [seq_len, batch, d_model]
  • 包含自注意力和前馈后的特征表示
  • 可直接输入下一个 Transformer 编码器层

五、关键设计理念

  1. 残差连接 + LayerNorm

    • 保证梯度稳定
    • 支持深层网络训练
  2. 多头自注意力

    • 并行捕捉多种关系
    • 提升表达能力
  3. 前馈网络

    • 提供非线性映射
    • 逐位置增强特征
  4. Dropout + Mask

    • Dropout 防止过拟合
    • Mask 控制注意力范围
相关推荐
山土成旧客2 小时前
【Python学习打卡-Day44】站在巨人的肩膀上:玩转PyTorch预训练模型与迁移学习
pytorch·python·学习
星河天欲瞩2 小时前
【深度学习Day1】环境配置(CUDA、PyTorch)
人工智能·pytorch·python·深度学习·学习·机器学习·conda
前沿AI2 小时前
东风奕派×中关村科金 | 大模型外呼重塑汽车营销新链路,实现高效线索转化
大数据·人工智能
2501_941837262 小时前
莲花目标检测任务改进RetinaNet_R50-Caffe_FPN_MS-2x_COCO模型训练与性能优化
人工智能·目标检测·caffe
老周聊架构2 小时前
解构Claude Skills:可插拔的AI专业知识模块设计
人工智能·skills
Irene.ll2 小时前
DAY32 官方文档的阅读
python
Pyeako2 小时前
Opencv计算机视觉--轮廓检测&模板匹配
人工智能·python·opencv·计算机视觉·边缘检测·轮廓检测·模板匹配
清铎2 小时前
项目_一款基于RAG的金融保险业务多模态问答assistant
人工智能
DBBH2 小时前
DBBH的AI学习笔记
人工智能·笔记·学习