使用PyTorch创建一个标准的Transformer架构

2026年,Transformer架构仍然是现代深度学习最核心的组件之一。从BERT、GPT系列到LLaMA、Qwen、Grok,几乎所有前沿大模型都建立在Transformer(或其变体)之上。本文将使用纯PyTorch,从零开始实现一个标准的Encoder-Decoder Transformer架构(即原始《Attention is All You Need》论文中的结构)。

目标:代码清晰、结构模块化、注释详尽、可直接用于序列到序列任务(如机器翻译、文本摘要等)。

1. 整体架构概览

标准的Transformer包含以下主要部分:

  • 输入嵌入 + 位置编码(Input Embedding + Positional Encoding)
  • Encoder(N层)
    • Multi-Head Self-Attention
    • Feed-Forward Network
    • Add & Norm(残差连接 + LayerNorm)
  • Decoder(N层)
    • Masked Multi-Head Self-Attention(因果掩码)
    • Multi-Head Cross-Attention(Encoder-Decoder Attention)
    • Feed-Forward Network
    • Add & Norm
  • 最终线性层 + Softmax

我们将逐层实现这些模块。

  1. 完整代码实现

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class PositionalEncoding(nn.Module):
    """经典的正弦位置编码"""
    def init(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
    super().init()
    self.dropout = nn.Dropout(p=dropout)

    复制代码
         # 计算位置编码
         position = torch.arange(max_len).unsqueeze(1)          # [max_len, 1]
         div_term = torch.exp(torch.arange(0, d_model, 2) * 
                             (-math.log(10000.0) / d_model))     # [d_model/2]
         
         pe = torch.zeros(max_len, d_model)                      # [max_len, d_model]
         pe[:, 0::2] = torch.sin(position * div_term)
         pe[:, 1::2] = torch.cos(position * div_term)
         
         # 注册为buffer(不参与梯度计算,但会随模型保存/加载)
         self.register_buffer('pe', pe.unsqueeze(0))             # [1, max_len, d_model]
    
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         # x: [batch_size, seq_len, d_model]
         x = x + self.pe[:, :x.size(1), :]    # 广播相加
         return self.dropout(x)

    class MultiHeadAttention(nn.Module):
    """多头注意力(包含三种用法:self-attn, masked self-attn, cross-attn)"""
    def init(self, d_model: int, num_heads: int, dropout: float = 0.1):
    super().init()
    assert d_model % num_heads == 0

    复制代码
         self.d_model = d_model
         self.num_heads = num_heads
         self.d_k = d_model // num_heads
         
         # Q, K, V 的线性投影 + 输出投影
         self.W_q = nn.Linear(d_model, d_model)
         self.W_k = nn.Linear(d_model, d_model)
         self.W_v = nn.Linear(d_model, d_model)
         self.W_o = nn.Linear(d_model, d_model)
         
         self.dropout = nn.Dropout(dropout)
         self.scale = math.sqrt(self.d_k)
    
     def forward(self, 
                 query: torch.Tensor, 
                 key: torch.Tensor, 
                 value: torch.Tensor,
                 attn_mask: torch.Tensor = None) -> tuple:
         """
         query, key, value : [batch, seq_len, d_model]
         attn_mask         : [batch, seq_len_q, seq_len_k] 或 [batch*heads, seq_len_q, seq_len_k]
         """
         batch_size = query.size(0)
         
         # 线性变换 + 分头
         Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
         K = self.W_k(key)  .view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
         V = self.W_v(value) .view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
         # Q,K,V now: [batch, heads, seq_len, d_k]
    
         # 缩放点积注意力
         scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale   # [b, h, q_len, k_len]
         
         if attn_mask is not None:
             # 保证mask形状匹配 [b, h, q_len, k_len]
             if attn_mask.dim() == 3:
                 attn_mask = attn_mask.unsqueeze(1)
             scores = scores.masked_fill(attn_mask == 0, -1e9)
         
         attn = F.softmax(scores, dim=-1)
         attn = self.dropout(attn)
         
         context = torch.matmul(attn, V)                               # [b, h, q_len, d_k]
         
         # 合并多头 & 输出投影
         context = context.transpose(1, 2).contiguous()\
                          .view(batch_size, -1, self.d_model)
         output = self.W_o(context)
         
         return output, attn

    class PositionWiseFeedForward(nn.Module):
    """逐位置前馈网络(两层线性 + ReLU)"""
    def init(self, d_model: int, d_ff: int, dropout: float = 0.1):
    super().init()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

    复制代码
     def forward(self, x):
         return self.linear2(self.dropout(F.relu(self.linear1(x))))

    class EncoderLayer(nn.Module):
    """单层Encoder"""
    def init(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
    super().init()
    self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)

    复制代码
         self.norm1 = nn.LayerNorm(d_model)
         self.norm2 = nn.LayerNorm(d_model)
         self.dropout = nn.Dropout(dropout)
    
     def forward(self, x, mask=None):
         # Self-Attention + Residual + Norm
         attn_output, _ = self.self_attn(x, x, x, mask)
         x = self.norm1(x + self.dropout(attn_output))
         
         # Feed-Forward + Residual + Norm
         ff_output = self.feed_forward(x)
         x = self.norm2(x + self.dropout(ff_output))
         
         return x

    class DecoderLayer(nn.Module):
    """单层Decoder"""
    def init(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
    super().init()
    self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
    self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)

    复制代码
         self.norm1 = nn.LayerNorm(d_model)
         self.norm2 = nn.LayerNorm(d_model)
         self.norm3 = nn.LayerNorm(d_model)
         self.dropout = nn.Dropout(dropout)
    
     def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
         # Masked Self-Attention
         self_attn_out, _ = self.self_attn(x, x, x, tgt_mask)
         x = self.norm1(x + self.dropout(self_attn_out))
         
         # Cross-Attention (query来自decoder, key/value来自encoder)
         cross_attn_out, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
         x = self.norm2(x + self.dropout(cross_attn_out))
         
         # Feed-Forward
         ff_out = self.feed_forward(x)
         x = self.norm3(x + self.dropout(ff_out))
         
         return x

    class Transformer(nn.Module):
    """
    完整的Encoder-Decoder Transformer
    参数示例:d_model=512, num_heads=8, num_layers=6, d_ff=2048
    """
    def init(self,
    src_vocab_size: int,
    tgt_vocab_size: int,
    d_model: int = 512,
    num_heads: int = 8,
    num_layers: int = 6,
    d_ff: int = 2048,
    max_seq_len: int = 5000,
    dropout: float = 0.1):
    super().init()

    复制代码
         self.src_embed = nn.Embedding(src_vocab_size, d_model)
         self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
         self.pos_enc   = PositionalEncoding(d_model, max_seq_len, dropout)
         
         self.encoder_layers = nn.ModuleList([
             EncoderLayer(d_model, num_heads, d_ff, dropout) 
             for _ in range(num_layers)
         ])
         
         self.decoder_layers = nn.ModuleList([
             DecoderLayer(d_model, num_heads, d_ff, dropout) 
             for _ in range(num_layers)
         ])
         
         self.final_linear = nn.Linear(d_model, tgt_vocab_size)
         self.d_model = d_model
    
     def forward(self, src, tgt, src_mask=None, tgt_mask=None):
         # 嵌入 + 位置编码
         src = self.src_embed(src) * math.sqrt(self.d_model)
         src = self.pos_enc(src)
         
         tgt = self.tgt_embed(tgt) * math.sqrt(self.d_model)
         tgt = self.pos_enc(tgt)
         
         # Encoder
         memory = src
         for layer in self.encoder_layers:
             memory = layer(memory, src_mask)
         
         # Decoder
         output = tgt
         for layer in self.decoder_layers:
             output = layer(output, memory, src_mask, tgt_mask)
         
         # 最后投影到词表
         output = self.final_linear(output)
         return output
  2. 快速使用示例(机器翻译伪代码)

    假设已经准备好数据

    model = Transformer(
    src_vocab_size=30000,
    tgt_vocab_size=30000,
    d_model=512,
    num_heads=8,
    num_layers=6,
    d_ff=2048,
    dropout=0.1
    ).cuda()

    生成因果掩码(很重要!)

    def generate_square_subsequent_mask(sz: int):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

    tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

    output = model(src, tgt, src_mask=None, tgt_mask=tgt_mask) # [b, tgt_len, vocab]

4. 几点建议

  • 实际项目中建议加上 label smoothinglearning rate warmupNoam scheduler
  • 现代实现通常还会加入 pre-norm (先LayerNorm再sub-layer)、SwiGLU 激活函数、RoPE 位置编码等改进
  • 如果你只想做Decoder-only(如GPT系),可以直接删掉Encoder部分,src相关代码全部移除

希望这篇实现能帮你在从零理解Transformer的道路上再往前一步。

如果你想继续讨论以下任一方向,欢迎留言:

  • 添加RoPE位置编码
  • 实现Grouped-Query Attention / Multi-Query Attention
  • 做Decoder-only版本(GPT-like)
  • 性能优化技巧(torch.compile、flash attention等)

5. 向现代工业级Transformer演进:最值得优先实现的几项改动(2026视角)

上一节提到的几点只是起点。在实际从头训练或微调大模型时,以下几项改动带来的收益通常是最大的(按性价比排序):

5.1 Pre-LayerNorm(几乎必做)

Post-LN 在层数超过 ~24 层后训练极不稳定,而 Pre-LN 让 100+ 层的模型也能相对容易收敛。

改动方式(以 EncoderLayer 为例,Decoder 同理):

复制代码
class EncoderLayer(nn.Module):
    def __init__(self, ...):
        ...
        # 保持两个 norm,但位置前置
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # Pre-LN 写法(最常见现代风格)
        residual = x
        x = self.norm1(x)                     # 先归一化
        x = self.self_attn(x, x, x, mask)[0]  # attention
        x = residual + self.dropout(x)        # 残差

        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = residual + self.dropout(x)
        return x

额外建议:很多团队还会把 dropout 放在 sublayer 内部(attention 和 ff 内部),而不是加在残差路径上,进一步减少方差。

5.2 激活函数升级:SwiGLU(强烈推荐)

ReLU → GELU → SwiGLU 是过去几年最显著的单点收益之一。 SwiGLU 在同等计算量下通常能带来 1--3% 的 perplexity 提升,且几乎不增加推理延迟。

两种常见写法(推荐第二种,更清晰):

复制代码
# 写法1:直接在 Linear 后做 split + silu
class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff * 2)
        self.w2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.w1(x)
        x, gate = x.chunk(2, dim=-1)
        return self.w2(x * F.silu(gate))

