Transformer 编码器深度解读 + 代码实战
1. 编码器核心作用
Transformer 编码器的核心任务是将输入序列(如文本、语音)转换为富含上下文语义的高维特征表示。它通过多层自注意力(Self-Attention)和前馈网络(FFN),逐步建模全局依赖关系,解决传统RNN/CNN的长距离依赖缺陷。
2. 编码器单层结构详解
每层编码器包含以下模块(附 PyTorch 代码):
2.1 多头自注意力(Multi-Head Self-Attention)
python
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# 线性变换层生成 Q, K, V
self.to_qkv = nn.Linear(embed_size, embed_size * 3) # 同时生成 Q/K/V
self.scale = self.head_dim ** -0.5 # 缩放因子
# 输出线性层
self.to_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 生成 Q, K, V 并分割多头
qkv = self.to_qkv(x).chunk(3, dim=-1) # 拆分为 [Q, K, V]
q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
# 计算注意力分数 (QK^T / sqrt(d_k))
attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
# 掩码(编码器通常不需要,但保留接口)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e10)
# Softmax 归一化
attn = torch.softmax(attn, dim=-1)
# 加权求和
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = out.reshape(batch_size, seq_len, self.embed_size)
# 输出线性变换
return self.to_out(out)
代码解析:
nn.Linear
生成 Q/K/V 矩阵,通过chunk
分割。einsum
实现高效矩阵运算,计算注意力分数。- 支持掩码(虽编码器通常不用,但为兼容性保留)。
2.2 前馈网络(Feed-Forward Network)
python
class FeedForward(nn.Module):
def __init__(self, embed_size, expansion=4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_size, embed_size * expansion), # 扩展维度
nn.GELU(), # 更平滑的激活函数(比ReLU效果更好)
nn.Linear(embed_size * expansion, embed_size) # 压缩回原维度
)
def forward(self, x):
return self.net(x)
代码解析:
- 典型结构:扩展维度(如512→2048)→激活→压缩回原维度。
- 使用
GELU
替代ReLU
(现代Transformer的常见选择)。
2.3 残差连接 + 层归一化(Add & Norm)
python
class TransformerEncoderLayer(nn.Module):
def __init__(self, embed_size, heads, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(embed_size, heads)
self.ffn = FeedForward(embed_size)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 自注意力子层
attn_out = self.attn(x)
x = x + self.dropout(attn_out) # 残差连接
x = self.norm1(x)
# 前馈子层
ffn_out = self.ffn(x)
x = x + self.dropout(ffn_out) # 残差连接
x = self.norm2(x)
return x
代码解析:
- 每个子层后执行
x = x + dropout(sublayer(x))
,再层归一化。 - 残差连接确保梯度稳定,层归一化加速收敛。
3. 位置编码(Positional Encoding)
python
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0)/embed_size)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置
self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, embed_size)
def forward(self, x):
return x + self.pe[:, :x.size(1)] # 自动广播到 (batch_size, seq_len, embed_size)
代码解析:
- 通过正弦/余弦函数编码绝对位置。
register_buffer
将位置编码注册为模型常量(不参与训练)。
4. 完整编码器实现
python
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.pos_encoding = PositionalEncoding(embed_size)
self.layers = nn.ModuleList([
TransformerEncoderLayer(embed_size, heads, dropout)
for _ in range(layers)
])
def forward(self, x):
# 输入x形状: (batch_size, seq_len)
x = self.embedding(x) # (batch_size, seq_len, embed_size)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x)
return x # (batch_size, seq_len, embed_size)
5. 实战测试
python
# 参数设置
vocab_size = 10000 # 假设词表大小
embed_size = 512 # 嵌入维度
layers = 6 # 编码器层数
heads = 8 # 注意力头数
# 初始化模型
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)
# 模拟输入(batch_size=32, seq_len=50)
x = torch.randint(0, vocab_size, (32, 50)) # 随机生成句子
# 前向传播
output = encoder(x)
print(output.shape) # 预期输出: torch.Size([32, 50, 512])