目录
- 1.原理和公式
-
- [1.1 核心三要素:Q、K、V](#1.1 核心三要素:Q、K、V)
- [1.2 数学公式](#1.2 数学公式)
- [1.3 流程图](#1.3 流程图)
- 2.代码实现
-
- [2.1 源码实现](#2.1 源码实现)
- [2.2 用 PyTorch 内置模块实现](#2.2 用 PyTorch 内置模块实现)
- [2.3 示例](#2.3 示例)
- [3.关于单头 vs 多头区别](#3.关于单头 vs 多头区别)
1.原理和公式
单头注意力计算目的:计算一句话里的每个词,让每个词去"问"所有其他词"你跟我有多相关?",然后根据相关度,把其他词的信息加权汇总到自己身上。
1.1 核心三要素:Q、K、V
每个输入 token 会被投影成三个向量:
| 名称 | 全称 | 含义 | 类比 |
|---|---|---|---|
| Q | Query(查询) | "我在找什么?" | 你走进图书馆心里想找的那本书的类型 |
| K | Key(键) | "我是什么?" | 每本书封面上的标签 |
| V | Value(值) | "我包含什么内容?" | 书本身的正文内容 |
注意力做的事:用 Q 去和所有 K 做匹配,匹配度作为权重,去加权汇总所有 V。
"我 爱 吃 苹果"
↓ ↓ ↓ ↓
每个词都投影出自己的 Q、K、V
1.2 数学公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
-
四个步骤:
输入:4个词的句子,每个词用 512维向量表示
X = [4, 512] ← 4个token,每个512维第①步:线性投影得到 Q、K、V
Q = X · W_Q → [4, 64]
K = X · W_K → [4, 64]
V = X · W_V → [4, 64]
(64 是 d_k,注意力头的维度)第②步:计算注意力分数(Q 和 K 的点积)
scores = Q · K^T → [4, 4]直观含义: 我 爱 吃 苹果 我 [0.8 0.1 0.3 0.5] ← "我"跟各个词的相关度 爱 [0.2 0.9 0.4 0.1] ← "爱"跟各个词的相关度 吃 [0.1 0.3 0.8 0.7] ← "吃"跟各个词的相关度 苹果 [0.3 0.1 0.6 0.9] ← "苹果"跟各个词的相关度第③步:缩放 + Softmax(归一化成概率)
scores = scores / √64 ← 除 √d_k,防止点积太大
attn_weights = softmax(scores) → [4, 4],每行加起来=1我 爱 吃 苹果 我 [0.60 0.05 0.15 0.20] ← "我"60%关注自己,20%关注"苹果" 爱 [0.10 0.60 0.20 0.10] 吃 [0.05 0.10 0.50 0.35] 苹果 [0.10 0.05 0.25 0.60]第④步:用权重加权汇总 V
output = attn_weights · V → [4, 64]第 i 行的输出 = Σⱼ (第i行第j列的权重 × V的第j行) "苹果"的新表示 = 0.10×V_我 + 0.05×V_爱 + 0.25×V_吃 + 0.60×V_苹果
1.3 流程图
X [4×512]
┌──────┼──────┐
W_Q W_K W_V
↓ ↓ ↓
Q[4×64] K[4×64] V[4×64]
│ │ │
└──┬───┘ │
↓ │
Q × K^T [4×4] │
↓ │
÷ √64 (缩放) │
↓ │
softmax [4×4] │
↓ │
× V ──────────────┘
↓
output [4×64]
2.代码实现
2.1 源码实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SingleHeadAttention(nn.Module):
"""单头注意力 ------ 从零实现"""
def __init__(self, d_model=512, d_k=64):
"""
d_model: 输入向量的维度(比如 512)
d_k: Q, K, V 投影后的维度(头维度)
"""
super().__init__()
self.d_k = d_k
# 三个线性投影层
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
返回: [batch_size, seq_len, d_k]
"""
# ──── 第①步:线性投影 ────
Q = self.W_Q(x) # [B, L, d_k]
K = self.W_K(x) # [B, L, d_k]
V = self.W_V(x) # [B, L, d_k]
# ──── 第②步:计算注意力分数 ────
# Q @ K^T,即 Q 和 K 的转置相乘
# [B, L, d_k] × [B, d_k, L] → [B, L, L]
scores = torch.matmul(Q, K.transpose(-2, -1))
# ──── 第③步:缩放 + Softmax ────
scores = scores / math.sqrt(self.d_k) # 缩放
attn_weights = F.softmax(scores, dim=-1) # 归一化
# ──── 第④步:加权汇总 ────
# attn_weights @ V
# [B, L, L] × [B, L, d_k] → [B, L, d_k]
output = torch.matmul(attn_weights, V)
return output
2.2 用 PyTorch 内置模块实现
python
import torch.nn as nn
# PyTorch 内置的多头注意力,设置 num_heads=1 就是单头
attention = nn.MultiheadAttention(
embed_dim=512, # d_model
num_heads=1, # ← 关键:头数设为1
batch_first=True, # 让输入格式为 [B, L, D]
)
# 使用
x = torch.randn(2, 10, 512) # [batch=2, seq_len=10, d_model=512]
output, attn_weights = attention(x, x, x) # Q, K, V 都来自 x(自注意力)
print(output.shape) # [2, 10, 512]
print(attn_weights.shape) # [2, 10, 10] ← 注意力权重矩阵
2.3 示例
python
import torch
import torch.nn.functional as F
import math
# 造数据:1个句子,4个词,每个词8维
x = torch.randn(1, 4, 8)
# 手动定义权重(为了可复现)
torch.manual_seed(42)
W_Q = torch.randn(8, 4) # [d_model=8, d_k=4]
W_K = torch.randn(8, 4)
W_V = torch.randn(8, 4)
# 第①步
Q = x @ W_Q # [1, 4, 4]
K = x @ W_K # [1, 4, 4]
V = x @ W_V # [1, 4, 4]
# 第②步
scores = Q @ K.transpose(-2, -1) # [1, 4, 4]
# 第③步
scores = scores / math.sqrt(4) # 除 √4=2
weights = F.softmax(scores, dim=-1) # [1, 4, 4]
# 第④步
output = weights @ V
print("注意力权重矩阵(4×4):")
print(weights[0])
# 词0 词1 词2 词3
# 词0 [0.34 0.21 0.12 0.33 ] ← 词0 最关注自己和词3
# 词1 [0.19 0.38 0.27 0.16 ] ← 词1 最关注自己
# 词2 [0.08 0.30 0.45 0.17 ] ← 词2 最关注自己
# 词3 [0.25 0.14 0.18 0.43 ] ← 词3 最关注自己
print("\n每行之和(应该都=1):")
print(weights[0].sum(dim=-1)) # [1.0, 1.0, 1.0, 1.0] ✓
3.关于单头 vs 多头区别
| 单头 | 多头 | |
|---|---|---|
| 头数 | 1 | 8(常见) |
| Q/K/V 维度 | 1个 64维 | 8个 8维(拼起来还是64) |
| 关注模式 | 只能学到一种关系 | 同时学多种关系(语法、语义、位置...) |
| 类比 | 一个人看一句话 | 8个人看一句话,每人关注不同角度 |
多头就是把 Q、K、V 各切成 N 份,每份独立做一遍上面这套流程,最后拼回来再过一遍线性层。