Transformer 核心机制拆解:自注意力、多头注意力、位置编码,一篇讲透

先搞清楚一个问题:为什么需要 Transformer?

在 Transformer 出现之前,NLP 领域是 RNN 和 LSTM 的天下。它们有个致命缺陷------只能顺序处理

想象一下你在读一句话:

"我 昨天 去 了 北京 ,因为 那里 的 烤鸭 特别 好吃"

RNN 读到"烤鸭"的时候,已经"走过"了好几个词,对"北京"的记忆早就衰减了。虽然 LSTM 用门控机制缓解了这个问题,但本质上还是"一个一个排队看"。

Transformer 的思路完全不同:让每个词同时看到句子里的所有词,想关注谁就关注谁,一步到位。

这就是 Self-Attention(自注意力机制) 的核心思想。


一、Self-Attention:让每个词都"眼观六路"

1.1 核心直觉

自注意力要解决的问题是:在一个序列中,每个元素应该"关注"哪些其他元素?

用人话来说就是:当模型处理"它"这个词的时候,它需要知道"它"指的是"猫"还是"垫子"。

"The cat sat on the mat because it was tired."

人类一眼就能看出 "it" 指的是 "cat"。自注意力机制就是让模型学会这种能力。

1.2 Q、K、V 三剑客

自注意力的核心是三个矩阵:Query(查询)Key(键)Value(值)

这个概念其实很好理解,拿搜索引擎来类比:

  • Query:你在搜索框输入的关键词 → "我想找什么?"
  • Key:每篇文章的标题/标签 → "我是什么?"
  • Value:文章的实际内容 → "这是我的内容"

匹配过程:用 Query 去和所有 Key 计算相似度,相似度高的 Value 就多关注一些。

less 复制代码
输入序列: [词1, 词2, 词3, ...]
    ↓ (乘以三个权重矩阵)
  Q矩阵    K矩阵    V矩阵
    ↓        ↓        ↓
  Q1,Q2..  K1,K2..  V1,V2..
    ↓
  Q·K^T → 注意力分数 → softmax → 加权求和Value → 输出

1.3 数学公式(别怕,很简单)

python 复制代码
# Scaled Dot-Product Attention
# 公式: Attention(Q, K, V) = softmax(QK^T / √d_k) · V

import torch
import torch.nn.functional as F
import math

def self_attention(Q, K, V):
    """
    Q: (seq_len, d_k)  查询矩阵
    K: (seq_len, d_k)  键矩阵
    V: (seq_len, d_v)  值矩阵
    """
    d_k = Q.size(-1)

    # 第一步:Q 和 K 做点积,得到注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # 第二步:除以 √d_k(缩放因子),防止分数过大导致 softmax 梯度消失
    scores = scores / math.sqrt(d_k)

    # 第三步:softmax 归一化,得到注意力权重(和为1)
    attention_weights = F.softmax(scores, dim=-1)

    # 第四步:用注意力权重对 V 加权求和
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

为什么除以 √d_k?

当 d_k 比较大的时候,Q·K^T 的结果会很大,导致 softmax 进入饱和区(梯度接近0)。除以 √d_k 就是把数值拉回到合理范围。

1.4 一个直观的例子

假设我们处理这句话:["我", "爱", "AI"]

makefile 复制代码
注意力权重矩阵(softmax后):
          我      爱      AI
    我   [0.7,   0.2,   0.1]   ← "我"主要关注自己
    爱   [0.3,   0.4,   0.3]   ← "爱"同时关注"我"和"AI"
    AI   [0.1,   0.3,   0.6]   ← "AI"主要关注自己和"爱"

看到没?模型自动学到了"爱"连接了"我"和"AI"这个语义关系。


二、Multi-Head Attention:多角度观察,信息更全面

2.1 为什么需要多头?

单头注意力只有一个 Q、K、V,意味着每个词只能从一个角度去关注其他词。

但语言是复杂的!一个词可能同时跟多个词有关系,而且关系的类型不同:

"The bank of the river" → bank 关注 river(位置关系)