# 写法2:更贴近 LLaMA / Qwen 的习惯(分开投影)
class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)   # gate
        self.w3 = nn.Linear(d_model, d_ff, bias=False)   # value
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

后者更常见,因为 w1 和 w3 可以独立初始化,且 bias=False 是当前主流做法。

5.3 位置编码:替换为 RoPE(中长期上下文必备)

正弦位置编码在 4k 上下文还凑合,但超过 8k--16k 后外推能力急剧下降。 RoPE(Rotary Position Embedding)已成为 2023 年后几乎所有高质量开源模型的标配。

最简洁可用的 1D-RoPE 实现(训练 + 推理都适用):

复制代码
def get_rotary_freqs(dim: int, max_position: int = 2048, base: float = 10000.0):
    theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    seq_idx = torch.arange(max_position).float()
    freqs = torch.outer(seq_idx, theta)
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)  # cos + i sin
    return freqs_complex  # [max_position, dim//2]

def apply_rotary(x: torch.Tensor, freqs_cis):
    # x: [bs, seq, heads, head_dim]
    # freqs_cis: [seq, head_dim//2] complex
    x_ = x.float().reshape(*x.shape[:-1], -1, 2)           # split to pairs
    x_complex = torch.view_as_complex(x_)
    rotated = x_complex * freqs_cis.unsqueeze(0).unsqueeze(2)  # broadcast
    x_rot = torch.view_as_real(rotated).flatten(-2)
    return x_rot.type_as(x)

