Transformer 核心模块详解:多头注意力、前馈网络与词嵌入

【学习记录】Transformer 核心模块详解:多头注意力、前馈网络与词嵌入

Transformer 是现代大语言模型的基石,而多头注意力(MultiHeadAttention)前馈网络(FFN)词嵌入(Embedding) 是其最核心的三个组件。本文从原理到代码,逐层拆解这三个模块,并提供 Python(PyTorch)和 C++(LibTorch)实现,附带完整的复杂度分析。


📌 目录

  1. MultiHeadAttention(多头注意力)
  2. FFN(前馈网络)
  3. Embedding(词嵌入)
  4. 三个模块的组合使用
  5. 复杂度总结

一、多头注意力(MultiHeadAttention)

1.1 作用

多头注意力机制允许模型同时关注输入序列中不同位置的不同表示子空间。它通过将查询(Q)、键(K)、值(V)线性映射到多个头,分别计算注意力,最后拼接并映射回原维度。

1.2 数学公式

标准缩放点积注意力:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

多头注意力:

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中 head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)。

1.3 代码实现(Python/PyTorch)

python 复制代码
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        # 1. 线性映射并拆分为多头
        Q = self.Wq(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.Wk(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.Wv(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 2. 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 3. 应用掩码(可选)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        # 4. 合并多头并输出
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.Wo(output)

1.4 图解(文本示意)

复制代码
输入: (B, T, D)
  │
  ├─→ 线性映射 Wq, Wk, Wv → (B, T, D)
  │
  ├─→ view + transpose → (B, n_head, T, d_k)
  │
  ├─→ scores = Q @ K^T / sqrt(d_k) → (B, n_head, T, T)
  │         │
  │         └─→ mask (可选) 填充 -1e9
  │
  ├─→ softmax → (B, n_head, T, T)
  │
  ├─→ output = attn @ V → (B, n_head, T, d_k)
  │
  ├─→ transpose + view → (B, T, D)
  │
  └─→ Wo 线性映射 → (B, T, D)

1.5 C++ 代码(LibTorch)

cpp 复制代码
#include <torch/torch.h>

class MultiHeadAttentionImpl : public torch::nn::Module {
public:
    int d_model, num_heads, d_k;
    torch::nn::Linear Wq, Wk, Wv, Wo;

    MultiHeadAttentionImpl(int d_model_, int num_heads_)
        : d_model(d_model_), num_heads(num_heads_),
          d_k(d_model_ / num_heads_),
          Wq(torch::nn::Linear(d_model, d_model)),
          Wk(torch::nn::Linear(d_model, d_model)),
          Wv(torch::nn::Linear(d_model, d_model)),
          Wo(torch::nn::Linear(d_model, d_model)) {
        register_module("Wq", Wq);
        register_module("Wk", Wk);
        register_module("Wv", Wv);
        register_module("Wo", Wo);
    }

    torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor mask = {}) {
        int batch_size = Q.size(0);
        // 线性映射
        Q = Wq->forward(Q).view({batch_size, -1, num_heads, d_k}).transpose(1, 2);
        K = Wk->forward(K).view({batch_size, -1, num_heads, d_k}).transpose(1, 2);
        V = Wv->forward(V).view({batch_size, -1, num_heads, d_k}).transpose(1, 2);

        // 注意力分数
        auto scores = torch::matmul(Q, K.transpose(-2, -1)) / std::sqrt(d_k);
        if (mask.defined()) {
            scores = scores.masked_fill(mask == 0, -1e9);
        }
        auto attn = torch::softmax(scores, -1);
        auto output = torch::matmul(attn, V);

        output = output.transpose(1, 2).contiguous().view({batch_size, -1, d_model});
        return Wo->forward(output);
    }
};
TORCH_MODULE(MultiHeadAttention);

1.6 复杂度分析

操作 时间复杂度 空间复杂度
线性映射 (Q,K,V) O(B×T×D²) O(B×T×D)
拆分多头 O(B×T×D) O(B×n_head×T×d_k)
分数矩阵乘法 O(B×n_head×T²×d_k) O(B×n_head×T²)
Softmax O(B×n_head×T²) O(B×n_head×T²)
加权求和 O(B×n_head×T²×d_k) O(B×n_head×T×d_k)
合并与输出映射 O(B×T×D²) O(B×T×D)
总计 O(B × T² × D) O(B × n_head × T²)

其中 D = d_model, d_k = D / n_head


二、前馈网络(FFN)

2.1 作用

FFN 对每个位置独立进行非线性变换,增加模型表达能力。标准结构:线性 → ReLU → 线性 ,通常中间维度 d_ffd_model 的 4 倍左右。

2.2 数学公式

FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 FFN(x)=ReLU(xW1+b1)W2+b2

2.3 代码实现(Python/PyTorch)

python 复制代码
class FFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.activation(self.linear1(x)))

2.4 图解

复制代码
输入 (B, T, D)
   │
   ├─→ linear1 (D → d_ff) → (B, T, d_ff)
   │
   ├─→ ReLU → (B, T, d_ff)
   │
   └─→ linear2 (d_ff → D) → (B, T, D)

2.5 C++ 代码(LibTorch)

cpp 复制代码
class FFNImpl : public torch::nn::Module {
public:
    torch::nn::Linear linear1, linear2;
    FFNImpl(int d_model, int d_ff)
        : linear1(d_model, d_ff), linear2(d_ff, d_model) {
        register_module("linear1", linear1);
        register_module("linear2", linear2);
    }
    torch::Tensor forward(torch::Tensor x) {
        return linear2->forward(torch::relu(linear1->forward(x)));
    }
};
TORCH_MODULE(FFN);

2.6 复杂度分析

操作 时间复杂度 空间复杂度
linear1 O(B × T × D × d_ff) O(B × T × d_ff)
ReLU O(B × T × d_ff) O(B × T × d_ff)
linear2 O(B × T × d_ff × D) O(B × T × D)
总计 O(B × T × D × d_ff) O(B × T × max(D, d_ff))

d_ff = 4 × D 时,复杂度约为 O(4 × B × T × D²)


三、词嵌入(Embedding)

3.1 作用

将离散的 token ID 序列映射为稠密的连续向量,并乘以 √d_model 进行缩放,以便与位置编码相加时尺度匹配。

3.2 代码实现(Python/PyTorch)

python 复制代码
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

3.3 图解

复制代码
输入: (B, T)  token IDs [ [1, 3, 0, ...] ]
   │
   └─→ nn.Embedding 查表 (vocab_size × D)
         │
         └─→ 输出 (B, T, D)
              │
              └─→ 乘以 √D → (B, T, D)

3.4 C++ 代码(LibTorch)

cpp 复制代码
class EmbeddingImpl : public torch::nn::Module {
public:
    torch::nn::Embedding embedding;
    int d_model;
    EmbeddingImpl(int vocab_size, int d_model_)
        : embedding(vocab_size, d_model_), d_model(d_model_) {
        register_module("embedding", embedding);
    }
    torch::Tensor forward(torch::Tensor x) {
        return embedding->forward(x) * std::sqrt(d_model);
    }
};
TORCH_MODULE(Embedding);

3.5 复杂度分析

操作 时间复杂度 空间复杂度
查表 O(B × T) O(B × T × D)
乘法 O(B × T × D) O(B × T × D)
总计 O(B × T × D) O(B × T × D)

四、三个模块的组合使用

一个完整的 Transformer 编码器层通常由 多头注意力 + 残差连接 + 层归一化FFN + 残差连接 + 层归一化 构成。

python 复制代码
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # 自注意力 + 残差 + 层归一化
        attn_out = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn_out)
        # FFN + 残差 + 层归一化
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

