经典Transformer的PyTorch实现

逐行拆解一个经典Transformer的PyTorch实现,并对应到"猫坐在垫子上"这个例子。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 假设词汇表大小=1000,向量维度=512,批次大小=2
# 输入: ["猫坐在垫子上", "狗在跑"]

1. 位置编码(Positional Encoding)

python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # 创建[max_len, d_model]的全0矩阵
        pe = torch.zeros(max_len, d_model)
        
        # position: [0,1,2,...,max_len-1] 纵向
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # div_term: 10000^(2i/d_model) 的分母
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # 偶数位置用sin,奇数位置用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # 0,2,4,...维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 1,3,5,...维度
        
        pe = pe.unsqueeze(0).transpose(0, 1)  # shape: [max_len, 1, d_model]
        self.register_buffer('pe', pe)  # 不更新的常量

    def forward(self, x):
        # x: [seq_len, batch_size, d_model]
        # 将对应长度的位置编码加到输入上
        x = x + self.pe[:x.size(0), :]  # "猫"在第0位,"坐"在第1位...
        return x

2. 缩放点积注意力(Scaled Dot-Product Attention)

python 复制代码
def attention(q, k, v, mask=None, dropout=None):
    # q,k,v: [batch_size, num_heads, seq_len, d_k] (d_k = d_model/num_heads)
    d_k = q.size(-1)
    
    # 1) Q*K^T / sqrt(d_k)
    # scores: [batch, heads, seq_len, seq_len]
    # 第i行第j列表示"第i个词对第j个词的注意力分数"
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 2) 应用mask(解码器用,防止看到未来词)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)  # 填充极小值
    
    # 3) Softmax得到注意力权重
    # attn_weights: [batch, heads, seq_len, seq_len]
    # 每行和为1,表示每个词对其他词的注意力分布
    attn_weights = F.softmax(scores, dim=-1)
    
    # 4) 应用dropout防止过拟合
    if dropout is not None:
        attn_weights = dropout(attn_weights)
    
    # 5) 加权求和: weights * V
    # output: [batch, heads, seq_len, d_k]
    # 此时"垫子"的向量已融合"猫"的信息
    output = torch.matmul(attn_weights, v)
    return output, attn_weights

3. 多头注意力(Multi-Head Attention)

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0  # 必须整除
        
        self.d_k = d_model // h  # 每个头的维度,例如512/8=64
        self.h = h  # 头数
        
        # 4个线性层: Q, K, V, 输出
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 1) 线性变换 + 分头
        # q: [batch, seq_len, d_model] -> [batch, seq_len, heads, d_k]
        q = self.linear_q(q).view(batch_size, -1, self.h, self.d_k)
        k = self.linear_k(k).view(batch_size, -1, self.h, self.d_k)
        v = self.linear_v(v).view(batch_size, -1, self.h, self.d_k)
        
        # 2) 转置: [batch, seq_len, heads, d_k] -> [batch, heads, seq_len, d_k]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 3) 计算注意力
        # attn_output: [batch, heads, seq_len, d_k]
        # attn_weights: [batch, heads, seq_len, seq_len]
        attn_output, attn_weights = attention(q, k, v, mask, self.dropout)
        
        # 4) 拼接多头
        # 先转回: [batch, seq_len, heads, d_k]
        # 再view: [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.h * self.d_k)
        
        # 5) 最后的线性层
        output = self.linear_out(attn_output)
        return output, attn_weights

4. 前馈网络(Position-wise Feed-Forward)

python 复制代码
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        # 两层全连接: d_model -> d_ff -> d_model
        # d_ff通常是4*d_model,如2048
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: [batch, seq_len, d_model]
        
        # 第一层 + ReLU激活
        # intermediate: [batch, seq_len, d_ff]
        # 这一步学习非线性特征,如"坐"+"垫子"→"休息"
        intermediate = F.relu(self.w1(x))
        
        # Dropout防止过拟合
        intermediate = self.dropout(intermediate)
        
        # 第二层投影回原维度
        output = self.w2(intermediate)
        return output

5. 编码器层(Encoder Layer)

python 复制代码
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(num_heads, d_model, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # x: [batch, seq_len, d_model]
        
        # 1) 多头自注意力
        # 编码器Q,K,V都来自自己
        # attn_output: [batch, seq_len, d_model]
        attn_output, _ = self.attention(x, x, x, mask)
        
        # 2) 残差连接 + 层归一化
        # 先dropout,再加原输入(短路连接)
        # 防止梯度消失,让训练更稳定
        x = self.norm1(x + self.dropout(attn_output))
        
        # 3) 前馈网络
        ff_output = self.feed_forward(x)
        
        # 4) 残差连接 + 层归一化
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

6. 编码器(Encoder)

python 复制代码
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout):
        super(Encoder, self).__init__()
        
        # 词嵌入层: 将token id转为向量
        # 如"猫"的id=5 -> 512维向量
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model)
        
        # 多个编码器层堆叠
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)  # 如12层
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # x: [batch, seq_len] token ids
        # e.g., [[5, 10, 2, 50, 3, 1], [8, 4, 2, 30, 1, 0]]
        
        # 1) 词嵌入
        # x: [batch, seq_len, d_model]
        x = self.embedding(x)
        
        # 2) 乘以sqrt(d_model)来缩放(和经验发现有关)
        x = x * math.sqrt(x.size(-1))
        
        # 3) 位置编码
        x = self.pos_encoding(x)
        
        # 4) Dropout
        x = self.dropout(x)
        
        # 5) 依次通过所有编码器层
        # 每层都让token间的交互更深
        # 第1层:"垫子"知道有"猫"
        # 第6层:理解"猫坐在垫子上"的场景
        # 第12层:捕捉完整语义和潜在逻辑
        for layer in self.layers:
            x = layer(x, mask)
        
        return x

