从零实现Transformer:第 2 部分 - 缩放点积注意力(Scaled Dot-Product Attention)

从零实现Transformer:第 2 部分 - 缩放点积注意力(Scaled Dot-Product Attention)

flyfish

对于一些名词分不清的,我特写了一篇
Transformer 架构里关于 Attention 概念的澄清

欢迎继续阅读《从零实现Transformer》系列!在第1篇:输入嵌入与位置编码 中,完成了输入预处理:将词元ID转化为融合语义信息与位置信息的向量。

接下来看注意力机制

本篇将聚焦实现《Attention Is All You Need》论文中定义的基础模块:缩放点积注意力 。该机制能让模型在处理序列中某个元素时,自主衡量序列内其他内容的关联权重。本文将搭建缩放点积注意力的机制,完整实现掩码操作与缩放计算逻辑。

本篇学习目标

  1. 理解注意力三大向量:查询Q、键K、值V的底层逻辑( Query, Key, Value);
  2. 分步拆解缩放点积注意力计算公式;
  3. 基于 PyTorch 编写可复用的注意力函数;
  4. 利用测试数据验证代码有效性。

一、注意力:查询、键、值(Q/K/V)

可以把注意力机制类比为在图书馆检索资料:
查询向量 Q :代表「当前想查找的信息」。针对序列里的某个单词,查询向量会提出问题:序列中哪些单词和我相关?
键向量 K :代表「信息的标签(label)与标识」。序列中每个单词都有专属键向量,用来描述:我包含哪类信息?
值向量 V:代表「真实的内容信息」。每个单词的值向量,就是最终会被提取、融合的有效内容。

注意力的运行逻辑

  1. 将当前单词的查询向量 ,与序列所有单词的键向量逐一比对(包含自身);
  2. 通过向量相似度计算关联分数,分数越高,代表两者关联性越强;
  3. 把原始分数转化为总和为1的概率权重;
  4. 用权重对所有值向量做加权求和;
  5. 最终输出的向量,就融合了序列中所有相关单词的上下文信息。

二、缩放点积注意力

论文图中1,2,3 用的都是 缩放点积注意力

Attention部分放大了看选一个 看 蓝色的3

原版 Transformer 采用的注意力形式即为缩放点积注意力,计算公式源自论文《Attention Is All You Need》3.2.1章节:

Attention ( Q , K , V ) = softmax ( Q ⋅ K ⊤ d k ) ⋅ V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q \cdot K^\top}{\sqrt{d_k}} \right) \cdot V Attention(Q,K,V)=softmax(dk Q⋅K⊤)⋅V

公式与图的对应关系

分步拆解公式:

1. 矩阵相乘: Q ⋅ K ⊤ Q \cdot K^\top Q⋅K⊤

举个例子就是

  • 计算每个查询向量Query所有键向量Key的点积相似度;
  • 维度规则:若 Q Q Q 形状为 [批次, 头数, 查询序列长度, 键维度], K K K 形状为 [批次, 头数, 键序列长度, 键维度],则键转置后完成矩阵运算;
  • 输出结果:注意力原始分数矩阵,矩阵中每个数值代表两个位置的关联度。

2. 缩放操作: ÷ d k \boldsymbol{\div \sqrt{d_k}} ÷dk

将分数除以键向量维度的平方根 进行缩放。

原因:当键向量维度 d k d_k dk 较大时,点积结果数值会急剧变大;

弊端:过大的数值输入 Softmax 后,梯度会趋近于0,导致模型训练困难;

作用:缩放可以限制点积数值范围,保证梯度稳定。

3. 可选:掩码 Mask

在执行 Softmax 归一化之前 ,通常需要添加掩码:

将需要屏蔽的位置,填充为极小负数(如 − ∞ -\infty −∞),经过 Softmax 后,这些位置的权重会趋近于0。

两种常