"I went to the bank" → bank 关注 went、money(语义关系)

多头注意力的思路:用多组 Q、K、V,让模型从不同的"视角"同时关注信息

2.2 结构图解

css 复制代码
输入 X
  │
  ├──→ Linear → Q₁, K₁, V₁ → Scaled Attention → Head₁ ──┐
  ├──→ Linear → Q₂, K₂, V₂ → Scaled Attention → Head₂ ──┤
  ├──→ Linear → Q₃, K₃, V₃ → Scaled Attention → Head₃ ──┼→ Concat → Linear → 输出
  ├──→ ...                                              │
  └──→ Linear → Q₈, K₈, V₈ → Scaled Attention → Head₈ ──┘

原论文用了 8 个头(h=8),每个头的维度是 d_model/h。比如 d_model=512,每个头的维度就是 64。

2.3 代码实现

python 复制代码
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 每个头的维度 = 64

        # Q、K、V 的线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # 输出的线性变换层
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # 1. 线性变换得到 Q、K、V
        Q = self.W_q(x)  # (batch, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        # 2. 拆分成多个头
        # reshape 成 (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # 3. 计算缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        # 4. 拼接所有头
        # (batch, n_heads, seq_len, d_k) → (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )

        # 5. 最终线性变换
        output = self.W_o(attn_output)

        return output

踩坑提醒 :实现多头注意力的时候,transposeview 的维度顺序特别容易搞混。建议先跑一个小例子(seq_len=4, d_model=16, n_heads=4)打印中间张量的 shape,确认无误再上大模型。

2.4 多头到底学到了什么?

不同头会关注不同类型的信息:

头编号 可能关注的模式 例子
Head 1 相邻词的关系 "The cat sat"
Head 2 语法依赖 "The cat that ate the fish"
Head 3 指代关系 "The cat because it was tired"
Head 4 位置关系 句首词关注句尾词

这就是多头的魅力------不指定关注什么,让模型自己学


三、Positional Encoding:给词加上"位置坐标"

3.1 一个关键问题

前面说到,Transformer 是并行处理所有词的。这带来了一个严重问题:

它不知道词的顺序!

对,你没听错。把 "狗咬人" 和 "人咬狗" 输入 Transformer,如果不加位置信息,它觉得这两句话是一样的。

RNN/LSTM 天然有序(一个一个处理),所以不存在这个问题。但 Transformer 的并行计算是以丢失位置信息为代价的。

解决方案就是:手动给每个词加一个"位置编码"

3.2 位置编码怎么做?

原论文用的是 正弦-余弦位置编码(Sinusoidal Positional Encoding):

python 复制代码
# 位置编码公式
# PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
# PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model=512, max_len=5000):
        super().__init__()

        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (max_len, 1)

        # 计算分母项: 10000^(2i/d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # 偶数位置用 sin,奇数位置用 cos
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度

        # (1, max_len, d_model) 方便广播
        pe = pe.unsqueeze(0)

        # 注册为 buffer(不参与梯度更新,但会保存到模型中)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        # 把位置编码加到输入上
        x = x + self.pe[:, :x.size(1), :]
        return x

3.3 为什么用 sin/cos?

你可能会问:为什么不用简单的 0, 1, 2, 3... 表示位置?

原因有三:

① 可以处理比训练序列更长的文本

sin/cos 是周期函数,天然支持外推。而学习到的位置编码(如 BERT 用的)对超长文本就不太友好。

② 每个位置有唯一的编码

不同位置的正弦波组合是唯一的,就像指纹一样。

③ 模型可以学到相对位置

这个是 sin/cos 编码最巧妙的地方。通过三角函数公式:

scss 复制代码
sin(α + β) = sin(α)cos(β) + cos(α)sin(β)

模型可以通过线性变换,从绝对位置编码中推导出相对位置关系

实际工作中 :BERT 用的是可学习位置编码(Learned Position Embedding),就是直接把位置编码当参数训练。GPT 也是。只有原始 Transformer 和一些长文本模型(如 Transformer-XL)还在用 sin/cos。两种方式各有优劣,sin/cos 泛化性更好,可学习的编码在训练长度内表现更优。

