Transformer架构
Transformer作用
Transformer就像一个超级智能的输入法:它一边看已经写好的所有文字(自注意力),输出一个包含每一个候选词预测可能性的矩阵,调概率最高的最为下一个预测结果
Transformer 与传统架构(RNN/LSTM)对比
RNN 像"接力棒",依次传递上一个词的语义;LSTM 像"智能筛选器",用门控网络决定记什么、忘什么。
而 Transformer 则像"全体会议室"------每个人都能同时看到所有人,直接沟通。
下面从 核心机制、并行能力、长距离依赖、训练难度 四个维度详细对比。
1. 核心机制对比
| 架构 | 核心思想 | 通俗类比 |
|---|---|---|
| RNN | 隐藏状态 h_t = f(h_{t-1}, x_t),逐个词传递信息 |
接力跑:每一棒只从上一棒接过信息,再往后传 |
| LSTM | 引入遗忘门、输入门、输出门,控制信息流动 | 有记忆功能的传送带:可以决定保留/丢弃/输出哪些信息 |
| Transformer | 自注意力:每个词直接与序列中所有词交互 | 圆桌会议:每个人都能同时看到所有人,直接对话 |
2. 并行计算能力(速度关键)
| 架构 | 训练时 | 原因 |
|---|---|---|
| RNN / LSTM | 串行,无法并行 | 必须按时间步 t=1,2,...,T 依次计算,第 t 步依赖 t-1 步的结果 |
| Transformer | 完全并行 | 自注意力矩阵一次性计算所有位置对的关系,没有时间依赖 |
影响:
- 对于长序列(如 1000 个词),RNN/LSTM 需要 1000 步,Transformer 只需 1 步(矩阵乘法)。
- 这也是为什么 GPT、BERT 等大模型都用 Transformer------能在 GPU 上高效并行训练。
3. 长距离依赖捕捉能力
| 架构 | 能力 | 问题 |
|---|---|---|
| RNN | 差 | 梯度消失/爆炸,很难记住 20+ 步前的信息 |
| LSTM | 较好 | 通过门控机制缓解梯度消失,但仍受限于串行路径长度(有效记忆约 100 步左右) |
| Transformer | 优秀 | 自注意力直接计算任意两个位置的关联,路径长度 O(1),不受距离影响 |
例子 :
句子"我在北京上学......(隔了 50 个词)......我喜欢那座城市。"
- RNN/LSTM:50 步后可能忘了"北京"。
- Transformer:最后一个词可以直接关注到第一个词"北京",没有信息衰减。
4. 训练难度与稳定性
| 架构 | 梯度问题 | 解决方案 |
|---|---|---|
| RNN | 严重梯度消失/爆炸 | 需要梯度裁剪、特殊初始化,效果仍有限 |
| LSTM | 缓解但仍然存在 | 门控机制有效,但深层 LSTM 仍难训练 |
| Transformer | 无时间步上的梯度消失 | 残差连接 + 层归一化 + 位置编码,支持百层以上深度 |
注意 :Transformer 并非完美,它的 O(n²) 计算复杂度在处理超长文本(如 10 万词)时内存爆炸。这催生了稀疏注意力、Longformer 等变种。
5. 关键差异总结表
| 维度 | RNN | LSTM | Transformer |
|---|---|---|---|
| 信息传递方式 | 顺序传递 | 顺序传递 + 门控筛选 | 全局直接交互 |
| 并行训练 | ❌ 串行 | ❌ 串行 | ✅ 完全并行 |
| 长距离依赖 | 差(≤10步) | 较好(≤100步) | 优秀(无距离衰减) |
| 训练深度 | 浅(2-3层) | 中等(4-8层) | 深(数十到数百层) |
| 时间复杂度 | O(T·d²) | O(T·d²) | O(T²·d)(T长时劣势) |
| 典型应用 | 早期语言模型 | 机器翻译、语音识别 | GPT、BERT、几乎所有现代 LLM |
6. 从"记忆"角度看本质差异
- RNN:只记得上一句,像金鱼。
- LSTM:有选择地记住重要内容,像普通人做笔记。
- Transformer:拥有完整的会议记录,随时可以翻看任何一页,像超级学者。
面试常见追问
Q:Transformer 这么好,RNN/LSTM 还有用吗?
A:有。当序列极长(如 DNA 序列、时间序列预测)且资源受限时,线性复杂度的 RNN 变体(如 RWKV、Mamba)仍有优势。另外,推理时 RNN 的常数内存占用比 Transformer 的 KV 缓存更小。
Q:Transformer 如何解决位置信息缺失?
A:通过加入位置编码(正弦/余弦或可学习的),让自注意力知道"距离"概念。
Q:能否说 Transformer 完全取代了 RNN?
A:在 NLP 领域基本取代,但在某些流式数据(实时语音、传感器数据)中,RNN/LSTM 的低延迟仍有价值。
一句话总结:
RNN 是排队传话,LSTM 是聪明地传话,Transformer 是所有人同时开圆桌会------效率最高,看得最远。
Transformer 结构模块详解
Transformer 由 编码器(Encoder) 和 解码器(Decoder) 组成,但现代大语言模型(如 GPT)通常只使用解码器部分。下面以 解码器结构 为例,拆解每个模块。
1. 输入表示:Token Embedding + 位置编码
目的:把文字转换成模型能理解的数字向量,并注入顺序信息。
- Token Embedding:每个词(或子词)被映射为一个固定维度的向量(例如 512 维)。这就像查一个"词向量表"。
- 位置编码 :因为 Transformer 没有天然的先后顺序(不像 RNN 一个接一个),需要额外告诉模型每个词的位置。常用 正弦/余弦函数 生成位置向量,与 Token Embedding 相加。
公式 :
输入 = TokenEmbedding + PositionalEncoding
形状 :(序列长度, 隐藏维度)
2. 前向传播(总体流程)
一个典型的解码器层(Decoder Layer)的前向传播顺序为:
输入 → 掩码多头自注意力 → 残差连接 + 层归一化 → 前馈网络(FFN) → 残差连接 + 层归一化 → 输出到下一层
(如果用于机器翻译,解码器中间还会有一个"编码器-解码器注意力",但 GPT 类模型没有)
3. 残差注意力 + "加 embedding"(实际是 Add & Norm)
"残差注意力加 embedding" 通常指的是 多头注意力 + 残差连接 + 层归一化。
- 多头自注意力:让每个词看到它之前的所有词(通过掩码防止看到未来),计算相互关系。
- 残差连接 :将输入(即未经过注意力的原始 x)与注意力输出相加:
x = x + AttentionOutput。这能缓解梯度消失,让模型更容易训练。 - 层归一化(LayerNorm):对残差连接的结果做归一化,稳定训练。
代码
python
attn_out = multi_head_self_attention(x, mask)
x = layer_norm(x + attn_out) # 这就是 "Add & Norm"
注意:这里并不是"加 embedding",而是加上本层的输入(残差)。原始输入 embedding 只在最开始加入一次。
4. FFN 残差连接
前馈网络(FFN) :一个两层的全连接网络,通常结构为 线性 -> ReLU(或GELU) -> 线性。它独立地处理每个位置的向量,增强模型的非线性表达能力。
- 第一层将维度从
d_model扩展到d_ff(通常是 4 倍,如 512 → 2048)。 - 第二层再从
d_ff映射回d_model。
同样加上残差连接和层归一化:
python
ffn_out = ffn(x) # 前馈网络
x = layer_norm(x + ffn_out) # 第二次 Add & Norm
这样,一个解码器层就完成了。
5. 输出层
经过 N 层(例如 12 层、24 层)堆叠后,最后一步:
- 线性映射(Linear) :将最后一层的输出向量(维度
d_model)映射到 词表大小(vocab_size),得到每个词的"得分"(logits)。 - Softmax:将得分转换为概率分布,总和为 1。
- 预测:取概率最高的词作为下一个输出(或者按采样策略)。
形状变化 :
(序列长度, d_model)→(序列长度, vocab_size)→ 概率分布。
6. 多层堆叠的效果
为什么需要堆叠很多层(例如 6 层、12 层、96 层)?
- 浅层(靠近输入):捕捉局部语法特征,比如词性、短语组合。
- 中层:学习句法结构,比如主谓宾关系、从句边界。
- 深层:理解全局语义、长距离依赖、推理和知识调用。
堆叠带来的能力:
- 感受野扩大:每一层都会重新计算自注意力,高层能看到经过下层抽象后的全局信息。
- 非线性增强:每层都有 FFN,层数越多,模型的表达能力越强。
- 残差连接保证可训练:即使 100 层,梯度也能直接回传。
例子:GPT-3 有 96 层,每层都能在前一层的基础上进一步提炼信息,最终能完成复杂推理。
整体结构图(文本示意)
输入序列: "我 爱 你"
│
▼
[Token Embedding] + [位置编码]
│
▼
┌─────────────────────────────┐
│ Decoder Layer 1 │
│ ┌─────────────────────┐ │
│ │ 掩码多头自注意力 │ │
│ └─────────┬───────────┘ │
│ ▼ │
│ Add & Norm (残差+LN) │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ FFN (前馈网络) │ │
│ └─────────┬───────────┘ │
│ ▼ │
│ Add & Norm (残差+LN) │
└─────────────────────────────┘
│
▼ (重复 N 次,上一层的输出作为下一层的输入)
┌─────────────────────────────┐
│ Decoder Layer N │
└─────────────────────────────┘
│
▼
线性层 (Linear)
│
▼
Softmax → 概率分布 → 预测下一个词
常见面试追问
Q:为什么每个子层后都要加残差和层归一化?
A:残差让梯度直接流通,避免深层梯度消失;层归一化稳定数值范围,加速收敛。
Q:位置编码为什么是加而不是拼接?
A:加法不改变向量维度,效率更高;实验证明效果与拼接相近,且节省参数。
Q:多层堆叠会不会导致过拟合?
A:会,所以需要 dropout、权重衰减等正则化,以及大规模预训练数据。