用掩码

  1. 填充掩码:屏蔽批次内补齐序列用的无效填充符,避免模型关注无意义占位字符;
  2. 前瞻掩码(解码器专用) :禁止当前位置读取序列后续内容,保证解码器自回归生成的特性(只能用前文预测下文)。

本文代码中,统一使用 -1e9 作为掩码填充值。

前瞻掩码(Look-Ahead Mask),也叫因果掩码(Causal Mask),是 Transformer 解码器(Decoder)中用于保证自回归生成的机制。

cpp 复制代码
位置:    1    2    3    4    5
     ┌─────────────────────┐
  1  │  0   1   1   1   1  │  ← 位置1只能看自己
  2  │  0   0   1   1   1  │  ← 位置2能看1,2
  3  │  0   0   0   1   1  │  ← 位置3能看1,2,3
  4  │  0   0   0   0   1  │
  5  │  0   0   0   0   0  │  ← 位置5能看全部
     └─────────────────────┘
     0=允许关注  1=屏蔽(未来位置)

未来位置的 attention score 被设为 -1e9

softmax 后这些位置的概率趋近于 0

加权求和时,未来位置的 value 几乎不参与计算

训练时:防止信息泄露

让模型学会"根据前n个词预测第n+1个词"

没有掩码:位置3可以同时看到位置4、5的词 → 作弊!

有掩码:位置3只能看到位置1、2、3 → 正确学习因果依赖

4. Softmax 归一化

在键向量维度上做 Softmax 运算,把原始分数转为概率分布;

最终得到注意力权重矩阵,所有权重数值总和为1,直观体现每个位置的关联占比。

5. 加权求和:权重 × 值矩阵

用归一化后的注意力权重,与值矩阵 V V V 相乘;

最终输出:融合全局上下文的特征向量,每个位置都结合了全序列的关联信息。

三、代码实现(PyTorch)

python 复制代码
import torch
import torch.nn as nn
import math

def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor = None,
    dropout: nn.Dropout = None
):
    """
    实现缩放点积注意力机制
    参数:
        query: 查询张量,形状 [批次大小, 多头数, 查询序列长度, 键维度]
        key: 键张量,形状 [批次大小, 多头数, 键序列长度, 键维度]
        value: 值张量,形状 [批次大小, 多头数, 值序列长度, 值维度]
        mask: 掩码张量,支持广播适配分数矩阵,True 代表需要屏蔽的位置
        dropout: 可选dropout层,用于注意力权重正则化
    返回:
        output: 注意力输出特征
        attention_weights: 注意力权重矩阵
    """
    # 获取Q/K的向量维度
    d_k = query.shape[-1]

    # 1. 计算Q与K转置的点积,得到原始注意力分数
    attention_scores = torch.matmul(query, key.transpose(-2, -1))

    # 2. 维度缩放
    attention_scores = attention_scores / math.sqrt(d_k)

    # 3. 应用掩码
    if mask is not None:
        attention_scores = attention_scores.masked_fill(mask == True, -1e9)

    # 4. Softmax归一化,得到注意力权重
    attention_weights = torch.softmax(attention_scores, dim=-1)

    # 5. 可选dropout
    if dropout is not None:
        attention_weights = dropout(attention_weights)

    # 6. 权重与值向量加权求和,得到最终输出
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

代码说明

  1. key.transpose(-2, -1):仅转置张量最后两个维度,适配批量矩阵运算;
  2. masked_fill:精准屏蔽无效位置,保证掩码逻辑生效;
  3. 函数同时返回输出特征注意力权重,方便后续可视化分析与调试。

四、代码测试验证

我们构造随机测试张量、搭建解码器前瞻掩码,完整验证函数功能:

python 复制代码
import torch

# 超参数定义
batch_size_test = 2    # 批次大小
num_heads_test = 8     # 注意力头数(下篇多头注意力会用到)
seq_len_q_test = 5     # 查询序列长度
seq_len_k_test = 7     # 键/值序列长度
d_k_test = 64 // num_heads_test
d_v_test = 64 // num_heads_test