完整流程示例

python 复制代码
vocab_size = 10000
d_model = 512
num_heads = 8
d_ff = 2048
batch_size = 2
seq_len = 10

# 输入 token IDs
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

# 嵌入层
embed = Embedding(vocab_size, d_model)
x = embed(input_ids)                     # (2,10,512)

# 位置编码(此处略,可加上)
# pos_enc = PositionalEncoding(d_model)
# x = pos_enc(x)

# Transformer 编码器层
encoder_layer = TransformerEncoderLayer(d_model, num_heads, d_ff)
output = encoder_layer(x)               # (2,10,512)

print(output.shape)  # torch.Size([2, 10, 512])

五、复杂度总结

模块 时间复杂度 空间复杂度 说明
MultiHeadAttention O(B × T² × D) O(B × n_head × T²) 核心瓶颈在 T²,长序列需优化
FFN O(B × T × D × d_ff) O(B × T × max(D, d_ff)) 通常 d_ff = 4D,复杂度约为 4×
Embedding O(B × T × D) O(B × T × D) 查表操作,轻量级

优化建议

  • 对于长序列(T 很大),可使用稀疏注意力(如 FlashAttention)降低 T² 复杂度。
  • FFN 的中间维度 d_ff 越大模型容量越大,但计算量线性增加。
  • 嵌入层占参数量主要为 vocab_size × D,大词表时需考虑参数共享或压缩。

🎯 总结

本文详细拆解了 Transformer 的三个核心模块:

  1. 多头注意力:让模型关注不同位置的多种关系,是 Transformer 成功的核心。
  2. 前馈网络:提供非线性变换,增强模型表达能力。
  3. 词嵌入:将离散符号映射到连续空间,是深度学习处理文本的起点。

通过理解这些模块的输入输出、形状变化和复杂度,能轻松搭建并优化自己的 Transformer 模型。

相关推荐
灰灰勇闯IT3 小时前
catlass:昇腾NPU上的算子模板库
人工智能
桜吹雪3 小时前
所有智能体架构(2):ReAct(推理 + 行动)
人工智能
动物园猫3 小时前
面向智慧牧场的牛行为识别数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·分类
埃菲尔铁塔_CV算法3 小时前
YOLO11 与传统纹理特征融合目标检测 完整实现教程
人工智能·神经网络·yolo·计算机视觉
快乐的哈士奇3 小时前
LangFuse 自托管实战:选型理由、Docker 部署与常用配置全解析
运维·人工智能·docker·容器
数智化管理手记3 小时前
精益生产3步实操,让现场从混乱变标杆
大数据·运维·网络·人工智能·精益工程
百度Geek说3 小时前
PRD → Goal → After-Goal:AI 主导全流程研发实践
人工智能
山西茄子3 小时前
DeepStream9.0 在DeepStream中使用VLM
人工智能
小小测试开发3 小时前
AI 水印攻防战:OpenAI 引入 SynthID 认证,GitHub 同步出现去水印工具
人工智能·github