Seq2Seq - 编码器(Encoder)和解码器(Decoder)

本节实现一个简单的 Seq2Seq(Sequence to Sequence)模型 的编码器(Encoder)和解码器(Decoder)部分。

重点把握Seq2Seq 模型的整体工作流程

理解编码器(Encoder)和解码器(Decoder)代码

本小节引入了nn.GRU API的调用,nn.GRU具体参数将在下一小节进行补充讲解

1. 编码器(Encoder

类定义
复制代码
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
  • vocab_size:输入词汇表的大小,即输入序列中可能出现的不同单词或标记的数量。

  • embedding_dim:嵌入层的维度,即每个单词或标记被映射到的向量空间的维度。

  • hidden_size:GRU(门控循环单元)的隐藏状态维度,决定了模型的内部状态大小。

主要组件
  1. 嵌入层(nn.Embedding

    • 嵌入层会将输入序列形状转换为 [batch_size, seq_len, embedding_dim] 的张量。

    • 这种映射是通过学习嵌入矩阵实现的,每个单词索引对应嵌入矩阵中的一行。

  2. GRU(nn.GRU

    • embedding_dim 是 GRU 的输入维度,hidden_size 是隐藏状态的维度。

    • batch_first=True 表示输入和输出的张量的第一个维度是批量大小(batch_size),而不是序列长度(seq_len)。

前向传播(forward
复制代码
def forward(self, x):
    embs = self.emb(x) #batch * token * embedding_dim
    gru_out, hidden = self.rnn(embs) #batch * token * hidden_size

    return gru_out, hidden
  • 输入 x 是一个形状为 [batch_size, seq_len] 的张量,表示一个批次的输入序列。

  • embs 是嵌入层的输出,形状为 [batch_size, seq_len, embedding_dim]

  • gru_out 是 GRU 的输出,形状为 [batch_size, seq_len, hidden_size],表示每个时间步的隐藏状态。

  • hidden 是 GRU 的最终隐藏状态,形状为 [1, batch_size, hidden_size],用于传递给解码器。

2. 解码器(Decoder)

类定义
复制代码
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
  • 解码器的结构与编码器类似,但它的作用是将编码器生成的上下文向量(hidden)解码为目标序列。
主要组件
  1. 嵌入层(nn.Embedding

    • 与编码器类似,将目标序列的单词索引映射到嵌入向量。
  2. GRU(nn.GRU

    • 与编码器中的 GRU 类似,但其输入是目标序列的嵌入向量,初始隐藏状态是编码器的最终隐藏状态。
前向传播(forward
复制代码
def forward(self, x, hx):
    embs = self.emb(x)
    gru_out, hidden = self.rnn(embs, hx=hx) #batch * token * hidden_size
    # batch * token * hidden_size
    # 1 * token * hidden_size

    return gru_out, hidden
  • 输入 x 是目标序列的单词索引,形状为 [batch_size, seq_len]

  • hx 是编码器的最终隐藏状态,形状为 [1, batch_size, hidden_size],作为解码器的初始隐藏状态。

  • embs 是目标序列的嵌入向量,形状为 [batch_size, seq_len, embedding_dim]

  • gru_out 是解码器 GRU 的输出,形状为 [batch_size, seq_len, hidden_size]

  • hidden 是解码器 GRU 的最终隐藏状态,形状为 [1, batch_size, hidden_size]

3. Seq2Seq 模型的整体工作流程⭐

  1. 编码阶段

    • 输入序列通过编码器的嵌入层,将单词索引映射为嵌入向量。

    • 嵌入向量通过 GRU,生成每个时间步的隐藏状态和最终的隐藏状态(上下文向量)。

    • 最终隐藏状态(hidden)作为编码器的输出,传递给解码器。

  2. 解码阶段

    • 解码器的初始隐藏状态是编码器的最终隐藏状态。

    • 解码器逐个生成目标序列的单词,每次生成一个单词后,将该单词的嵌入向量作为下一次输入,同时更新隐藏状态。

    • 通过这种方式,解码器逐步生成目标序列。

相关推荐
每天都要写算法(努力版)23 分钟前
【神经网络与深度学习】深度学习中的生成模型简介
人工智能·深度学习·神经网络·生成模型
进来有惊喜24 分钟前
过采样处理
人工智能
shao9185161 小时前
Gradio全解20——Streaming:流式传输的多媒体应用(5)——基于WebRTC的摄像头实时目标检测
人工智能·目标检测·webrtc·yolov10·twilio·yoloe·turn服务器
蹦蹦跳跳真可爱5892 小时前
Python----机器学习(模型评估:准确率、损失函数值、精确度、召回率、F1分数、混淆矩阵、ROC曲线和AUC值、Top-k精度)
人工智能·python·机器学习
江鸟19984 小时前
AI 编程日报 · 2025 年 5 月 04 日|GitHub Copilot Agent 模式发布,Ultralytics 优化训练效率
人工智能·github·copilot
liaokailin6 小时前
Spring AI 实战:第十一章、Spring AI Agent之知行合一
java·人工智能·spring
Bruce_Liuxiaowei6 小时前
从零开发一个B站视频数据统计Chrome插件
人工智能·visualstudio·html
乌恩大侠6 小时前
【AI科技】ROCm 6.4:打破 AI、HPC 和模块化 GPU 软件的障碍
人工智能·科技
nuise_8 小时前
李沐《动手学深度学习》 | Softmax回归 - 分类问题
深度学习·分类·回归
CHNMSCS8 小时前
PyTorch_张量基本运算
人工智能·pytorch·python