# 构造随机测试张量
dummy_q = torch.randn(batch_size_test, num_heads_test, seq_len_q_test, d_k_test)
dummy_k = torch.randn(batch_size_test, num_heads_test, seq_len_k_test, d_k_test)
dummy_v = torch.randn(batch_size_test, num_heads_test, seq_len_k_test, d_v_test)

# 1. 无掩码测试
print("===== 无掩码测试 =====")
output_no_mask, weights_no_mask = scaled_dot_product_attention(dummy_q, dummy_k, dummy_v)
print(f"输出特征形状:{output_no_mask.shape}")
print(f"注意力权重形状:{weights_no_mask.shape}")
# 校验:权重每行总和近似为1
assert torch.allclose(weights_no_mask.sum(dim=-1), torch.ones_like(weights_no_mask.sum(dim=-1)))
print("无掩码测试通过\n")

# 2. 解码器前瞻掩码测试
mask_size = seq_len_q_test
# 生成上三角掩码(屏蔽未来位置)
look_ahead_mask = torch.triu(torch.ones(mask_size, mask_size), diagonal=1).bool()
# 拓展维度,支持广播运算
dummy_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)

# 适配自注意力:统一Q/K/V序列长度
dummy_k_masked = torch.randn(batch_size_test, num_heads_test, seq_len_q_test, d_k_test)
dummy_v_masked = torch.randn(batch_size_test, num_heads_test, seq_len_q_test, d_v_test)

print("===== 前瞻掩码测试 =====")
output_masked, weights_masked = scaled_dot_product_attention(dummy_q, dummy_k_masked, dummy_v_masked, mask=dummy_mask)
print(f"掩码形状:{dummy_mask.shape}")
print(f"掩码输出特征形状:{output_masked.shape}")
# 校验:掩码位置权重强制为0
assert torch.all(weights_masked.masked_select(dummy_mask) == 0)
print("掩码测试通过")

测试逻辑说明

  1. 无掩码场景:校验张量维度、注意力权重归一化效果;
  2. 掩码场景:利用上三角矩阵屏蔽未来时间步,模拟解码器自回归约束;
  3. 通过断言校验逻辑,确保掩码、缩放、Softmax 全部生效。

完整源码

cpp 复制代码
import torch
import torch.nn as nn
import math

def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    d_k = query.shape[-1]
    attention_scores = torch.matmul(query, key.transpose(-2, -1))
    attention_scores = attention_scores / math.sqrt(d_k)
    
    if mask is not None:
        attention_scores = attention_scores.masked_fill(mask == True, -1e9)
    
    attention_weights = torch.softmax(attention_scores, dim=-1)
    
    if dropout is not None:
        attention_weights = dropout(attention_weights)
    
    output = torch.matmul(attention_weights, value)
    return output, attention_weights

if __name__ == "__main__":
    batch_size_test = 2
    num_heads_test = 8
    seq_len_q_test = 5
    seq_len_k_test = 7
    d_k_test = 64 // num_heads_test
    d_v_test = 64 // num_heads_test

    dummy_q = torch.randn(batch_size_test, num_heads_test, seq_len_q_test, d_k_test)
    dummy_k = torch.randn(batch_size_test, num_heads_test, seq_len_k_test, d_k_test)
    dummy_v = torch.randn(batch_size_test, num_heads_test, seq_len_k_test, d_v_test)

    print("--- 无掩码测试 ---")
    output_no_mask, weights_no_mask = scaled_dot_product_attention(dummy_q, dummy_k, dummy_v)
    print(f"输出形状: {output_no_mask.shape}")
    print(f"权重形状: {weights_no_mask.shape}")
    assert torch.allclose(weights_no_mask.sum(dim=-1), torch.ones_like(weights_no_mask.sum(dim=-1)))
    print("无掩码测试通过")

    print("\n--- 前瞻掩码测试 ---")
    mask_size = seq_len_q_test
    look_ahead_mask = torch.triu(torch.ones(mask_size, mask_size), diagonal=1).bool()
    dummy_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)

    dummy_k_masked = torch.randn(batch_size_test, num_heads_test, seq_len_q_test, d_k_test)
    dummy_v_masked = torch.randn(batch_size_test, num_heads_test, seq_len_q_test, d_v_test)

    output_masked, weights_masked = scaled_dot_product_attention(dummy_q, dummy_k_masked, dummy_v_masked, mask=dummy_mask)
    print(f"掩码输出形状: {output_masked.shape}")
    assert torch.all(weights_masked.masked_select(dummy_mask) == 0)
    print("掩码测试通过")

