手撕Transformer编码器:从Self-Attention到Positional Encoding的PyTorch逐行实现


Transformer 编码器深度解读 + 代码实战


1. 编码器核心作用

Transformer 编码器的核心任务是将输入序列(如文本、语音)转换为富含上下文语义的高维特征表示。它通过多层自注意力(Self-Attention)和前馈网络(FFN),逐步建模全局依赖关系,解决传统RNN/CNN的长距离依赖缺陷。


2. 编码器单层结构详解

每层编码器包含以下模块(附 PyTorch 代码):

2.1 多头自注意力(Multi-Head Self-Attention)
python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        # 线性变换层生成 Q, K, V
        self.to_qkv = nn.Linear(embed_size, embed_size * 3)  # 同时生成 Q/K/V
        self.scale = self.head_dim ** -0.5  # 缩放因子
        
        # 输出线性层
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q, K, V 并分割多头
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 拆分为 [Q, K, V]
        q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
        
        # 计算注意力分数 (QK^T / sqrt(d_k))
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 掩码(编码器通常不需要,但保留接口)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)
        
        # Softmax 归一化
        attn = torch.softmax(attn, dim=-1)
        
        # 加权求和
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.reshape(batch_size, seq_len, self.embed_size)
        
        # 输出线性变换
        return self.to_out(out)

代码解析

  • nn.Linear 生成 Q/K/V 矩阵,通过 chunk 分割。
  • einsum 实现高效矩阵运算,计算注意力分数。
  • 支持掩码(虽编码器通常不用,但为兼容性保留)。

2.2 前馈网络(Feed-Forward Network)
python 复制代码
class FeedForward(nn.Module):
    def __init__(self, embed_size, expansion=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_size, embed_size * expansion),  # 扩展维度
            nn.GELU(),  # 更平滑的激活函数(比ReLU效果更好)
            nn.Linear(embed_size * expansion, embed_size)   # 压缩回原维度
        )
    
    def forward(self, x):
        return self.net(x)

代码解析

  • 典型结构:扩展维度(如512→2048)→激活→压缩回原维度。
  • 使用 GELU 替代 ReLU(现代Transformer的常见选择)。

2.3 残差连接 + 层归一化(Add & Norm)
python 复制代码
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, heads, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(embed_size, heads)
        self.ffn = FeedForward(embed_size)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 自注意力子层
        attn_out = self.attn(x)
        x = x + self.dropout(attn_out)  # 残差连接
        x = self.norm1(x)
        
        # 前馈子层
        ffn_out = self.ffn(x)
        x = x + self.dropout(ffn_out)   # 残差连接
        x = self.norm2(x)
        return x

代码解析

  • 每个子层后执行 x = x + dropout(sublayer(x)),再层归一化。
  • 残差连接确保梯度稳定,层归一化加速收敛。

3. 位置编码(Positional Encoding)
python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0)/embed_size)
        
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置
        self.register_buffer('pe', pe.unsqueeze(0))   # (1, max_len, embed_size)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]  # 自动广播到 (batch_size, seq_len, embed_size)

代码解析

  • 通过正弦/余弦函数编码绝对位置。
  • register_buffer 将位置编码注册为模型常量(不参与训练)。

4. 完整编码器实现
python 复制代码
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_size, heads, dropout)
            for _ in range(layers)
        ])
    
    def forward(self, x):
        # 输入x形状: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embed_size)
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x)
        return x  # (batch_size, seq_len, embed_size)

5. 实战测试

python 复制代码
# 参数设置
vocab_size = 10000  # 假设词表大小
embed_size = 512    # 嵌入维度
layers = 6          # 编码器层数
heads = 8           # 注意力头数

# 初始化模型
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)

# 模拟输入(batch_size=32, seq_len=50)
x = torch.randint(0, vocab_size, (32, 50))  # 随机生成句子

# 前向传播
output = encoder(x)
print(output.shape)  # 预期输出: torch.Size([32, 50, 512])

相关推荐
代码79724 分钟前
【无标题】使用 Playwright 实现跨 Chromium、Firefox、WebKit 浏览器自动化操作
运维·前端·深度学习·华为·自动化
.银河系.37 分钟前
9.28 深度学习10
人工智能·深度学习
jie*38 分钟前
小杰深度学习(two)——全连接与链式求导
图像处理·人工智能·pytorch·python·深度学习·分类·回归
Bwcx_lzp42 分钟前
深度学习核心技术演进:从函数到 Transformer 架构
人工智能·深度学习·transformer
代码797215 小时前
使用会话存储时,处理存储信息加密问题
深度学习·算法·自动化·散列表·harmonyos
小毕超15 小时前
基于 PyTorch 完全从零手搓 GPT 混合专家 (MOE) 对话模型
pytorch·transformer·moe
Coovally AI模型快速验证15 小时前
华为发布开源超节点架构,以开放战略叩响AI算力生态变局
人工智能·深度学习·神经网络·计算机视觉·华为·架构·开源
ygyqinghuan15 小时前
Pytorch 数据处理
人工智能·pytorch·python
且慢.58916 小时前
机器学习/深度学习名词理解
人工智能·深度学习·机器学习
SkyXZ~17 小时前
AWS SageMaker SDK 完整教程:从零开始云端训练你的模型
人工智能·深度学习·云计算·aws·sagemaker·boto3