手撕Transformer编码器:从Self-Attention到Positional Encoding的PyTorch逐行实现


Transformer 编码器深度解读 + 代码实战


1. 编码器核心作用

Transformer 编码器的核心任务是将输入序列(如文本、语音)转换为富含上下文语义的高维特征表示。它通过多层自注意力(Self-Attention)和前馈网络(FFN),逐步建模全局依赖关系,解决传统RNN/CNN的长距离依赖缺陷。


2. 编码器单层结构详解

每层编码器包含以下模块(附 PyTorch 代码):

2.1 多头自注意力(Multi-Head Self-Attention)
python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        # 线性变换层生成 Q, K, V
        self.to_qkv = nn.Linear(embed_size, embed_size * 3)  # 同时生成 Q/K/V
        self.scale = self.head_dim ** -0.5  # 缩放因子
        
        # 输出线性层
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q, K, V 并分割多头
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 拆分为 [Q, K, V]
        q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
        
        # 计算注意力分数 (QK^T / sqrt(d_k))
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 掩码(编码器通常不需要,但保留接口)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)
        
        # Softmax 归一化
        attn = torch.softmax(attn, dim=-1)
        
        # 加权求和
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.reshape(batch_size, seq_len, self.embed_size)
        
        # 输出线性变换
        return self.to_out(out)

代码解析

  • nn.Linear 生成 Q/K/V 矩阵,通过 chunk 分割。
  • einsum 实现高效矩阵运算,计算注意力分数。
  • 支持掩码(虽编码器通常不用,但为兼容性保留)。

2.2 前馈网络(Feed-Forward Network)
python 复制代码
class FeedForward(nn.Module):
    def __init__(self, embed_size, expansion=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_size, embed_size * expansion),  # 扩展维度
            nn.GELU(),  # 更平滑的激活函数(比ReLU效果更好)
            nn.Linear(embed_size * expansion, embed_size)   # 压缩回原维度
        )
    
    def forward(self, x):
        return self.net(x)

代码解析

  • 典型结构:扩展维度(如512→2048)→激活→压缩回原维度。
  • 使用 GELU 替代 ReLU(现代Transformer的常见选择)。

2.3 残差连接 + 层归一化(Add & Norm)
python 复制代码
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, heads, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(embed_size, heads)
        self.ffn = FeedForward(embed_size)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 自注意力子层
        attn_out = self.attn(x)
        x = x + self.dropout(attn_out)  # 残差连接
        x = self.norm1(x)
        
        # 前馈子层
        ffn_out = self.ffn(x)
        x = x + self.dropout(ffn_out)   # 残差连接
        x = self.norm2(x)
        return x

代码解析

  • 每个子层后执行 x = x + dropout(sublayer(x)),再层归一化。
  • 残差连接确保梯度稳定,层归一化加速收敛。

3. 位置编码(Positional Encoding)
python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0)/embed_size)
        
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置
        self.register_buffer('pe', pe.unsqueeze(0))   # (1, max_len, embed_size)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]  # 自动广播到 (batch_size, seq_len, embed_size)

代码解析

  • 通过正弦/余弦函数编码绝对位置。
  • register_buffer 将位置编码注册为模型常量(不参与训练)。

4. 完整编码器实现
python 复制代码
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_size, heads, dropout)
            for _ in range(layers)
        ])
    
    def forward(self, x):
        # 输入x形状: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embed_size)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x)
        return x  # (batch_size, seq_len, embed_size)

5. 实战测试

python 复制代码
# 参数设置
vocab_size = 10000  # 假设词表大小
embed_size = 512    # 嵌入维度
layers = 6          # 编码器层数
heads = 8           # 注意力头数

# 初始化模型
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)

# 模拟输入(batch_size=32, seq_len=50)
x = torch.randint(0, vocab_size, (32, 50))  # 随机生成句子

# 前向传播
output = encoder(x)
print(output.shape)  # 预期输出: torch.Size([32, 50, 512])

相关推荐
编程小白_正在努力中11 小时前
神经网络深度解析:从神经元到深度学习的进化之路
人工智能·深度学习·神经网络·机器学习
无风听海11 小时前
神经网络之经验风险最小化
人工智能·深度学习·神经网络
H***997613 小时前
月之暗面公开强化学习训练加速方法:训练速度暴涨97%,长尾延迟狂降93%
人工智能·深度学习·机器学习
FL162386312914 小时前
无人机视角航拍河道漂浮物垃圾识别分割数据集labelme格式256张1类别
深度学习
audyxiao00115 小时前
期刊研究热点扫描|一文了解计算机视觉顶刊TIP的研究热点
人工智能·计算机视觉·transformer·图像分割·多模态
青瓷程序设计16 小时前
昆虫识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
小殊小殊17 小时前
DeepSeek为什么这么慢?
人工智能·深度学习
Coding茶水间19 小时前
基于深度学习的路面坑洞检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
哥布林学者19 小时前
吴恩达深度学习课程三: 结构化机器学习项目 第二周:误差分析与学习方法(一)误差分析与快速迭代
深度学习·ai
CoovallyAIHub20 小时前
如何在手机上轻松识别多种鸟类?我们发现了更简单的秘密……
深度学习·算法·计算机视觉