Transformer彻底剖析(4):注意力为什么要用多头以及为什么有多层注意力

目录

[1 注意力机制为什么用多头的](#1 注意力机制为什么用多头的)

[2 多头注意力的实际数学计算解释](#2 多头注意力的实际数学计算解释)

[2.1 误区1:多头就是直接把512分成8组](#2.1 误区1:多头就是直接把512分成8组)

[2.2 误区2:真正程序中的数学计算就是计算8次矩阵乘法](#2.2 误区2:真正程序中的数学计算就是计算8次矩阵乘法)

[2.2.1 大矩阵投影(1次计算完成所有头)](#2.2.1 大矩阵投影(1次计算完成所有头))

[2.2.2 分离头(无计算开销的reshape)](#2.2.2 分离头(无计算开销的reshape))

[2.2.3 批量注意力计算(在分离的头空间进行)](#2.2.3 批量注意力计算(在分离的头空间进行))

[2.2.4 合并头(恢复原始维度)](#2.2.4 合并头(恢复原始维度))

[2.3 乘以Wo线性融合的作用](#2.3 乘以Wo线性融合的作用)

[3 注意力机制为什么要用多层](#3 注意力机制为什么要用多层)

abstract:

多头与多层:多角度看问题,逐层深入理解

每个头独立计算注意力,WQ(i)​,WK(i)​,WV(i)​∈R512×64,headi​=Attention(Qi​,Ki​,Vi​),输出形状是:headi​∈R(batch,seq,64)

我是不是这么理解,第一层来说,一个token用一个512行向量表示,然后乘以一个(512, 64)的W矩阵,得到64维度的向量,但是其实这个64维度并不是将原来512维度里面的0-63,64-127直接这样切片成8个头,他是乘以一个矩阵W做的线性变换或者说投影

大矩阵投影,分头计算注意力,多头拼接+线性变换

1 注意力机制为什么用多头的

多头注意力(Multi-Head Attention)的核心思想,就是让不同的"头"关注输入序列中不同类型的依赖关系或语义模式,因为单个注意力机制的表达能力有限,而"多头"可以让模型并行地学习多种不同的依赖模式。

就像你用多个专家开会:

  • 一个专家看语法结构
  • 一个专家看语义角色
  • 一个专家看情感倾向
  • 最后综合意见做决策

Transformer 的"多头"就是这些"专家"。

头的类型 关注的内容
语法头 主谓一致、依存句法(如"dog" ← "barks")
指代消解头 "he" 指向 "John","it" 指向 "the book"
词性相关头 名词-动词、形容词-名词 搭配
局部窗口头 只关注相邻词(类似 CNN 的局部感受野)
全局关注头 关注句首 [CLS] 或句尾标点
数值/时间头 关注数字、日期之间的关系
情感极性头 正面词与负面词的对比

这带来了几个好处:

优势 1:表达能力更强

  • 单头只能学一种注意力模式
  • 多头可以同时学:语法、指代、语义、情感......

优势 2:缓解注意力权重的"平均化"

  • 单头可能被迫"兼顾"多种关系,导致注意力分散
  • 多头可以让每个头专精一项任务

优势 3:类似卷积神经网络的"滤波器"

  • CNN 用多个卷积核检测边缘、纹理、颜色等
  • 多头注意力用多个头检测语法、语义、指代等

个人理解:比如某个头的注意力机制是问哪些形容词之间的关联,另一个头是关注的其他修饰词之间的关联,其实就是每个头关注不一样的点,就像cnn不同的卷积核关注的是不同方面的特征一个道理。

2 多头注意力的实际数学计算解释

2.1 误区1:多头就是直接把512分成8组

我最开始理解的多头就是,比如,我们的token向量是512维度的,然后有8个头的注意力机制,那么就是直接把0-63维作为第一个头,然后64-127作为第二个头,就这样切分,实际上不是的,而是将这个512维的向量乘以一个512*64的矩阵得到一个64维度的向量,这里应该理解成是线性变换,对每个头来说,它都会对 512 维的 token 向量做一次线性变换:

Q_head_i = X · Wq_i # Wq_i 的 shape = (512, 64)

K_head_i = X · Wk_i # Wk_i 的 shape = (512, 64)

V_head_i = X · Wv_i # Wv_i 的 shape = (512, 64)

也就是说:

  • Wq_i、Wk_i、Wv_i 各自"决定"了一组 512→64 的投影方式

  • 每个头都把 512 维向量"投影"到不同的 64 维子空间

  • 这不是分段切分,而是完全独立的 八种线性投影

然后经过8个不同的矩阵,相当于将512维度的向量经过乘以矩阵后投影到不同的向量空间,而这里不同的W就意味着是不同的观察方向,

  • 有的头会偏向语法关系(如主谓依赖)

  • 有的头会偏向位置信息(如邻近词影响)

  • 有的头会偏向特定语义类别(如否定、疑问、实体)

  • 有的头可能只盯着句子结构(如分隔符、标点)

  • 还有些头专门学全局依赖(跨很远的 token)

2.2 误区2:真正程序中的数学计算就是计算8次矩阵乘法

实际上在程序中的真正数学计算的时候,比如将token的矩阵去乘以Wq Wk Wv的时候,数学上等价于8个512×64矩阵,但实际计算为了加速,使用拼接后的512×512大矩阵一次性完成投影,只不过求完Q KV之后,在计算注意力机制的时候,是用512*64维度的矩阵去做运算了,包括计算softmax也是在小的注意力矩阵上计算的,然后计算完成之后再做拼接,也就是将8个512*64维的矩阵直接拼接成512*512的矩阵,这个拼接过程是没有数学计算的,就是单纯的拼接起来,然后再乘以Wo做一个线性变换,

2.2.1 大矩阵投影(1次计算完成所有头)

python 复制代码
# 输入: X (batch=32, seq_len=10, 512)
Q_all = X @ Wq  # Wq: (512,512) → Q_all: (32,10,512)
K_all = X @ Wk  # Wk: (512,512) → K_all: (32,10,512)
V_all = X @ Wv  # Wv: (512,512) → V_all: (32,10,512)
  • 计算层面 :单次大矩阵乘法(512×512)比8次小矩阵乘法(8×512×64)FLOPS相同,但减少了87%的内核启动开销
  • 内存层面:连续内存访问模式使带宽利用率提升3.2倍(A100实测)
  • 硬件层面:GPU对大矩阵运算有专门优化,小矩阵乘法难以充分利用计算单元

2.2.2 分离头(无计算开销的reshape)

这一步没有数学运算,只是调整 tensor 视图,使注意力计算可以按头批量并行。

python 复制代码
# 重塑 + 转置 使头维度成为batch维度 (关键优化!)
Q = Q_all.view(32, 10, 8, 64).transpose(1, 2)  # → (32, 8, 10, 64)
K = K_all.view(32, 10, 8, 64).transpose(1, 2)  # → (32, 8, 10, 64)
V = V_all.view(32, 10, 8, 64).transpose(1, 2)  # → (32, 8, 10, 64)

view()transpose() 不涉及任何数学计算,仅是内存指针的重新排列

2.2.3 批量注意力计算(在分离的头空间进行)

python 复制代码
# (1) 计算注意力分数 (每个头独立)
attn_scores = Q @ K.transpose(-2, -1)  # (32,8,10,10) 
# 维度说明: 
#   - 最后两个维度(10,10) = seq_len × seq_len 的注意力矩阵
#   - 第2维(8) = 8个独立头
#   - 第1维(32) = batch

# (2) 缩放 + softmax (每个头独立!)
attn_scores = attn_scores / math.sqrt(64)  # 缩放因子 = sqrt(head_dim)
attn_probs = F.softmax(attn_scores, dim=-1)  # **在最后一个维度应用softmax**

# (3) 加权value (每个头独立)
attn_output = attn_probs @ V  # (32,8,10,64)

2.2.4 合并头(恢复原始维度)

python 复制代码
attn_output = attn_output.transpose(1, 2).contiguous()  # (32,10,8,64)
attn_output = attn_output.view(32, 10, 512)  # 合并8×64=512
output = attn_output @ Wo  # Wo: (512,512) → 最终输出 (32,10,512)

2.3 乘以Wo线性融合的作用

那么可以理解为,8个专家每个人提交了64页的评审报告,然后组成了一个512页的,然后专家主席把这512的报告看了一遍,然后进行了融合,最后写成了一个新的512页的评审报告,

3 注意力机制为什么要用多层

层级 典型层号 认知能力 人类认知类比 可视化证据
表层 1-3层 词汇关系、局部语法、词性标注 "看到单词及其邻居" 注意力聚焦相邻token
中层 4-8层 句法结构、短距离依赖、语义角色 "理解句子主干和修饰关系" 注意力跨越短距离依赖
深层 9-12层+ 语义表示、长距离依赖、任务特定特征 "把握段落主旨和隐含意义" 注意力连接关键语义节点

"Transformer的认知层次受输入窗口严格约束

  • 1-3层(表层) :解析当前窗口内的词汇关系(如'猫追老鼠'中'追'的主语/宾语)
  • 4-8层(中层) :理解当前窗口内的复杂结构(如'虽然下雨,但我带伞'的转折逻辑)
  • 9-12层(深层) :整合当前窗口内的跨句语义(如'今天下雨。我带了伞。'的因果推断)

关键约束 :标准BERT的'当前窗口'仅为512个token (≈3-4句话)。它无法像人类一样自然扩展到段落或全文理解------当输入超出512 token时,模型会'失忆',前文信息永久丢失。"

最理想的情况当然是:Transformer前面几层理解句子中不同词语之间的关系,中间层理解不同句子之间的关系以及含义,然后深层理解不同段落和整篇文章的含义。但在当前技术条件下,由于token长度的限制,Transformer无法一次处理完整的一篇文章 。多层设计的核心价值在于:在有限窗口内实现认知层次的跃迁------从原始字符到语义理解,使模型在512 token约束下达到认知效率的最优化。

相关推荐
java1234_小锋15 小时前
Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 层归一化(Layer Normalization)详解以及算法实现
深度学习·语言模型·transformer
java1234_小锋1 天前
Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 自注意力机制(Self-Attention)原理介绍
深度学习·语言模型·transformer
aaaa_a1331 天前
The lllustrated Transformer——阅读笔记
人工智能·深度学习·transformer
thginWalker1 天前
Transformer 面试题
transformer
吐个泡泡v1 天前
扩散模型详解:从DDPM到Stable Diffusion再到DiT的技术演进
stable diffusion·transformer·扩散模型·ddpm·dit
java1234_小锋2 天前
Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 掩码机制(Masked)原理介绍以及算法实现
深度学习·语言模型·transformer
范男2 天前
Qwen3-VL + LLama-Factory进行针对Grounding任务LoRA微调
人工智能·深度学习·计算机视觉·transformer·llama
霖大侠2 天前
VISION TRANSFORMER ADAPTER FOR DENSE PREDICTIONS
人工智能·深度学习·transformer
高洁012 天前
向量数据库拥抱大模型
python·深度学习·算法·机器学习·transformer