输出

cpp 复制代码
--- 无掩码测试 ---
输出形状: torch.Size([2, 8, 5, 8])
权重形状: torch.Size([2, 8, 5, 7])
无掩码测试通过

--- 前瞻掩码测试 ---
掩码输出形状: torch.Size([2, 8, 5, 8])
掩码测试通过

至此,完成了 Transformer 最的缩放点积注意力 底层实现。

原论文证明:并行运行多组独立的注意力计算 效果会大幅提升------也就是多头注意力 。下一篇就是 实现多头注意力机制(Multi-Head Attention)

Transformer - 注意⼒机制 Scaled Dot-Product Attention 计算过程
Transformer - 注意⼒机制 代码实现
Transformer - 注意⼒机制 Scaled Dot-Product Attention不同的代码比较
Transformer - 注意⼒机制 代码解释
Transformer - 注意⼒机制 Attention 中的 Q, K, V 解释(1)
Transformer - 注意⼒机制 Attention 中的 Q, K, V 解释(2)
Transformer - 《Attention is All You Need》中的Scaled Dot-Product Attention,为什么要Scaled

注意力机制的输入

注意力机制的输入和输出

注意力机制的核心本质是什么?

扩展阅读

FlashInfer - SparseAttention(稀疏注意力)只计算部分有意义的注意力连接,而非全部 token 对

DeepSpeed-Ulysses 密集自注意力(Dense Self-Attention)和稀疏自注意力(Sparse Self-Attention)

相关推荐
小超同学你好6 小时前
OpenClaw 深度解析与源代码导读 · 第10篇:多 Agent 核心(agents.list、bindings 与隔离边界的可验证机制)
人工智能·深度学习·语言模型·transformer
机器学习之心7 小时前
IGWO-Transformer模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析(附MATLAB代码)
深度学习·回归·transformer·shap分析·igwo
code_pgf7 小时前
OpenPI / π₀ 系列算法详解、创新点及 Jetson Orin NX 16GB 边缘端部署
人工智能·transformer·agi·palm
qq_283720057 小时前
基于 Transformer,Python 搭建中文文本分类大模型:从零到一实现企业级文本分类
python·分类·transformer
AI技术增长8 小时前
Pytorch图像去噪实战(十一):Diffusion扩散模型去噪入门,从噪声预测理解生成式图像恢复
pytorch·深度学习·机器学习·cnn·transformer
AI木马人1 天前
2.人工智能实战:大模型接口并发低、GPU利用率上不去?基于 vLLM 重构推理服务的完整工程方案
人工智能·transformer·vllm
小超同学你好1 天前
Transformer 30. MoCo:用「动量编码器 + 队列字典」把对比学习做成可扩展的“字典查找”
深度学习·学习·transformer
yigan_Eins1 天前
Transformer|残差连接的技术演进:从CNN到ResNet
人工智能·深度学习·cnn·transformer
大江东去浪淘尽千古风流人物1 天前
【RT-1】面向真实世界规模化控制的机器人Transformer
深度学习·机器人·transformer