单头注意力(Single-Head Attention)

目录

  • 1.原理和公式
    • [1.1 核心三要素:Q、K、V](#1.1 核心三要素:Q、K、V)
    • [1.2 数学公式](#1.2 数学公式)
    • [1.3 流程图](#1.3 流程图)
  • 2.代码实现
    • [2.1 源码实现](#2.1 源码实现)
    • [2.2 用 PyTorch 内置模块实现](#2.2 用 PyTorch 内置模块实现)
    • [2.3 示例](#2.3 示例)
  • [3.关于单头 vs 多头区别](#3.关于单头 vs 多头区别)

1.原理和公式

单头注意力计算目的:计算一句话里的每个词,让每个词去"问"所有其他词"你跟我有多相关?",然后根据相关度,把其他词的信息加权汇总到自己身上。

1.1 核心三要素:Q、K、V

每个输入 token 会被投影成三个向量:

名称 全称 含义 类比
Q Query(查询) "我在找什么?" 你走进图书馆心里想找的那本书的类型
K Key(键) "我是什么?" 每本书封面上的标签
V Value(值) "我包含什么内容?" 书本身的正文内容

注意力做的事:用 Q 去和所有 K 做匹配,匹配度作为权重,去加权汇总所有 V。

复制代码
"我  爱  吃  苹果"
 ↓   ↓   ↓   ↓
每个词都投影出自己的 Q、K、V

1.2 数学公式

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

  • 四个步骤:

    输入:4个词的句子,每个词用 512维向量表示
    X = [4, 512] ← 4个token,每个512维

    第①步:线性投影得到 Q、K、V
    Q = X · W_Q → [4, 64]
    K = X · W_K → [4, 64]
    V = X · W_V → [4, 64]
    (64 是 d_k,注意力头的维度)

    第②步:计算注意力分数(Q 和 K 的点积)
    scores = Q · K^T → [4, 4]

    复制代码
    直观含义:
         我   爱   吃  苹果
    我  [0.8  0.1  0.3  0.5]   ← "我"跟各个词的相关度
    爱  [0.2  0.9  0.4  0.1]   ← "爱"跟各个词的相关度
    吃  [0.1  0.3  0.8  0.7]   ← "吃"跟各个词的相关度
    苹果 [0.3  0.1  0.6  0.9]   ← "苹果"跟各个词的相关度

    第③步:缩放 + Softmax(归一化成概率)
    scores = scores / √64 ← 除 √d_k,防止点积太大
    attn_weights = softmax(scores) → [4, 4],每行加起来=1

    复制代码
         我    爱    吃   苹果
    我  [0.60  0.05  0.15  0.20]   ← "我"60%关注自己,20%关注"苹果"
    爱  [0.10  0.60  0.20  0.10]
    吃  [0.05  0.10  0.50  0.35]
    苹果 [0.10  0.05  0.25  0.60]

    第④步:用权重加权汇总 V
    output = attn_weights · V → [4, 64]

    复制代码
    第 i 行的输出 = Σⱼ (第i行第j列的权重 × V的第j行)
    
    "苹果"的新表示 = 0.10×V_我 + 0.05×V_爱 + 0.25×V_吃 + 0.60×V_苹果

1.3 流程图

复制代码
                    X [4×512]
                 ┌──────┼──────┐
                W_Q    W_K    W_V
                 ↓      ↓      ↓
               Q[4×64] K[4×64] V[4×64]
                 │      │       │
                 └──┬───┘       │
                    ↓           │
              Q × K^T [4×4]     │
                    ↓           │
              ÷ √64 (缩放)      │
                    ↓           │
              softmax [4×4]     │
                    ↓           │
              × V ──────────────┘
                    ↓
              output [4×64]

2.代码实现

2.1 源码实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SingleHeadAttention(nn.Module):
    """单头注意力 ------ 从零实现"""
    
    def __init__(self, d_model=512, d_k=64):
        """
        d_model: 输入向量的维度(比如 512)
        d_k:     Q, K, V 投影后的维度(头维度)
        """
        super().__init__()
        self.d_k = d_k
        
        # 三个线性投影层
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)
    
    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        返回: [batch_size, seq_len, d_k]
        """
        # ──── 第①步:线性投影 ────
        Q = self.W_Q(x)   # [B, L, d_k]
        K = self.W_K(x)   # [B, L, d_k]
        V = self.W_V(x)   # [B, L, d_k]
        
        # ──── 第②步:计算注意力分数 ────
        # Q @ K^T,即 Q 和 K 的转置相乘
        # [B, L, d_k] × [B, d_k, L] → [B, L, L]
        scores = torch.matmul(Q, K.transpose(-2, -1))
        
        # ──── 第③步:缩放 + Softmax ────
        scores = scores / math.sqrt(self.d_k)     # 缩放
        attn_weights = F.softmax(scores, dim=-1)   # 归一化
        
        # ──── 第④步:加权汇总 ────
        # attn_weights @ V
        # [B, L, L] × [B, L, d_k] → [B, L, d_k]
        output = torch.matmul(attn_weights, V)
        
        return output

2.2 用 PyTorch 内置模块实现

python 复制代码
import torch.nn as nn

# PyTorch 内置的多头注意力,设置 num_heads=1 就是单头
attention = nn.MultiheadAttention(
    embed_dim=512,    # d_model
    num_heads=1,      # ← 关键:头数设为1
    batch_first=True, # 让输入格式为 [B, L, D]
)

# 使用
x = torch.randn(2, 10, 512)  # [batch=2, seq_len=10, d_model=512]
output, attn_weights = attention(x, x, x)  # Q, K, V 都来自 x(自注意力)

print(output.shape)        # [2, 10, 512]
print(attn_weights.shape)  # [2, 10, 10] ← 注意力权重矩阵

2.3 示例

python 复制代码
import torch
import torch.nn.functional as F
import math

# 造数据:1个句子,4个词,每个词8维
x = torch.randn(1, 4, 8)

# 手动定义权重(为了可复现)
torch.manual_seed(42)
W_Q = torch.randn(8, 4)   # [d_model=8, d_k=4]
W_K = torch.randn(8, 4)
W_V = torch.randn(8, 4)

# 第①步
Q = x @ W_Q   # [1, 4, 4]
K = x @ W_K   # [1, 4, 4]
V = x @ W_V   # [1, 4, 4]

# 第②步
scores = Q @ K.transpose(-2, -1)   # [1, 4, 4]

# 第③步
scores = scores / math.sqrt(4)      # 除 √4=2
weights = F.softmax(scores, dim=-1) # [1, 4, 4]

# 第④步
output = weights @ V

print("注意力权重矩阵(4×4):")
print(weights[0])
#          词0     词1     词2     词3
# 词0  [0.34   0.21   0.12   0.33  ]   ← 词0 最关注自己和词3
# 词1  [0.19   0.38   0.27   0.16  ]   ← 词1 最关注自己
# 词2  [0.08   0.30   0.45   0.17  ]   ← 词2 最关注自己
# 词3  [0.25   0.14   0.18   0.43  ]   ← 词3 最关注自己

print("\n每行之和(应该都=1):")
print(weights[0].sum(dim=-1))   # [1.0, 1.0, 1.0, 1.0] ✓

3.关于单头 vs 多头区别

单头 多头
头数 1 8(常见)
Q/K/V 维度 1个 64维 8个 8维(拼起来还是64)
关注模式 只能学到一种关系 同时学多种关系(语法、语义、位置...)
类比 一个人看一句话 8个人看一句话,每人关注不同角度

多头就是把 Q、K、V 各切成 N 份,每份独立做一遍上面这套流程,最后拼回来再过一遍线性层。

相关推荐
hudawei9966 个月前
W_q,W_k,W_v矩阵是怎么确定的?
矩阵·transformer·梯度下降·多头注意力·单头注意力