2026年,Transformer架构仍然是现代深度学习最核心的组件之一。从BERT、GPT系列到LLaMA、Qwen、Grok,几乎所有前沿大模型都建立在Transformer(或其变体)之上。本文将使用纯PyTorch,从零开始实现一个标准的Encoder-Decoder Transformer架构(即原始《Attention is All You Need》论文中的结构)。
目标:代码清晰、结构模块化、注释详尽、可直接用于序列到序列任务(如机器翻译、文本摘要等)。
1. 整体架构概览
标准的Transformer包含以下主要部分:
- 输入嵌入 + 位置编码(Input Embedding + Positional Encoding)
- Encoder(N层)
- Multi-Head Self-Attention
- Feed-Forward Network
- Add & Norm(残差连接 + LayerNorm)
- Decoder(N层)
- Masked Multi-Head Self-Attention(因果掩码)
- Multi-Head Cross-Attention(Encoder-Decoder Attention)
- Feed-Forward Network
- Add & Norm
- 最终线性层 + Softmax
我们将逐层实现这些模块。
-
完整代码实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass PositionalEncoding(nn.Module):
"""经典的正弦位置编码"""
def init(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().init()
self.dropout = nn.Dropout(p=dropout)# 计算位置编码 position = torch.arange(max_len).unsqueeze(1) # [max_len, 1] div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) # [d_model/2] pe = torch.zeros(max_len, d_model) # [max_len, d_model] pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) # 注册为buffer(不参与梯度计算,但会随模型保存/加载) self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model] def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [batch_size, seq_len, d_model] x = x + self.pe[:, :x.size(1), :] # 广播相加 return self.dropout(x)class MultiHeadAttention(nn.Module):
"""多头注意力(包含三种用法:self-attn, masked self-attn, cross-attn)"""
def init(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().init()
assert d_model % num_heads == 0self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # Q, K, V 的线性投影 + 输出投影 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor = None) -> tuple: """ query, key, value : [batch, seq_len, d_model] attn_mask : [batch, seq_len_q, seq_len_k] 或 [batch*heads, seq_len_q, seq_len_k] """ batch_size = query.size(0) # 线性变换 + 分头 Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(key) .view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(value) .view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # Q,K,V now: [batch, heads, seq_len, d_k] # 缩放点积注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # [b, h, q_len, k_len] if attn_mask is not None: # 保证mask形状匹配 [b, h, q_len, k_len] if attn_mask.dim() == 3: attn_mask = attn_mask.unsqueeze(1) scores = scores.masked_fill(attn_mask == 0, -1e9) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) context = torch.matmul(attn, V) # [b, h, q_len, d_k] # 合并多头 & 输出投影 context = context.transpose(1, 2).contiguous()\ .view(batch_size, -1, self.d_model) output = self.W_o(context) return output, attnclass PositionWiseFeedForward(nn.Module):
"""逐位置前馈网络(两层线性 + ReLU)"""
def init(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().init()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)def forward(self, x): return self.linear2(self.dropout(F.relu(self.linear1(x))))class EncoderLayer(nn.Module):
"""单层Encoder"""
def init(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().init()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Self-Attention + Residual + Norm attn_output, _ = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # Feed-Forward + Residual + Norm ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return xclass DecoderLayer(nn.Module):
"""单层Decoder"""
def init(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().init()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask=None, tgt_mask=None): # Masked Self-Attention self_attn_out, _ = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(self_attn_out)) # Cross-Attention (query来自decoder, key/value来自encoder) cross_attn_out, _ = self.cross_attn(x, enc_output, enc_output, src_mask) x = self.norm2(x + self.dropout(cross_attn_out)) # Feed-Forward ff_out = self.feed_forward(x) x = self.norm3(x + self.dropout(ff_out)) return xclass Transformer(nn.Module):
"""
完整的Encoder-Decoder Transformer
参数示例:d_model=512, num_heads=8, num_layers=6, d_ff=2048
"""
def init(self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 5000,
dropout: float = 0.1):
super().init()self.src_embed = nn.Embedding(src_vocab_size, d_model) self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model) self.pos_enc = PositionalEncoding(d_model, max_seq_len, dropout) self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.decoder_layers = nn.ModuleList([ DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.final_linear = nn.Linear(d_model, tgt_vocab_size) self.d_model = d_model def forward(self, src, tgt, src_mask=None, tgt_mask=None): # 嵌入 + 位置编码 src = self.src_embed(src) * math.sqrt(self.d_model) src = self.pos_enc(src) tgt = self.tgt_embed(tgt) * math.sqrt(self.d_model) tgt = self.pos_enc(tgt) # Encoder memory = src for layer in self.encoder_layers: memory = layer(memory, src_mask) # Decoder output = tgt for layer in self.decoder_layers: output = layer(output, memory, src_mask, tgt_mask) # 最后投影到词表 output = self.final_linear(output) return output -
快速使用示例(机器翻译伪代码)
假设已经准备好数据
model = Transformer(
src_vocab_size=30000,
tgt_vocab_size=30000,
d_model=512,
num_heads=8,
num_layers=6,
d_ff=2048,
dropout=0.1
).cuda()生成因果掩码(很重要!)
def generate_square_subsequent_mask(sz: int):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
output = model(src, tgt, src_mask=None, tgt_mask=tgt_mask) # [b, tgt_len, vocab]
4. 几点建议
- 实际项目中建议加上 label smoothing 、learning rate warmup 、Noam scheduler
- 现代实现通常还会加入 pre-norm (先LayerNorm再sub-layer)、SwiGLU 激活函数、RoPE 位置编码等改进
- 如果你只想做Decoder-only(如GPT系),可以直接删掉Encoder部分,src相关代码全部移除
希望这篇实现能帮你在从零理解Transformer的道路上再往前一步。
如果你想继续讨论以下任一方向,欢迎留言:
- 添加RoPE位置编码
- 实现Grouped-Query Attention / Multi-Query Attention
- 做Decoder-only版本(GPT-like)
- 性能优化技巧(torch.compile、flash attention等)
5. 向现代工业级Transformer演进:最值得优先实现的几项改动(2026视角)
上一节提到的几点只是起点。在实际从头训练或微调大模型时,以下几项改动带来的收益通常是最大的(按性价比排序):
5.1 Pre-LayerNorm(几乎必做)
Post-LN 在层数超过 ~24 层后训练极不稳定,而 Pre-LN 让 100+ 层的模型也能相对容易收敛。
改动方式(以 EncoderLayer 为例,Decoder 同理):
class EncoderLayer(nn.Module):
def __init__(self, ...):
...
# 保持两个 norm,但位置前置
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Pre-LN 写法(最常见现代风格)
residual = x
x = self.norm1(x) # 先归一化
x = self.self_attn(x, x, x, mask)[0] # attention
x = residual + self.dropout(x) # 残差
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
x = residual + self.dropout(x)
return x
额外建议:很多团队还会把 dropout 放在 sublayer 内部(attention 和 ff 内部),而不是加在残差路径上,进一步减少方差。
5.2 激活函数升级:SwiGLU(强烈推荐)
ReLU → GELU → SwiGLU 是过去几年最显著的单点收益之一。 SwiGLU 在同等计算量下通常能带来 1--3% 的 perplexity 提升,且几乎不增加推理延迟。
两种常见写法(推荐第二种,更清晰):
# 写法1:直接在 Linear 后做 split + silu
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff * 2)
self.w2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.w1(x)
x, gate = x.chunk(2, dim=-1)
return self.w2(x * F.silu(gate))
# 写法2:更贴近 LLaMA / Qwen 的习惯(分开投影)
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False) # gate
self.w3 = nn.Linear(d_model, d_ff, bias=False) # value
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
后者更常见,因为 w1 和 w3 可以独立初始化,且 bias=False 是当前主流做法。
5.3 位置编码:替换为 RoPE(中长期上下文必备)
正弦位置编码在 4k 上下文还凑合,但超过 8k--16k 后外推能力急剧下降。 RoPE(Rotary Position Embedding)已成为 2023 年后几乎所有高质量开源模型的标配。
最简洁可用的 1D-RoPE 实现(训练 + 推理都适用):
def get_rotary_freqs(dim: int, max_position: int = 2048, base: float = 10000.0):
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
seq_idx = torch.arange(max_position).float()
freqs = torch.outer(seq_idx, theta)
freqs_complex = torch.polar(torch.ones_like(freqs), freqs) # cos + i sin
return freqs_complex # [max_position, dim//2]
def apply_rotary(x: torch.Tensor, freqs_cis):
# x: [bs, seq, heads, head_dim]
# freqs_cis: [seq, head_dim//2] complex
x_ = x.float().reshape(*x.shape[:-1], -1, 2) # split to pairs
x_complex = torch.view_as_complex(x_)
rotated = x_complex * freqs_cis.unsqueeze(0).unsqueeze(2) # broadcast
x_rot = torch.view_as_real(rotated).flatten(-2)
return x_rot.type_as(x)
在 MultiHeadAttention 中插入:
# 在计算 Q、K 后、matmul 前应用
Q = apply_rotary(Q, freqs_cis)
K = apply_rotary(K, freqs_cis)
进阶技巧(如果想支持更长上下文):
- 直接用 NTK-aware scaling(直接改 base 值)
- YaRN / PI / LongRoPE 等外推方案(需额外代码)
5.4 注意力效率:GQA / MQA + torch sdpa
现代推理最关心的指标之一是 TTFT(time to first token) 和 tokens/s。
- MQA:1 个 KV head,推理最快,但质量有损失
- GQA:8--16 个 KV head,质量接近 MHA,推理速度大幅提升(LLaMA-3、Qwen2、Mistral、Gemma 等都用这个)
PyTorch 2.1+ 提供了原生支持:
# 在 MultiHeadAttention forward 中替换原版 matmul + softmax
output = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=(tgt_mask is not None), # 自动生成 causal mask
)
开启 flash attention(需 torch>=2.2 + GPU 支持):
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(...)
6. 小结:推荐的"2026 年入门现代 Transformer" 配置清单
- 整体架构:Decoder-only(去掉 Encoder 和 cross-attention)
- Normalization:Pre-LN
- 激活函数:SwiGLU(bias=False)
- 位置编码:RoPE(base=10000 或 NTK 调整)
- Attention:GQA(KV heads = 8 或 16)
- 训练超参:label smoothing 0.1、cosine decay、warmup 3000--10000 步
- 推理加速:torch.compile + F.scaled_dot_product_attention(flash 模式)