Transformers 架构核心原理:从注意力机制到 GPT
前言
Transformers 架构是现代大语言模型的基石。无论是 GPT 系列、LLaMA、还是 BERT,都基于 Transformer 的核心组件构建。理解它的原理,对于更好地使用和优化大模型至关重要。
我最初学习 Transformers 时,翻阅了大量论文和教程,但很多解释要么过于简略,要么陷入过多的数学细节。今天想用清晰的思路,把 Transformers 的核心组件和演进历程讲清楚。
注意力机制详解
注意力机制的起源
注意力机制最早出现在序列到序列模型中,用于解决 RNN 难以处理长序列的问题。其核心思想是:在生成每个输出时,模型应该"关注"输入序列的不同部分。
自注意力机制(Self-Attention)
自注意力是 Transformers 的核心创新。它允许序列中的每个位置关注序列中的所有其他位置:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert self.head_dim * heads == embed_size, "embed_size must be divisible by heads"
# QKV 投影
self.qkv = nn.Linear(embed_size, embed_size * 3)
# 输出投影
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
N, seq_len, _ = x.shape
# 线性变换得到 Q, K, V
qkv = self.qkv(x)
qkv = qkv.reshape(N, seq_len, 3, self.heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, N, heads, seq_len, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [q, k]) / math.sqrt(self.head_dim)
# energy: (N, heads, seq_len, seq_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# softmax 得到注意力权重
attention = F.softmax(energy, dim=-1)
# 加权求和
out = torch.einsum("nhql,nlhd->nqhd", [attention, v])
out = out.reshape(N, seq_len, self.heads * self.head_dim)
return self.fc_out(out)
多头注意力(Multi-Head Attention)
多头注意力让模型能够同时关注不同位置的不同表示子空间:
python
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads, dropout=0.1):
super().__init__()
self.attention = SelfAttention(embed_size, num_heads)
self.norm = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-LN 风格(现代常用)
x = self.norm(x)
attention_out = self.attention(x, mask)
return x + self.dropout(attention_out)
位置编码(Positional Encoding)
由于自注意力机制本身不包含位置信息,需要手动添加位置编码:
python
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# 创建位置编码
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, embed_size)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.shape[1], :]
return self.dropout(x)
Transformer 编码器
前馈网络(Feed Forward Network)
每个 Transformer 层还包含一个前馈网络:
python
class FFN(nn.Module):
def __init__(self, embed_size, ff_dim, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(embed_size, ff_dim)
self.linear2 = nn.Linear(ff_dim, embed_size)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, x):
return self.linear2(self.dropout(self.activation(self.linear1(x))))
完整编码器层
python
class TransformerEncoderLayer(nn.Module):
def __init__(self, embed_size, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(embed_size, num_heads, dropout)
self.ffn = FFN(embed_size, ff_dim, dropout)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差
x = x + self.dropout(self.attention(x, mask))
# FFN + 残差
x = x + self.dropout(self.ffn(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
def __init__(self, num_layers, embed_size, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
TransformerEncoderLayer(embed_size, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_size)
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
Transformer 解码器
因果掩码(Causal Mask)
解码器需要防止看到未来的信息:
python
def create_causal_mask(seq_len, device):
"""创建因果掩码,上三角为负无穷"""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
return mask.masked_fill(mask == 1, float("-inf"))
完整解码器层
python
class TransformerDecoderLayer(nn.Module):
def __init__(self, embed_size, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(embed_size, num_heads, dropout)
self.cross_attention = MultiHeadAttention(embed_size, num_heads, dropout)
self.ffn = FFN(embed_size, ff_dim, dropout)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.norm3 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# 自注意力(因果)
x = x + self.dropout(self.self_attention(self.norm1(x), tgt_mask))
# 交叉注意力(关注编码器输出)
x = x + self.dropout(self.cross_attention(self.norm2(x), encoder_output, src_mask))
# FFN
x = x + self.dropout(self.ffn(self.norm3(x)))
return x
GPT 架构详解
GPT 与 BERT 的区别
GPT(Generative Pre-trained Transformer)和 BERT 虽然都基于 Transformer,但架构上有重要区别:
| 特性 | GPT | BERT |
|---|---|---|
| 注意力 | 单向(因果) | 双向 |
| 预训练任务 | 下一词预测 | 掩码语言模型 |
| 适用场景 | 生成任务 | 理解任务 |
| 层数 | 通常更多 | 相对较少 |
GPT-2 架构实现
python
class GPT2Model(nn.Module):
def __init__(
self,
vocab_size,
embed_size,
num_heads,
num_layers,
ff_dim,
max_seq_len,
dropout=0.1
):
super().__init__()
# 词嵌入
self.token_embedding = nn.Embedding(vocab_size, embed_size)
self.position_embedding = PositionalEncoding(embed_size, max_seq_len, dropout)
# Transformer 解码器层
self.layers = nn.ModuleList([
TransformerDecoderLayer(embed_size, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_size)
self.head = nn.Linear(embed_size, vocab_size, bias=False)
# 权重绑定
self.head.weight = self.token_embedding.weight
def forward(self, x, targets=None):
# 词嵌入 + 位置编码
x = self.token_embedding(x)
x = self.position_embedding(x)
# 因果掩码
seq_len = x.shape[1]
causal_mask = create_causal_mask(seq_len, x.device)
# Transformer 层
for layer in self.layers:
x = layer(x, None, None, causal_mask)
x = self.norm(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
return {"loss": loss, "logits": logits}
演进:从 Transformer 到现代 LLM
关键技术演进
-
Post-LN → Pre-LN
- 原始 Transformer 使用 Post-LN(LayerNorm 在残差之后)
- 现代模型多使用 Pre-LN(LayerNorm 在残差之前),更稳定
-
固定位置编码 → RoPE
- RoPE(Rotary Position Embedding)成为主流
- 更好的外推能力,处理比训练时更长的序列
-
GELU 激活函数
- 替代原始的 ReLU,如 GELU(x) = x * Φ(x)
-
RMSNorm
- 更高效的归一化,减少计算量
GPT-3 的创新
GPT-3 引入了几个关键创新:
- Sparse Attention:不是所有 token 都两两关注,减少计算复杂度
- In-Context Learning:通过 prompt 中的示例学习新任务
- 更大的模型和数据集:1750 亿参数,3000 亿 token 训练数据
总结
Transformers 架构的核心是自注意力机制,它让模型能够灵活地关注输入序列的任意部分。多头注意力、位置编码、残差连接等技术共同构成了这个强大的架构。
理解这些底层原理,不仅能帮助我们更好地使用现有模型,还能为未来的优化和创新打下基础。