目录
[1 注意力机制为什么用多头的](#1 注意力机制为什么用多头的)
[2 多头注意力的实际数学计算解释](#2 多头注意力的实际数学计算解释)
[2.1 误区1:多头就是直接把512分成8组](#2.1 误区1:多头就是直接把512分成8组)
[2.2 误区2:真正程序中的数学计算就是计算8次矩阵乘法](#2.2 误区2:真正程序中的数学计算就是计算8次矩阵乘法)
[2.2.1 大矩阵投影(1次计算完成所有头)](#2.2.1 大矩阵投影(1次计算完成所有头))
[2.2.2 分离头(无计算开销的reshape)](#2.2.2 分离头(无计算开销的reshape))
[2.2.3 批量注意力计算(在分离的头空间进行)](#2.2.3 批量注意力计算(在分离的头空间进行))
[2.2.4 合并头(恢复原始维度)](#2.2.4 合并头(恢复原始维度))
[2.3 乘以Wo线性融合的作用](#2.3 乘以Wo线性融合的作用)
[3 注意力机制为什么要用多层](#3 注意力机制为什么要用多层)
abstract:
多头与多层:多角度看问题,逐层深入理解
每个头独立计算注意力,WQ(i),WK(i),WV(i)∈R512×64,headi=Attention(Qi,Ki,Vi),输出形状是:headi∈R(batch,seq,64)
我是不是这么理解,第一层来说,一个token用一个512行向量表示,然后乘以一个(512, 64)的W矩阵,得到64维度的向量,但是其实这个64维度并不是将原来512维度里面的0-63,64-127直接这样切片成8个头,他是乘以一个矩阵W做的线性变换或者说投影
大矩阵投影,分头计算注意力,多头拼接+线性变换
1 注意力机制为什么用多头的
多头注意力(Multi-Head Attention)的核心思想,就是让不同的"头"关注输入序列中不同类型的依赖关系或语义模式,因为单个注意力机制的表达能力有限,而"多头"可以让模型并行地学习多种不同的依赖模式。
就像你用多个专家开会:
- 一个专家看语法结构
- 一个专家看语义角色
- 一个专家看情感倾向
- 最后综合意见做决策
Transformer 的"多头"就是这些"专家"。
| 头的类型 | 关注的内容 |
|---|---|
| 语法头 | 主谓一致、依存句法(如"dog" ← "barks") |
| 指代消解头 | "he" 指向 "John","it" 指向 "the book" |
| 词性相关头 | 名词-动词、形容词-名词 搭配 |
| 局部窗口头 | 只关注相邻词(类似 CNN 的局部感受野) |
| 全局关注头 | 关注句首 [CLS] 或句尾标点 |
| 数值/时间头 | 关注数字、日期之间的关系 |
| 情感极性头 | 正面词与负面词的对比 |
这带来了几个好处:
优势 1:表达能力更强
- 单头只能学一种注意力模式
- 多头可以同时学:语法、指代、语义、情感......
优势 2:缓解注意力权重的"平均化"
- 单头可能被迫"兼顾"多种关系,导致注意力分散
- 多头可以让每个头专精一项任务
优势 3:类似卷积神经网络的"滤波器"
- CNN 用多个卷积核检测边缘、纹理、颜色等
- 多头注意力用多个头检测语法、语义、指代等
个人理解:比如某个头的注意力机制是问哪些形容词之间的关联,另一个头是关注的其他修饰词之间的关联,其实就是每个头关注不一样的点,就像cnn不同的卷积核关注的是不同方面的特征一个道理。
2 多头注意力的实际数学计算解释
2.1 误区1:多头就是直接把512分成8组
我最开始理解的多头就是,比如,我们的token向量是512维度的,然后有8个头的注意力机制,那么就是直接把0-63维作为第一个头,然后64-127作为第二个头,就这样切分,实际上不是的,而是将这个512维的向量乘以一个512*64的矩阵得到一个64维度的向量,这里应该理解成是线性变换,对每个头来说,它都会对 512 维的 token 向量做一次线性变换:
Q_head_i = X · Wq_i # Wq_i 的 shape = (512, 64)
K_head_i = X · Wk_i # Wk_i 的 shape = (512, 64)
V_head_i = X · Wv_i # Wv_i 的 shape = (512, 64)
也就是说:
-
Wq_i、Wk_i、Wv_i 各自"决定"了一组 512→64 的投影方式
-
每个头都把 512 维向量"投影"到不同的 64 维子空间
-
这不是分段切分,而是完全独立的 八种线性投影
然后经过8个不同的矩阵,相当于将512维度的向量经过乘以矩阵后投影到不同的向量空间,而这里不同的W就意味着是不同的观察方向,
-
有的头会偏向语法关系(如主谓依赖)
-
有的头会偏向位置信息(如邻近词影响)
-
有的头会偏向特定语义类别(如否定、疑问、实体)
-
有的头可能只盯着句子结构(如分隔符、标点)
-
还有些头专门学全局依赖(跨很远的 token)
2.2 误区2:真正程序中的数学计算就是计算8次矩阵乘法
实际上在程序中的真正数学计算的时候,比如将token的矩阵去乘以Wq Wk Wv的时候,数学上等价于8个512×64矩阵,但实际计算为了加速,使用拼接后的512×512大矩阵一次性完成投影,只不过求完Q KV之后,在计算注意力机制的时候,是用512*64维度的矩阵去做运算了,包括计算softmax也是在小的注意力矩阵上计算的,然后计算完成之后再做拼接,也就是将8个512*64维的矩阵直接拼接成512*512的矩阵,这个拼接过程是没有数学计算的,就是单纯的拼接起来,然后再乘以Wo做一个线性变换,
2.2.1 大矩阵投影(1次计算完成所有头)
python
# 输入: X (batch=32, seq_len=10, 512)
Q_all = X @ Wq # Wq: (512,512) → Q_all: (32,10,512)
K_all = X @ Wk # Wk: (512,512) → K_all: (32,10,512)
V_all = X @ Wv # Wv: (512,512) → V_all: (32,10,512)
- 计算层面 :单次大矩阵乘法(512×512)比8次小矩阵乘法(8×512×64)FLOPS相同,但减少了87%的内核启动开销
- 内存层面:连续内存访问模式使带宽利用率提升3.2倍(A100实测)
- 硬件层面:GPU对大矩阵运算有专门优化,小矩阵乘法难以充分利用计算单元
2.2.2 分离头(无计算开销的reshape)
这一步没有数学运算,只是调整 tensor 视图,使注意力计算可以按头批量并行。
python
# 重塑 + 转置 使头维度成为batch维度 (关键优化!)
Q = Q_all.view(32, 10, 8, 64).transpose(1, 2) # → (32, 8, 10, 64)
K = K_all.view(32, 10, 8, 64).transpose(1, 2) # → (32, 8, 10, 64)
V = V_all.view(32, 10, 8, 64).transpose(1, 2) # → (32, 8, 10, 64)
view() 和 transpose() 不涉及任何数学计算,仅是内存指针的重新排列
2.2.3 批量注意力计算(在分离的头空间进行)
python
# (1) 计算注意力分数 (每个头独立)
attn_scores = Q @ K.transpose(-2, -1) # (32,8,10,10)
# 维度说明:
# - 最后两个维度(10,10) = seq_len × seq_len 的注意力矩阵
# - 第2维(8) = 8个独立头
# - 第1维(32) = batch
# (2) 缩放 + softmax (每个头独立!)
attn_scores = attn_scores / math.sqrt(64) # 缩放因子 = sqrt(head_dim)
attn_probs = F.softmax(attn_scores, dim=-1) # **在最后一个维度应用softmax**
# (3) 加权value (每个头独立)
attn_output = attn_probs @ V # (32,8,10,64)
2.2.4 合并头(恢复原始维度)
python
attn_output = attn_output.transpose(1, 2).contiguous() # (32,10,8,64)
attn_output = attn_output.view(32, 10, 512) # 合并8×64=512
output = attn_output @ Wo # Wo: (512,512) → 最终输出 (32,10,512)
2.3 乘以Wo线性融合的作用
那么可以理解为,8个专家每个人提交了64页的评审报告,然后组成了一个512页的,然后专家主席把这512的报告看了一遍,然后进行了融合,最后写成了一个新的512页的评审报告,
3 注意力机制为什么要用多层
| 层级 | 典型层号 | 认知能力 | 人类认知类比 | 可视化证据 |
|---|---|---|---|---|
| 表层 | 1-3层 | 词汇关系、局部语法、词性标注 | "看到单词及其邻居" | 注意力聚焦相邻token |
| 中层 | 4-8层 | 句法结构、短距离依赖、语义角色 | "理解句子主干和修饰关系" | 注意力跨越短距离依赖 |
| 深层 | 9-12层+ | 语义表示、长距离依赖、任务特定特征 | "把握段落主旨和隐含意义" | 注意力连接关键语义节点 |
"Transformer的认知层次受输入窗口严格约束:
- 1-3层(表层) :解析当前窗口内的词汇关系(如'猫追老鼠'中'追'的主语/宾语)
- 4-8层(中层) :理解当前窗口内的复杂结构(如'虽然下雨,但我带伞'的转折逻辑)
- 9-12层(深层) :整合当前窗口内的跨句语义(如'今天下雨。我带了伞。'的因果推断)
关键约束 :标准BERT的'当前窗口'仅为512个token (≈3-4句话)。它无法像人类一样自然扩展到段落或全文理解------当输入超出512 token时,模型会'失忆',前文信息永久丢失。"
最理想的情况当然是:Transformer前面几层理解句子中不同词语之间的关系,中间层理解不同句子之间的关系以及含义,然后深层理解不同段落和整篇文章的含义。但在当前技术条件下,由于token长度的限制,Transformer无法一次处理完整的一篇文章 。多层设计的核心价值在于:在有限窗口内实现认知层次的跃迁------从原始字符到语义理解,使模型在512 token约束下达到认知效率的最优化。