在 MultiHeadAttention 中插入:

复制代码
# 在计算 Q、K 后、matmul 前应用
Q = apply_rotary(Q, freqs_cis)
K = apply_rotary(K, freqs_cis)

进阶技巧(如果想支持更长上下文):

  • 直接用 NTK-aware scaling(直接改 base 值)
  • YaRN / PI / LongRoPE 等外推方案(需额外代码)
5.4 注意力效率:GQA / MQA + torch sdpa

现代推理最关心的指标之一是 TTFT(time to first token)tokens/s

  • MQA:1 个 KV head,推理最快,但质量有损失
  • GQA:8--16 个 KV head,质量接近 MHA,推理速度大幅提升(LLaMA-3、Qwen2、Mistral、Gemma 等都用这个)

PyTorch 2.1+ 提供了原生支持:

复制代码
# 在 MultiHeadAttention forward 中替换原版 matmul + softmax
output = F.scaled_dot_product_attention(
    Q, K, V,
    attn_mask=attn_mask,
    dropout_p=self.dropout.p if self.training else 0.0,
    is_causal=(tgt_mask is not None),   # 自动生成 causal mask
)

开启 flash attention(需 torch>=2.2 + GPU 支持):

复制代码
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(...)

6. 小结:推荐的"2026 年入门现代 Transformer" 配置清单

  • 整体架构:Decoder-only(去掉 Encoder 和 cross-attention)
  • Normalization:Pre-LN
  • 激活函数:SwiGLU(bias=False)
  • 位置编码:RoPE(base=10000 或 NTK 调整)
  • Attention:GQA(KV heads = 8 或 16)
  • 训练超参:label smoothing 0.1、cosine decay、warmup 3000--10000 步
  • 推理加速:torch.compile + F.scaled_dot_product_attention(flash 模式)
相关推荐
暖阳之下2 小时前
学习周报三十二
人工智能·学习
byzh_rc2 小时前
[机器学习从入门到入土] 自回归滑动平均ARMA
人工智能·机器学习·回归
Das12 小时前
【机器学习】10_特征选择与稀疏学习
人工智能·学习·机器学习
徐112 小时前
文物数据如何长期保存?非接触式3D扫描仪的数字化解决方案
人工智能
SAP工博科技2 小时前
SAP ERP 公有云 AI / 机器学习落地指南:技术架构、业务场景与实施路径
人工智能
ybdesire2 小时前
AI驱动的威胁狩猎落地案例
人工智能
Aurora@Hui2 小时前
FactorAnalysisTool 因子分析工具
人工智能·算法·机器学习
SmartBrain2 小时前
Agent 技术在医疗场景的应用研究
人工智能·语言模型·aigc
羊仔AI探索2 小时前
AI心理学导师测评,智能体商单案例
ide·人工智能·ai·aigc