现代循环神经网络6:编码器-解码器架构

一、编码器-解码器架构概述

在许多实际应用中(例如机器翻译),我们的输入和输出都是长度可变的序列。传统的神经网络难以直接处理这种变长输入输出问题。编码器-解码器(Encoder-Decoder)架构正是为此而设计的。它由两个主要组件构成:

  1. 编码器(Encoder): 接收一个长度可变的输入序列,并将其转换为一个固定形状的"编码状态"。
  2. 解码器(Decoder): 接收编码状态,并逐步生成长度可变的输出序列。

举个简单例子,对于英语到法语的机器翻译,输入序列可能是:

  • 英文输入:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( They , are , watching , . ) X = (\text{They}, \text{are}, \text{watching}, \text{.}) </math>X=(They,are,watching,.)

编码器将这个序列编码为一个固定的状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s(可以理解为一个向量),表示为:

  • 编码状态:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 表示编码器内部的映射函数。然后,解码器基于这个状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s,逐个生成输出词元,得到翻译后的序列:

  • 法文输出:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Y = ( Ils , regardent , . ) Y = (\text{Ils}, \text{regardent}, \text{.}) </math>Y=(Ils,regardent,.)

如图所示,整个过程可以直观地理解为将一个变长序列"压缩"为一个定长的状态,再由这个状态"解压缩"出另一个变长序列。


二、编码器

编码器的核心任务是将输入序列转换为一个固定形状的编码状态。假设输入序列为
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( x 1 , x 2 , ... , x T ) X = (x_1, x_2, \ldots, x_T) </math>X=(x1,x2,...,xT)

那么编码器可以看作是一个函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f,输出一个状态向量
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

在深度学习中,常用循环神经网络(RNN)、长短时记忆网络(LSTM)或门控循环单元(GRU)来实现编码器,因为它们能够处理序列数据。

下面是一个基于 PyTorch 框架的编码器接口示例代码:

python 复制代码
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""

    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *rgs):
        # X 为输入序列(长度可变)
        # 此处应实现编码逻辑,将 X 转换为固定形状的编码状态
        raise NotImplementedError

在这个接口中,任何继承 Encoder 的模型都需要实现 forward 方法。数学上可以认为这个过程为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

三、解码器

解码器负责将固定形状的编码状态转换为长度可变的输出序列。其核心在于:

  1. 初始化状态:将编码器输出转换为解码器的初始状态。
  2. 逐步生成输出:在每个时间步,解码器根据当前状态和上一步输出生成下一个词元。

例如,在序列生成过程中,我们可以在时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 用下面两个公式描述状态更新与输出生成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t = tanh ⁡ ( W   y t − 1 + U   h t − 1 + b ) h_t = \tanh(W\, y_{t-1} + U\, h_{t-1} + b) </math>ht=tanh(Wyt−1+Uht−1+b)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y t = s o f t m a x ( V   h t + c ) y_t = \mathrm{softmax}(V\, h_t + c) </math>yt=softmax(Vht+c)

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht 表示当前时刻的隐藏状态,
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> y t − 1 y_{t-1} </math>yt−1 是上一步生成的词元,
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> W , U , V W, U, V </math>W,U,V 为参数矩阵,
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> b , c b, c </math>b,c 为偏置项。

下面给出一个 PyTorch 实现的解码器接口示例代码:

python 复制代码
class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""

    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

四、合并编码器和解码器

编码器与解码器虽然是两个独立的组件,但在整个序列转换模型中必须协同工作。整体流程如下:

  1. 编码阶段 :输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 经过编码器处理,生成编码输出 enc_outputs
  2. 状态初始化 :将 enc_outputs 通过解码器的 init_state 方法转换为初始解码状态 dec_state
  3. 解码阶段 :解码器根据输入(如目标序列的部分信息)和 dec_state 逐步生成输出序列。

整个过程可以用下面的公式来描述:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y = D e c o d e r ( d e c _ X , i n i t _ s t a t e ( E n c o d e r ( e n c _ X ) ) ) y = \mathrm{Decoder}(dec\_X, \mathrm{init\_state}(\mathrm{Encoder}(enc\_X))) </math>y=Decoder(dec_X,init_state(Encoder(enc_X)))

下面是一个整合编码器和解码器的 PyTorch 接口示例代码:

python 复制代码
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""

    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        # enc_X:输入序列
        # dec_X:解码器接收的输入(例如目标序列的部分)
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

五、小结

本文介绍了编码器-解码器架构,这种架构在处理序列转换问题(如机器翻译)中具有广泛的应用。主要内容回顾如下:

  • 编码器:将变长输入序列

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( x 1 , x 2 , ... , x T ) X = (x_1, x_2, \ldots, x_T) </math>X=(x1,x2,...,xT)

    转换为固定形状的编码状态
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

  • 解码器:接收编码状态,利用状态更新公式

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t = tanh ⁡ ( W   y t − 1 + U   h t − 1 + b ) h_t = \tanh(W\, y_{t-1} + U\, h_{t-1} + b) </math>ht=tanh(Wyt−1+Uht−1+b)

    和输出生成公式
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y t = s o f t m a x ( V   h t + c ) y_t = \mathrm{softmax}(V\, h_t + c) </math>yt=softmax(Vht+c)

    逐步生成输出序列。

  • 整合:整个模型通过编码器输出与解码器初始状态的衔接实现完整的序列转换过程,即

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y = D e c o d e r ( d e c _ X , i n i t _ s t a t e ( E n c o d e r ( e n c _ X ) ) ) y = \mathrm{Decoder}(dec\_X, \mathrm{init\_state}(\mathrm{Encoder}(enc\_X))) </math>y=Decoder(dec_X,init_state(Encoder(enc_X)))

这种架构不仅为后续更复杂的循环神经网络(例如带注意力机制的模型)奠定了基础,也为实现机器翻译等任务提供了清晰的模块化设计思路。希望通过本文的讲解,大家能对深度学习中的序列转换模型有更直观的理解,并尝试阅读和实现相关代码。

相关推荐
Coovally AI模型快速验证31 分钟前
何恺明团队新突破:用“物理直觉“重构AI视觉系统,去噪神经网络让机器看懂世界规律
人工智能·深度学习·神经网络·机器学习·计算机视觉·目标跟踪·重构
AndrewHZ1 小时前
DeepSeek模型本地化部署方案及Python实现
人工智能·深度学习·算法·语言模型·ai助理·deepseek·本地化部署
大知闲闲哟1 小时前
深度学习Y1周:调用官方权重进行检测
人工智能·深度学习
幻风_huanfeng2 小时前
神经网络完成训练的详细过程
人工智能·pytorch·深度学习·神经网络·机器学习·优化算法包括梯度下降法
邪恶的贝利亚2 小时前
神经网络常用库-torch(基础操作张量)
人工智能·深度学习·神经网络
_zwy3 小时前
【C++ 函数模板】—— 模板参数推导、实例化策略与编译优化
c语言·c++·人工智能·深度学习·机器学习
凡人的AI工具箱10 小时前
PyTorch深度学习框架进阶学习计划 - 第20天:端到端图像生成系统
人工智能·pytorch·python·深度学习·学习·aigc·ai编程
ZhuBin36516 小时前
概率论与数理统计
人工智能·深度学习·机器学习·自动化·概率论
yutianzuijin16 小时前
Scaled_dot_product_attention(SDPA)使用详解
人工智能·深度学习·llm·大模型推理
WBingJ16 小时前
深度学习基础:线性代数本质4——矩阵乘法
深度学习·线性代数·矩阵