Transformer架构梳理

Transformer架构

Transformer作用

Transformer就像一个超级智能的输入法:它一边看已经写好的所有文字(自注意力),输出一个包含每一个候选词预测可能性的矩阵,调概率最高的最为下一个预测结果

Transformer 与传统架构(RNN/LSTM)对比

RNN 像"接力棒",依次传递上一个词的语义;LSTM 像"智能筛选器",用门控网络决定记什么、忘什么。

而 Transformer 则像"全体会议室"------每个人都能同时看到所有人,直接沟通。

下面从 核心机制、并行能力、长距离依赖、训练难度 四个维度详细对比。


1. 核心机制对比

架构 核心思想 通俗类比
RNN 隐藏状态 h_t = f(h_{t-1}, x_t),逐个词传递信息 接力跑:每一棒只从上一棒接过信息,再往后传
LSTM 引入遗忘门、输入门、输出门,控制信息流动 有记忆功能的传送带:可以决定保留/丢弃/输出哪些信息
Transformer 自注意力:每个词直接与序列中所有词交互 圆桌会议:每个人都能同时看到所有人,直接对话

2. 并行计算能力(速度关键)

架构 训练时 原因
RNN / LSTM 串行,无法并行 必须按时间步 t=1,2,...,T 依次计算,第 t 步依赖 t-1 步的结果
Transformer 完全并行 自注意力矩阵一次性计算所有位置对的关系,没有时间依赖

影响

  • 对于长序列(如 1000 个词),RNN/LSTM 需要 1000 步,Transformer 只需 1 步(矩阵乘法)。
  • 这也是为什么 GPT、BERT 等大模型都用 Transformer------能在 GPU 上高效并行训练。

3. 长距离依赖捕捉能力

架构 能力 问题
RNN 梯度消失/爆炸,很难记住 20+ 步前的信息
LSTM 较好 通过门控机制缓解梯度消失,但仍受限于串行路径长度(有效记忆约 100 步左右)
Transformer 优秀 自注意力直接计算任意两个位置的关联,路径长度 O(1),不受距离影响

例子

句子"我在北京上学......(隔了 50 个词)......我喜欢那座城市。"

  • RNN/LSTM:50 步后可能忘了"北京"。
  • Transformer:最后一个词可以直接关注到第一个词"北京",没有信息衰减。

4. 训练难度与稳定性

架构 梯度问题 解决方案
RNN 严重梯度消失/爆炸 需要梯度裁剪、特殊初始化,效果仍有限
LSTM 缓解但仍然存在 门控机制有效,但深层 LSTM 仍难训练
Transformer 无时间步上的梯度消失 残差连接 + 层归一化 + 位置编码,支持百层以上深度

注意 :Transformer 并非完美,它的 O(n²) 计算复杂度在处理超长文本(如 10 万词)时内存爆炸。这催生了稀疏注意力、Longformer 等变种。


5. 关键差异总结表

维度 RNN LSTM Transformer
信息传递方式 顺序传递 顺序传递 + 门控筛选 全局直接交互
并行训练 ❌ 串行 ❌ 串行 ✅ 完全并行
长距离依赖 差(≤10步) 较好(≤100步) 优秀(无距离衰减)
训练深度 浅(2-3层) 中等(4-8层) 深(数十到数百层)
时间复杂度 O(T·d²) O(T·d²) O(T²·d)(T长时劣势)
典型应用 早期语言模型 机器翻译、语音识别 GPT、BERT、几乎所有现代 LLM

6. 从"记忆"角度看本质差异

  • RNN:只记得上一句,像金鱼。
  • LSTM:有选择地记住重要内容,像普通人做笔记。
  • Transformer:拥有完整的会议记录,随时可以翻看任何一页,像超级学者。

面试常见追问

Q:Transformer 这么好,RNN/LSTM 还有用吗?

A:有。当序列极长(如 DNA 序列、时间序列预测)且资源受限时,线性复杂度的 RNN 变体(如 RWKV、Mamba)仍有优势。另外,推理时 RNN 的常数内存占用比 Transformer 的 KV 缓存更小。

Q:Transformer 如何解决位置信息缺失?

A:通过加入位置编码(正弦/余弦或可学习的),让自注意力知道"距离"概念。

Q:能否说 Transformer 完全取代了 RNN?

A:在 NLP 领域基本取代,但在某些流式数据(实时语音、传感器数据)中,RNN/LSTM 的低延迟仍有价值。


一句话总结:

RNN 是排队传话,LSTM 是聪明地传话,Transformer 是所有人同时开圆桌会------效率最高,看得最远。


Transformer 结构模块详解

Transformer 由 编码器(Encoder)解码器(Decoder) 组成,但现代大语言模型(如 GPT)通常只使用解码器部分。下面以 解码器结构 为例,拆解每个模块。


1. 输入表示:Token Embedding + 位置编码

目的:把文字转换成模型能理解的数字向量,并注入顺序信息。

  • Token Embedding:每个词(或子词)被映射为一个固定维度的向量(例如 512 维)。这就像查一个"词向量表"。
  • 位置编码 :因为 Transformer 没有天然的先后顺序(不像 RNN 一个接一个),需要额外告诉模型每个词的位置。常用 正弦/余弦函数 生成位置向量,与 Token Embedding 相加

公式输入 = TokenEmbedding + PositionalEncoding

形状(序列长度, 隐藏维度)


2. 前向传播(总体流程)

一个典型的解码器层(Decoder Layer)的前向传播顺序为:

