Seq2Seq 架构
Seq2Seq 是序列到序列任务模型,可以实现将一个序列转换为另一个序列,我们遇到的很多场景,其实就是 Seq2Seq 任务,比如:
- 翻译:中文序列翻译成英文序列
- 会议纪要:会议内容序列到纪要序列
- 图书分类:一本书分类是个人成长,书(序列)到类别(序列)

Seq2Seq 常用的架构是编码器-解码器架构,即 Encoder-Decoder 架构。
Seq1 先经过编码器处理,编码器可以看到 Seq1 的所有内容,所以更好的理解了 Seq1 的内容,然后编码成上下文向量,再输入到解码器中,解码器基于当前可获得的信息,逐步生成文本。
plain
输入文本 "我爱学习AI"
↓ 编码器逐步处理
token1 "我" → 向量表示
token2 "爱" → 向量表示
token3 "学习" → 向量表示
token4 "AI" → 向量表示
↓ 整合所有信息
上下文向量 (d_model) - 包含整个句子的理解
plain
上下文向量 (d_model) - 来自编码器的理解
↓ 解码器逐步生成
生成第1个词 "I"
↓ 基于已生成内容继续
生成第2个词 "love"
↓ 继续生成
生成第3个词 "learning"
↓ 继续生成
生成第4个词 "AI"
↓
输出文本 "I love learning AI"
根据上述,我们可以看到,编码器擅长理解,是阅读理解专家;解码器擅长生成,是写作专家。为什么会这样?
编码器
我们将输入序列传递给编码器时,编码器处理当前词时,也可以看到当前词左边和右边的词,可以理解为有双向注意力机制,因此可以更好的理解完整上下文含义;类比我们读完一整段话后,可以理解整体意思。
plain
编码器能同时看到:
- 左边的词
- 当前的词
- 右边的词
例如:处理"学习"这个词时,编码器能看到:
"我" + "爱" + "学习" + "AI"
所以,编码器适用的场景包括:
- 文本分类
- 命名实体识别
- 提取式生成摘要
解码器
解码器是基于当前序列,预测下一个最可能的词,逐步生成一段完整的内容。所以,解码器在处理当前词时,只能看到当前词和其左边的词(已生成),右边的词还没有生成。
plain
解码器只能看到:
- 左边的词(已生成的)
- 当前的词
不能看到右边的词(还没生成的)
例如:生成"love"时,只能看到:
"I" + "love"
不能看到后面的词
所以,解码器适用的场景包括:
- 文本生成
- 生成式对话系统
- 代码生成
实现
python
"""
完整的Seq2Seq模型实现
"""
import torch
import torch.nn as nn
class Seq2Seq(nn.Module):
def __init__(self, vocab_size, d_model=512, n_heads=8,
num_encoder_layers=6, num_decoder_layers=6):
super().__init__()
# Embedding层
self.embedding = nn.Embedding(vocab_size, d_model)
# 编码器
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
# 解码器
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_heads)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
# 输出投影层
self.output_projection = nn.Linear(d_model, vocab_size)
def encode(self, src):
"""
编码器阶段
Args:
src: 输入序列 (batch_size, seq_len)
Returns:
context_vector: 上下文向量 (batch_size, d_model)
"""
# Embedding
src_embeddings = self.embedding(src)
print(f" Embedding形状: {src_embeddings.shape}") # (batch_size, seq_len, d_model)
# 编码器处理
encoder_output = self.encoder(src_embeddings)
print(f" 编码器输出形状: {encoder_output.shape}") # (batch_size, seq_len, d_model)
# 提取上下文向量(平均池化)
context_vector = encoder_output.mean(dim=1)
print(f" 上下文向量形状: {context_vector.shape}") # (batch_size, d_model)
return context_vector
def decode(self, context_vector, max_length=10):
"""
解码器阶段
Args:
context_vector: 上下文向量 (batch_size, d_model)
max_length: 最大生成长度
Returns:
output_ids: 输出的token IDs (batch_size, max_length)
"""
batch_size = context_vector.size(0)
# 初始化输入(开始标记)
current_input = torch.zeros(batch_size, 1, context_vector.size(1))
output_ids = []
for i in range(max_length):
# 解码器处理
decoder_output = self.decoder(current_input, context_vector.unsqueeze(1))
print(f" 步骤{i+1}: 解码器输出形状: {decoder_output.shape}") # (batch_size, i+1, d_model)
# 投影到词汇表
logits = self.output_projection(decoder_output[:, -1, :])
print(f" 步骤{i+1}: Logits形状: {logits.shape}") # (batch_size, vocab_size)
# 选择最可能的token
next_token = logits.argmax(dim=-1, keepdim=True)
print(f" 步骤{i+1}: 生成的Token ID: {next_token.item()}")
output_ids.append(next_token)
# 获取token的embedding并拼接
next_token_embedding = self.embedding(next_token)
current_input = torch.cat([current_input, next_token_embedding], dim=1)
print(f" 步骤{i+1}: 当前输入形状: {current_input.shape}") # (batch_size, i+2, d_model)
# 拼接所有输出
output_ids = torch.cat(output_ids, dim=1)
print(f" 最终输出形状: {output_ids.shape}") # (batch_size, max_length)
return output_ids
def forward(self, src, max_length=10):
"""
前向传播
Args:
src: 输入序列 (batch_size, seq_len)
max_length: 最大生成长度
Returns:
output_ids: 输出的token IDs (batch_size, max_length)
"""
print("\n=== 编码器阶段 ===")
context_vector = self.encode(src)
print("\n=== 解码器阶段 ===")
output_ids = self.decode(context_vector, max_length)
return output_ids
# 演示
print("=== Seq2Seq模型演示 ===\n")
# 创建模型
vocab_size = 10000
model = Seq2Seq(vocab_size=vocab_size, d_model=512, n_heads=8,
num_encoder_layers=2, num_decoder_layers=2)
# 输入序列
src = torch.randint(0, vocab_size, (1, 4)) # (batch_size=1, seq_len=4)
print(f"输入序列: {src.tolist()}")
print(f"输入形状: {src.shape}\n")
# 前向传播
output = model(src, max_length=5)
print(f"\n输出序列: {output.tolist()}")
print(f"输出形状: {output.shape}")