3.4 位置编码可视化

ini 复制代码
位置编码热力图(每个位置 × 每个维度):

维度 →  0   1   2   3   4   5   6   7
位置 ↓
  0   [ 0.0  1.0  0.0  1.0  0.0  1.0  0.0  1.0 ]
  1   [ 0.84 0.54 0.01 1.0  0.0  1.0  0.0  1.0 ]
  2   [ 0.91 -0.42 0.02 1.0  0.0  1.0  0.0  1.0 ]
  3   [ 0.14 -0.99 0.03 1.0  0.0  1.0  0.0  1.0 ]
  ...
 50   [...变化越来越慢的低频部分...]

可以看到:低维度变化快(高频),高维度变化慢(低频)。这种多尺度的设计让模型能同时捕捉局部和全局的位置关系。


四、三大组件如何协同工作?

把上面三个部分串起来,Transformer 的编码器(Encoder)一层的完整流程是:

scss 复制代码
输入词向量 (Word Embedding)
        │
        ▼
  ┌─────────────────────┐
  │ Positional Encoding │  ← 加上位置信息
  └────────┬────────────┘
           ▼
  ┌──────────────────────┐
  │ Multi-Head Attention │  ← 自注意力:词与词之间交互
  └────────┬─────────────┘
           ▼
  ┌───────────────────────────┐
  │ Add & Norm (残差+层归一化) │
  └────────┬──────────────────┘
           ▼
  ┌──────────────────────┐
  │ Feed Forward Network │  ← 全连接层:逐位置变换
  └────────┬─────────────┘
           ▼
  ┌───────────────────────────┐
  │ Add & Norm (残差+层归一化) │
  └────────┬──────────────────┘
           ▼
       输出到下一层

用代码串起来:

python 复制代码
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 1. Multi-Head Self-Attention + 残差连接 + LayerNorm
        attn_out = self.self_attn(x)
        x = self.norm1(x + self.dropout(attn_out))

        # 2. Feed Forward + 残差连接 + LayerNorm
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))

        return x

残差连接(Residual Connection) 的作用:把输入直接加到输出上,梯度可以"跳过"中间层直接回传,有效缓解深层网络的梯度消失问题。这个设计来自 ResNet,被 Transformer 借鉴过来了。


五、速查对照表

组件 解决什么问题 核心思想 一句话记忆
Self-Attention 词与词之间怎么交互 Q·K^T 算相似度,加权取 V "你跟谁最相关?"
Multi-Head Attention 单一视角不够 多组 QKV 并行,多角度观察 "多看几眼,看得更全"
Positional Encoding 并行计算丢失位置 sin/cos 注入位置信息 "告诉模型你在哪"
相关推荐
云烟成雨TD2 小时前
Spring AI Alibaba 1.x 系列【60】检查点机制原理与全流程剖析
java·人工智能·spring
极光代码工作室2 小时前
基于机器学习的二手商品价格预测系统
人工智能·python·深度学习·机器学习
YueJoy.AI2 小时前
AI应用的隐私保护:从设计开始的隐私
人工智能·ai·语言模型
小当家.1052 小时前
PostgreSQL 做向量数据库:pgvector 在 RAG 中的实战与多场景适配
数据库·人工智能·postgresql·rag
ForgeAI码匠2 小时前
Maven 多模块项目如何避免越写越乱?Forge Admin 的模块边界实践
java·人工智能·开源·maven
Dola_Zou2 小时前
工业软件防破解避坑指南:CodeMeter 全流程入门与选型(上)
人工智能·自动化·视觉检测·软件工程·软件加密
生成论实验室2 小时前
我们给AI装上了判断力
人工智能·深度学习·语言模型·机器人·自动驾驶
掘金安东尼2 小时前
国内通用智能体(本地操作型 Agent)深度测评对比
人工智能
完成大叔2 小时前
Agent感知模式的情景化联想应用
人工智能