导读
Transformer 之所以重要,不只是"更强的模型",而是它用注意力让序列计算可以并行化,并且把"依赖关系建模"变成了可解释、可控制的操作。本文从公式与结构讲清楚核心机制,再落到 PyTorch torch.nn.Transformer 的输入形状与 mask 语义,帮助你把概念转成可运行的工程实现。
一、从序列建模瓶颈说起
传统 RNN/GRU/LSTM 逐步处理序列,时间复杂度虽然线性,但计算无法并行,而且长距离依赖容易被梯度稀释。Transformer 的关键突破是:用注意力让每个位置都能直接"看见"全序列,从而减少对长链条的依赖,并把计算并行化。
二、核心公式:缩放点积注意力
论文给出的标准公式是:
scss
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
这里的缩放 sqrt(d_k) 是为了避免点积值过大导致 softmax 饱和,从而稳定训练。直觉上可以这样理解:Q 与 K 计算"相关性得分",softmax 把得分变成权重,最后用权重对 V 做加权求和,得到"当前位置应该看哪些信息"的聚合结果。
三、多头注意力:并行子空间的组合
多头注意力把原本单一的注意力分成多个头:每个头在不同的线性投影空间里计算注意力,最后再拼接起来。这样模型能同时捕捉不同类型的关系,例如位置邻近、语义相似或结构对齐。论文指出多头机制是提升表达能力的关键步骤,而不是简单的并行加速。
四、编码器/解码器的信息流
经典 Transformer 是 encoder-decoder 结构:
- 编码器层包含自注意力与前馈网络,负责把输入序列编码成上下文表示。
- 解码器层包含"带因果 mask 的自注意力"和"对编码器输出的交叉注意力",确保生成时只能看到已生成的历史,同时能读取源序列信息。
这种分工让模型既能做序列到序列的映射,也能在解码阶段保持自回归一致性。
五、位置编码:没有顺序就没有语义
注意力本身不包含顺序,因此 Transformer 必须显式加入位置编码。论文中的正弦位置编码提供了一个无需训练的方案,使模型能通过位置嵌入感知顺序与相对位移。无论用正弦还是可训练位置嵌入,本质目的都是让"相同 token 在不同位置"得到区分。
论文给出的正弦位置编码公式为:
scss
PE(pos, 2i) = sin(pos / 10000^(2i / d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i / d_model))
这个设计的直觉是:不同维度用不同频率编码位置,使模型能够通过线性组合表达相对位移;同时也更容易外推到比训练时更长的序列位置。
六、落到工程:torch.nn.Transformer 的定位
PyTorch 官方文档明确指出:torch.nn.Transformer 是对原始架构的参考实现,功能相对有限,更多高效结构需基于底层模块自行组合(见官方文档说明)。因此它非常适合教学、原型或基线模型,但不一定是生产最优解。
参数方面,常用配置包括:
d_model:输入/输出的特征维度;nhead:多头注意力头数;num_encoder_layers/num_decoder_layers:层数深度;dim_feedforward:前馈网络维度;dropout、activation:训练稳定性与非线性选择;batch_first:是否使用(N, S, E)的张量布局;norm_first:是否在子层前应用 LayerNorm。
这些参数直接决定模型容量、训练稳定性与性能瓶颈。
七、输入形状与 batch_first
官方文档明确给出形状规则:
batch_first=False(默认)时:src是(S, N, E),tgt是(T, N, E);batch_first=True时:src是(N, S, E),tgt是(N, T, E);src_mask形状为(S, S),tgt_mask为(T, T);src_key_padding_mask为(N, S),tgt_key_padding_mask为(N, T)。
这里的 S/T 是序列长度,N 是 batch 大小,E 是特征维度。形状错一维,模型要么报错,要么更隐蔽地学错。
八、Mask 语义:布尔与浮点的差别
PyTorch 文档特别说明:
- 当 mask 是布尔张量时,
True表示"不可参与注意力"; - 当 mask 是浮点张量时,它会被加到注意力权重上;
src_is_causal/tgt_is_causal只是"因果 mask 提示",文档警告错误提示会导致不正确的执行结果。
这是一类非常常见的坑:不同库对 mask 的 True/False 语义不同,不能想当然。
九、实践代码:完整输入 + 因果 mask + padding mask
下面示例展示 batch_first=True 的配置方式,并使用 generate_square_subsequent_mask 生成因果 mask:
python
import torch
import torch.nn as nn
model = nn.Transformer(
d_model=256,
nhead=8,
num_encoder_layers=4,
num_decoder_layers=4,
batch_first=True,
)
# (N, S, E) 与 (N, T, E)
src = torch.rand(32, 10, 256)
tgt = torch.rand(32, 12, 256)
# (T, T) 因果 mask:被遮挡位置为 -inf
causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1))
# padding mask:True 表示忽略
src_key_padding_mask = torch.zeros(32, 10, dtype=torch.bool)
y = model(
src,
tgt,
tgt_mask=causal_mask,
src_key_padding_mask=src_key_padding_mask,
)
这段代码与官方文档给出的 shape 规则一致,也是排查输入错误时最常用的基线。
十、从注意力矩阵理解 mask 的作用
理解 mask 的一个好方法,是把注意力权重当成一个"可视化的矩阵"。对自注意力来说,权重矩阵的形状通常是 (T, T) 或 (S, S);对交叉注意力,形状是 (T, S)。官方文档也给出了 src_mask、tgt_mask、memory_mask 的形状定义,这能帮助你判断一个 mask 应该对应"谁看谁"。
在实现上,mask 会在 softmax 之前影响注意力权重:布尔 mask 的 True 表示禁止关注,浮点 mask 会被直接加到权重上(通常用 -inf 做遮挡)。因此你可以把 mask 理解为"把某些边从注意力图中剪掉"。如果 mask 方向弄反,模型就会在不该看的位置上"偷看答案"。
十一、训练与推理建议(面向工程)
- 训练阶段 :确保
tgt是右移后的输入序列,并用因果 mask 避免泄漏未来信息。 - 推理阶段 :逐步生成时要持续更新
tgt,并复用generate_square_subsequent_mask的形状规则。 - 日志记录 :建议记录
batch_first、mask 类型、shape 与 dtype,便于定位训练不稳定或输出异常。
十二、性能与边界条件
注意力的计算复杂度是 O(n^2),在长序列场景会造成明显的显存和延迟压力。官方文档也指出 torch.nn.Transformer 是参考实现,功能有限,强调更高效的实现应基于底层模块或 PyTorch 生态库。
这意味着:当你遇到长上下文或实时推理瓶颈时,需要考虑稀疏注意力、分块注意力或 KV cache 等策略,而不是"盲目加层"。
十三、常见错误清单
- 把
batch_first当成默认 :默认是False,要显式设置并匹配输入形状。 - mask 语义搞反:布尔 mask 的 True 表示"不可关注"。
- 忽视
*_is_causal的提示语义:错误提示会导致不正确的执行结果。 - 只记得公式,不记得形状:Transformer 的大部分 bug 出在 shape 与 mask 组合上。