Transformer前世今生——使用pytorch实现多头注意力(八)

随着AI市场,生成模型的投资热,小编在这里也开了一个Transformer的讲解系列,因为目前主流的大模型其核心都是Transformer,attention is all you need.本系列将介绍Transformer的算法原理以及使用pytorch的实现.

本节我们要学习如何使用pytorch实现多头注意力。

用 PyTorch 实现多头注意力(Multi-Head Attention)


🧠 一、概念简介:我们要实现什么?

在本教程中,我们将一步步用 PyTorch 实现:

  1. 一个通用的 Attention 类

    → 可执行三种注意力:

    • 自注意力(Self-Attention)
    • 掩蔽自注意力(Masked Self-Attention)
    • 编码器-解码器注意力(Encoder-Decoder Attention)
  2. 一个 Multi-Head Attention 类

    → 模拟 Transformer 中多个"头"同时计算注意力的机制。


🧩 二、导入 PyTorch 模块

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
torch: 提供张量和计算操作

torch.nn: 提供网络模块(nn.Module, nn.Linear)

torch.nn.functional: 提供函数接口(如 F.softmax())

🏗️ 三、实现 Attention 类

该类可执行三种注意力机制:

Self-Attention

Masked Self-Attention

Encoder-Decoder Attention

python 复制代码
class Attention(nn.Module): 
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()

        # 定义线性变换矩阵 W_Q, W_K, W_V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        # 1️⃣ 生成 Q, K, V
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # 2️⃣ 计算相似度矩阵 Q × Kᵀ
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # 3️⃣ 缩放(Scale)
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        # 4️⃣ 掩蔽(Masked Attention)
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        # 5️⃣ Softmax 归一化
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # 6️⃣ 加权求和输出
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

🧩 核心思想:

使用 Query、Key、Value 三个矩阵计算注意力权重:

Attention(Q,K,V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V

]

QKᵀ 计算相关性,Softmax 转为权重,乘以 V 得出上下文表示。

🧮 四、测试 Encoder-Decoder Attention

4.1三组输入编码(Query, Key, Value)
python 复制代码
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])
encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])
encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

torch.manual_seed(42)
4.2创建 Attention 对象
python 复制代码
attention = Attention(d_model=2, row_dim=0, col_dim=1)
4.3计算 Encoder-Decoder Attention
python 复制代码
print(attention(encodings_for_q, encodings_for_k, encodings_for_v))
输出结果:

lua
Copy code
tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]])

成功计算 Encoder-Decoder Attention!

🧠 五、实现 Multi-Head Attention 类

多头注意力(Multi-Head Attention) 的原理:

同时创建多个独立的 Attention 头,每个头学习不同的注意力模式,最后将结果拼接(concatenate)在一起。

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1, num_heads=1):
        super().__init__()

        # 创建多个注意力头(Attention 实例)
        self.heads = nn.ModuleList(
            [Attention(d_model, row_dim, col_dim) for _ in range(num_heads)]
        )

        self.col_dim = col_dim

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):
        # 将输入依次传入各个 head,并拼接结果
        return torch.cat(
            [head(encodings_for_q, encodings_for_k, encodings_for_v)
             for head in self.heads],
            dim=self.col_dim
        )

🔢 六、验证 Multi-Head Attention

(1)单头测试

python 复制代码
torch.manual_seed(42)
multiHeadAttention = MultiHeadAttention(d_model=2, row_dim=0, col_dim=1, num_heads=1)
print(multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v))
输出:

lua
Copy code
tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]])
结果与单个 Attention 一致 ✅

(2)双头测试

python 复制代码
torch.manual_seed(42)
multiHeadAttention = MultiHeadAttention(d_model=2, row_dim=0, col_dim=1, num_heads=2)
print(multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v))
输出:

lua
Copy code
tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]])
✅ 输出列数翻倍(每个头输出 2 个值),

说明两个注意力头成功并行计算。

七.核心思想回顾

多头注意力 = 多个独立的 Attention + 拼接输出

每个头学习不同的语义模式,如:

一个头关注短距离关系;

一个头关注句法结构;

一个头关注情感或上下文。

最终通过线性层组合结果,让模型能从多个视角理解输入。

一句话总结

多头注意力让模型拥有多双"眼睛",

每个头专注不同的语义视角,

最终融合成更全面的理解 ------

这就是 Transformer 的强大之处。

📘 参考资料

Vaswani et al. (2017) Attention Is All You Need

Josh Starmer: Coding Attention in PyTorch

Jay Alammar: Illustrated Transformer

《动手学深度学习》第10章 注意力机制

相关推荐
l1t2 小时前
利用DeepSeek改写SQLite版本的二进制位数独求解SQL
数据库·人工智能·sql·sqlite
说私域2 小时前
开源AI智能名片链动2+1模式S2B2C商城小程序FAQ设计及其意义探究
人工智能·小程序
开利网络3 小时前
合规底线:健康产品营销的红线与避坑指南
大数据·前端·人工智能·云计算·1024程序员节
非著名架构师3 小时前
量化“天气风险”:金融与保险机构如何利用气候大数据实现精准定价与投资决策
大数据·人工智能·新能源风光提高精度·疾风气象大模型4.0
巫婆理发2224 小时前
评估指标+数据不匹配+贝叶斯最优误差(分析方差和偏差)+迁移学习+多任务学习+端到端深度学习
深度学习·学习·迁移学习
熙梦数字化4 小时前
2025汽车零部件行业数字化转型落地方案
大数据·人工智能·汽车
刘海东刘海东4 小时前
逻辑方程结构图语言的机器实现(草稿)
人工智能
亮剑20184 小时前
第2节:程序逻辑与控制流——让程序“思考”
开发语言·c++·人工智能
hixiong1234 小时前
C# OpenCVSharp使用 读光-票证检测矫正模型
人工智能·opencv·c#