【学习记录】Transformer 核心模块详解:多头注意力、前馈网络与词嵌入
Transformer 是现代大语言模型的基石,而多头注意力(MultiHeadAttention) 、前馈网络(FFN) 和词嵌入(Embedding) 是其最核心的三个组件。本文从原理到代码,逐层拆解这三个模块,并提供 Python(PyTorch)和 C++(LibTorch)实现,附带完整的复杂度分析。
📌 目录
一、多头注意力(MultiHeadAttention)
1.1 作用
多头注意力机制允许模型同时关注输入序列中不同位置的不同表示子空间。它通过将查询(Q)、键(K)、值(V)线性映射到多个头,分别计算注意力,最后拼接并映射回原维度。
1.2 数学公式
标准缩放点积注意力:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
多头注意力:
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中 head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)。
1.3 代码实现(Python/PyTorch)
python
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
self.Wo = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 1. 线性映射并拆分为多头
Q = self.Wq(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.Wk(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.Wv(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 3. 应用掩码(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
# 4. 合并多头并输出
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.Wo(output)
1.4 图解(文本示意)
输入: (B, T, D)
│
├─→ 线性映射 Wq, Wk, Wv → (B, T, D)
│
├─→ view + transpose → (B, n_head, T, d_k)
│
├─→ scores = Q @ K^T / sqrt(d_k) → (B, n_head, T, T)
│ │
│ └─→ mask (可选) 填充 -1e9
│
├─→ softmax → (B, n_head, T, T)
│
├─→ output = attn @ V → (B, n_head, T, d_k)
│
├─→ transpose + view → (B, T, D)
│
└─→ Wo 线性映射 → (B, T, D)
1.5 C++ 代码(LibTorch)
cpp
#include <torch/torch.h>
class MultiHeadAttentionImpl : public torch::nn::Module {
public:
int d_model, num_heads, d_k;
torch::nn::Linear Wq, Wk, Wv, Wo;
MultiHeadAttentionImpl(int d_model_, int num_heads_)
: d_model(d_model_), num_heads(num_heads_),
d_k(d_model_ / num_heads_),
Wq(torch::nn::Linear(d_model, d_model)),
Wk(torch::nn::Linear(d_model, d_model)),
Wv(torch::nn::Linear(d_model, d_model)),
Wo(torch::nn::Linear(d_model, d_model)) {
register_module("Wq", Wq);
register_module("Wk", Wk);
register_module("Wv", Wv);
register_module("Wo", Wo);
}
torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor mask = {}) {
int batch_size = Q.size(0);
// 线性映射
Q = Wq->forward(Q).view({batch_size, -1, num_heads, d_k}).transpose(1, 2);
K = Wk->forward(K).view({batch_size, -1, num_heads, d_k}).transpose(1, 2);
V = Wv->forward(V).view({batch_size, -1, num_heads, d_k}).transpose(1, 2);
// 注意力分数
auto scores = torch::matmul(Q, K.transpose(-2, -1)) / std::sqrt(d_k);
if (mask.defined()) {
scores = scores.masked_fill(mask == 0, -1e9);
}
auto attn = torch::softmax(scores, -1);
auto output = torch::matmul(attn, V);
output = output.transpose(1, 2).contiguous().view({batch_size, -1, d_model});
return Wo->forward(output);
}
};
TORCH_MODULE(MultiHeadAttention);
1.6 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 线性映射 (Q,K,V) | O(B×T×D²) | O(B×T×D) |
| 拆分多头 | O(B×T×D) | O(B×n_head×T×d_k) |
| 分数矩阵乘法 | O(B×n_head×T²×d_k) | O(B×n_head×T²) |
| Softmax | O(B×n_head×T²) | O(B×n_head×T²) |
| 加权求和 | O(B×n_head×T²×d_k) | O(B×n_head×T×d_k) |
| 合并与输出映射 | O(B×T×D²) | O(B×T×D) |
| 总计 | O(B × T² × D) | O(B × n_head × T²) |
其中
D = d_model,d_k = D / n_head。
二、前馈网络(FFN)
2.1 作用
FFN 对每个位置独立进行非线性变换,增加模型表达能力。标准结构:线性 → ReLU → 线性 ,通常中间维度 d_ff 是 d_model 的 4 倍左右。
2.2 数学公式
FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 FFN(x)=ReLU(xW1+b1)W2+b2
2.3 代码实现(Python/PyTorch)
python
class FFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.activation = nn.ReLU()
def forward(self, x):
return self.linear2(self.activation(self.linear1(x)))
2.4 图解
输入 (B, T, D)
│
├─→ linear1 (D → d_ff) → (B, T, d_ff)
│
├─→ ReLU → (B, T, d_ff)
│
└─→ linear2 (d_ff → D) → (B, T, D)
2.5 C++ 代码(LibTorch)
cpp
class FFNImpl : public torch::nn::Module {
public:
torch::nn::Linear linear1, linear2;
FFNImpl(int d_model, int d_ff)
: linear1(d_model, d_ff), linear2(d_ff, d_model) {
register_module("linear1", linear1);
register_module("linear2", linear2);
}
torch::Tensor forward(torch::Tensor x) {
return linear2->forward(torch::relu(linear1->forward(x)));
}
};
TORCH_MODULE(FFN);
2.6 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| linear1 | O(B × T × D × d_ff) | O(B × T × d_ff) |
| ReLU | O(B × T × d_ff) | O(B × T × d_ff) |
| linear2 | O(B × T × d_ff × D) | O(B × T × D) |
| 总计 | O(B × T × D × d_ff) | O(B × T × max(D, d_ff)) |
当
d_ff = 4 × D时,复杂度约为O(4 × B × T × D²)。
三、词嵌入(Embedding)
3.1 作用
将离散的 token ID 序列映射为稠密的连续向量,并乘以 √d_model 进行缩放,以便与位置编码相加时尺度匹配。
3.2 代码实现(Python/PyTorch)
python
class Embedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
3.3 图解
输入: (B, T) token IDs [ [1, 3, 0, ...] ]
│
└─→ nn.Embedding 查表 (vocab_size × D)
│
└─→ 输出 (B, T, D)
│
└─→ 乘以 √D → (B, T, D)
3.4 C++ 代码(LibTorch)
cpp
class EmbeddingImpl : public torch::nn::Module {
public:
torch::nn::Embedding embedding;
int d_model;
EmbeddingImpl(int vocab_size, int d_model_)
: embedding(vocab_size, d_model_), d_model(d_model_) {
register_module("embedding", embedding);
}
torch::Tensor forward(torch::Tensor x) {
return embedding->forward(x) * std::sqrt(d_model);
}
};
TORCH_MODULE(Embedding);
3.5 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 查表 | O(B × T) | O(B × T × D) |
| 乘法 | O(B × T × D) | O(B × T × D) |
| 总计 | O(B × T × D) | O(B × T × D) |
四、三个模块的组合使用
一个完整的 Transformer 编码器层通常由 多头注意力 + 残差连接 + 层归一化 和 FFN + 残差连接 + 层归一化 构成。
python
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FFN(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# 自注意力 + 残差 + 层归一化
attn_out = self.self_attn(x, x, x, mask)
x = self.norm1(x + attn_out)
# FFN + 残差 + 层归一化
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
完整流程示例
python
vocab_size = 10000
d_model = 512
num_heads = 8
d_ff = 2048
batch_size = 2
seq_len = 10
# 输入 token IDs
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
# 嵌入层
embed = Embedding(vocab_size, d_model)
x = embed(input_ids) # (2,10,512)
# 位置编码(此处略,可加上)
# pos_enc = PositionalEncoding(d_model)
# x = pos_enc(x)
# Transformer 编码器层
encoder_layer = TransformerEncoderLayer(d_model, num_heads, d_ff)
output = encoder_layer(x) # (2,10,512)
print(output.shape) # torch.Size([2, 10, 512])
五、复杂度总结
| 模块 | 时间复杂度 | 空间复杂度 | 说明 |
|---|---|---|---|
| MultiHeadAttention | O(B × T² × D) | O(B × n_head × T²) | 核心瓶颈在 T²,长序列需优化 |
| FFN | O(B × T × D × d_ff) | O(B × T × max(D, d_ff)) | 通常 d_ff = 4D,复杂度约为 4× |
| Embedding | O(B × T × D) | O(B × T × D) | 查表操作,轻量级 |
优化建议:
- 对于长序列(T 很大),可使用稀疏注意力(如 FlashAttention)降低 T² 复杂度。
- FFN 的中间维度 d_ff 越大模型容量越大,但计算量线性增加。
- 嵌入层占参数量主要为
vocab_size × D,大词表时需考虑参数共享或压缩。
🎯 总结
本文详细拆解了 Transformer 的三个核心模块:
- 多头注意力:让模型关注不同位置的多种关系,是 Transformer 成功的核心。
- 前馈网络:提供非线性变换,增强模型表达能力。
- 词嵌入:将离散符号映射到连续空间,是深度学习处理文本的起点。
通过理解这些模块的输入输出、形状变化和复杂度,能轻松搭建并优化自己的 Transformer 模型。