7. 解码器层(Decoder Layer)

python 复制代码
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        
        # 1) 带掩码的自注意力(防止看到未来词)
        self.self_attn = MultiHeadAttention(num_heads, d_model, dropout)
        
        # 2) 编码器-解码器注意力
        self.enc_attn = MultiHeadAttention(num_heads, d_model, dropout)
        
        # 3) 前馈网络
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # 三个层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # x: 解码器输入 [batch, tgt_len, d_model]
        # enc_output: 编码器输出 [batch, src_len, d_model]
        
        # 1) 带掩码的自注意力
        # 防止生成第i个词时看到i+1及以后的词
        self_attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        
        # 2) 编码器-解码器注意力
        # Q来自解码器,K,V来自编码器
        # 让解码器关注编码器的输出
        enc_attn_output, attn_weights = self.enc_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(enc_attn_output))
        
        # 3) 前馈网络
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x, attn_weights

8. 解码器(Decoder)

python 复制代码
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        x = self.embedding(x) * math.sqrt(x.size(-1))
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        for layer in self.layers:
            x, attn_weights = layer(x, enc_output, src_mask, tgt_mask)
        
        return x, attn_weights

9. 完整Transformer

python 复制代码
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, 
                 num_heads=8, d_ff=2048, num_layers=6, dropout=0.1):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(src_vocab, d_model, num_heads, d_ff, num_layers, dropout)
        self.decoder = Decoder(tgt_vocab, d_model, num_heads, d_ff, num_layers, dropout)
        
        # 最后的线性层 + Softmax,输出词表概率
        self.linear_out = nn.Linear(d_model, tgt_vocab)
        
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # src: [batch, src_len] 源语言token ids
        # tgt: [batch, tgt_len] 目标语言token ids
        
        # 1) 编码器处理源语言
        # enc_output: [batch, src_len, d_model]
        enc_output = self.encoder(src, src_mask)
        
        # 2) 解码器生成目标语言
        # dec_output: [batch, tgt_len, d_model]
        dec_output, attn_weights = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        
        # 3) 投影到词表空间
        # output: [batch, tgt_len, tgt_vocab]
        output = self.linear_out(dec_output)
        
        return output, attn_weights

10. 使用示例

python 复制代码
# 超参数
src_vocab = 1000  # 中文词表大小
tgt_vocab = 1000  # 英文词表大小
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6

# 创建模型
model = Transformer(src_vocab, tgt_vocab, d_model, num_heads, d_ff, num_layers)

# 模拟输入
src = torch.tensor([[5, 10, 2, 50, 3, 1], [8, 4, 2, 30, 1, 0]])  # [2,6]
tgt = torch.tensor([[1, 15, 20, 4, 2, 0], [1, 12, 6, 7, 2, 0]])  # [2,6]

# 前向传播
output, attn = model(src, tgt)

# output: [2, 6, 1000],每个位置是英文词的概率分布
# 如output[0,0]是"<BOS>"的概率,output[0,1]是"The"的概率...

关键变量形状演变总结

复制代码
输入id:        [batch, seq_len]               e.g., [2, 6]
↓ 嵌入
词向量:        [batch, seq_len, d_model]       e.g., [2, 6, 512]
↓ 多头注意力
分头后:        [batch, heads, seq_len, d_k]    e.g., [2, 8, 6, 64]
↓ 注意力计算
输出:          [batch, heads, seq_len, d_k]    e.g., [2, 8, 6, 64]
↓ 拼接
合并头:        [batch, seq_len, d_model]       e.g., [2, 6, 512]
↓ 前馈网络
最终输出:      [batch, seq_len, d_model]       e.g., [2, 6, 512]

每一层都在同一个向量空间中操作,但通过注意力机制不断交换信息,让模型逐层抽象出从字形→词义→句法→语义的层次化理解。

相关推荐
AI人工智能+6 分钟前
智能表格识别技术:通过深度学习与版面分析相结合,解决传统OCR在复杂表格处理中的局限性
深度学习·ocr·表格识别
Lian_Ge_Blog6 分钟前
知识蒸馏学习总结
人工智能·深度学习
2401_841495648 分钟前
【机器学习】人工神经网络(ANN)
人工智能·python·深度学习·神经网络·机器学习·特征学习·非线性映射
薛不痒13 分钟前
深度学习之神经网络的构建和实现
人工智能·深度学习·神经网络
jrlong20 分钟前
HappyLLM task12 大模型训练流程实践
人工智能·深度学习·机器学习
小途软件11 小时前
用于机器人电池电量预测的Sarsa强化学习混合集成方法
java·人工智能·pytorch·python·深度学习·语言模型
哥布林学者11 小时前
吴恩达深度学习课程五:自然语言处理 第一周:循环神经网络 (五)门控循环单元 GRU
深度学习·ai
薛不痒11 小时前
深度学习之优化模型(数据预处理,数据增强,调整学习率)
深度学习·学习
棒棒的皮皮13 小时前
【深度学习】YOLO模型速度优化Checklist
人工智能·深度学习·yolo·计算机视觉
AI街潜水的八角15 小时前
基于Pytorch深度学习神经网络MNIST手写数字识别系统源码(带界面和手写画板)
pytorch·深度学习·神经网络