复制代码
输入 → 掩码多头自注意力 → 残差连接 + 层归一化 → 前馈网络(FFN) → 残差连接 + 层归一化 → 输出到下一层

(如果用于机器翻译,解码器中间还会有一个"编码器-解码器注意力",但 GPT 类模型没有)


3. 残差注意力 + "加 embedding"(实际是 Add & Norm)

"残差注意力加 embedding" 通常指的是 多头注意力 + 残差连接 + 层归一化

  • 多头自注意力:让每个词看到它之前的所有词(通过掩码防止看到未来),计算相互关系。
  • 残差连接 :将输入(即未经过注意力的原始 x)与注意力输出相加:x = x + AttentionOutput。这能缓解梯度消失,让模型更容易训练。
  • 层归一化(LayerNorm):对残差连接的结果做归一化,稳定训练。

代码

python 复制代码
attn_out = multi_head_self_attention(x, mask)
x = layer_norm(x + attn_out)   # 这就是 "Add & Norm"

注意:这里并不是"加 embedding",而是加上本层的输入(残差)。原始输入 embedding 只在最开始加入一次。


4. FFN 残差连接

前馈网络(FFN) :一个两层的全连接网络,通常结构为 线性 -> ReLU(或GELU) -> 线性。它独立地处理每个位置的向量,增强模型的非线性表达能力。

  • 第一层将维度从 d_model 扩展到 d_ff(通常是 4 倍,如 512 → 2048)。
  • 第二层再从 d_ff 映射回 d_model

同样加上残差连接和层归一化

python 复制代码
ffn_out = ffn(x)               # 前馈网络
x = layer_norm(x + ffn_out)    # 第二次 Add & Norm

这样,一个解码器层就完成了。


5. 输出层

经过 N 层(例如 12 层、24 层)堆叠后,最后一步:

  • 线性映射(Linear) :将最后一层的输出向量(维度 d_model)映射到 词表大小(vocab_size),得到每个词的"得分"(logits)。
  • Softmax:将得分转换为概率分布,总和为 1。
  • 预测:取概率最高的词作为下一个输出(或者按采样策略)。

形状变化(序列长度, d_model)(序列长度, vocab_size) → 概率分布。


6. 多层堆叠的效果

为什么需要堆叠很多层(例如 6 层、12 层、96 层)?

  • 浅层(靠近输入):捕捉局部语法特征,比如词性、短语组合。
  • 中层:学习句法结构,比如主谓宾关系、从句边界。
  • 深层:理解全局语义、长距离依赖、推理和知识调用。

堆叠带来的能力

  1. 感受野扩大:每一层都会重新计算自注意力,高层能看到经过下层抽象后的全局信息。
  2. 非线性增强:每层都有 FFN,层数越多,模型的表达能力越强。
  3. 残差连接保证可训练:即使 100 层,梯度也能直接回传。

例子:GPT-3 有 96 层,每层都能在前一层的基础上进一步提炼信息,最终能完成复杂推理。


整体结构图(文本示意)

复制代码
输入序列: "我 爱 你"
      │
      ▼
[Token Embedding] + [位置编码]
      │
      ▼
┌─────────────────────────────┐
│    Decoder Layer 1          │
│  ┌─────────────────────┐    │
│  │ 掩码多头自注意力     │    │
│  └─────────┬───────────┘    │
│            ▼                 │
│      Add & Norm (残差+LN)    │
│            ▼                 │
│  ┌─────────────────────┐    │
│  │   FFN (前馈网络)     │    │
│  └─────────┬───────────┘    │
│            ▼                 │
│      Add & Norm (残差+LN)    │
└─────────────────────────────┘
      │
      ▼ (重复 N 次,上一层的输出作为下一层的输入)
┌─────────────────────────────┐
│    Decoder Layer N          │
└─────────────────────────────┘
      │
      ▼
   线性层 (Linear)
      │
      ▼
   Softmax → 概率分布 → 预测下一个词

常见面试追问

Q:为什么每个子层后都要加残差和层归一化?

A:残差让梯度直接流通,避免深层梯度消失;层归一化稳定数值范围,加速收敛。

Q:位置编码为什么是加而不是拼接?

A:加法不改变向量维度,效率更高;实验证明效果与拼接相近,且节省参数。

Q:多层堆叠会不会导致过拟合?

A:会,所以需要 dropout、权重衰减等正则化,以及大规模预训练数据。

相关推荐
独隅2 小时前
PyTorch转TFLite动态形状处理技巧
人工智能·pytorch·python
猫头虎2 小时前
一个插件,国内直接用Claude Opus 4.7
人工智能·langchain·开源·prompt·aigc·ai编程·agi
台XX2 小时前
Ollama+其他模型仓库
人工智能
思绪无限2 小时前
YOLOv5至YOLOv12升级:条形码二维码检测系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·条形码二维码检测·yolov12·yolo全家桶
KC2702 小时前
老板主动给我涨薪!揭秘制造业数字化转型省300万的3招
人工智能·aigc
恒哥的爸爸2 小时前
GPT原理笔记
人工智能·笔记·gpt
咚咚王者3 小时前
人工智能之知识蒸馏 第五章 蒸馏优化技术:精度损失补偿方法
人工智能
kishu_iOS&AI3 小时前
Pytorch —— 自动微分模块
人工智能·pytorch·python·深度学习·算法·线性回归
星浩AI3 小时前
手把手带你在 Windows 安装 Hermess Agent,并接入飞书 [喂饭级教程含踩坑经验]
人工智能·后端·agent