现代循环神经网络6:编码器-解码器架构

一、编码器-解码器架构概述

在许多实际应用中(例如机器翻译),我们的输入和输出都是长度可变的序列。传统的神经网络难以直接处理这种变长输入输出问题。编码器-解码器(Encoder-Decoder)架构正是为此而设计的。它由两个主要组件构成:

  1. 编码器(Encoder): 接收一个长度可变的输入序列,并将其转换为一个固定形状的"编码状态"。
  2. 解码器(Decoder): 接收编码状态,并逐步生成长度可变的输出序列。

举个简单例子,对于英语到法语的机器翻译,输入序列可能是:

  • 英文输入:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( They , are , watching , . ) X = (\text{They}, \text{are}, \text{watching}, \text{.}) </math>X=(They,are,watching,.)

编码器将这个序列编码为一个固定的状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s(可以理解为一个向量),表示为:

  • 编码状态:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f 表示编码器内部的映射函数。然后,解码器基于这个状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s,逐个生成输出词元,得到翻译后的序列:

  • 法文输出:
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Y = ( Ils , regardent , . ) Y = (\text{Ils}, \text{regardent}, \text{.}) </math>Y=(Ils,regardent,.)

如图所示,整个过程可以直观地理解为将一个变长序列"压缩"为一个定长的状态,再由这个状态"解压缩"出另一个变长序列。


二、编码器

编码器的核心任务是将输入序列转换为一个固定形状的编码状态。假设输入序列为
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( x 1 , x 2 , ... , x T ) X = (x_1, x_2, \ldots, x_T) </math>X=(x1,x2,...,xT)

那么编码器可以看作是一个函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f,输出一个状态向量
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

在深度学习中,常用循环神经网络(RNN)、长短时记忆网络(LSTM)或门控循环单元(GRU)来实现编码器,因为它们能够处理序列数据。

下面是一个基于 PyTorch 框架的编码器接口示例代码:

python 复制代码
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""

    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *rgs):
        # X 为输入序列(长度可变)
        # 此处应实现编码逻辑,将 X 转换为固定形状的编码状态
        raise NotImplementedError

在这个接口中,任何继承 Encoder 的模型都需要实现 forward 方法。数学上可以认为这个过程为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

三、解码器

解码器负责将固定形状的编码状态转换为长度可变的输出序列。其核心在于:

  1. 初始化状态:将编码器输出转换为解码器的初始状态。
  2. 逐步生成输出:在每个时间步,解码器根据当前状态和上一步输出生成下一个词元。

例如,在序列生成过程中,我们可以在时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 用下面两个公式描述状态更新与输出生成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t = tanh ⁡ ( W   y t − 1 + U   h t − 1 + b ) h_t = \tanh(W\, y_{t-1} + U\, h_{t-1} + b) </math>ht=tanh(Wyt−1+Uht−1+b)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y t = s o f t m a x ( V   h t + c ) y_t = \mathrm{softmax}(V\, h_t + c) </math>yt=softmax(Vht+c)

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht 表示当前时刻的隐藏状态,
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> y t − 1 y_{t-1} </math>yt−1 是上一步生成的词元,
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> W , U , V W, U, V </math>W,U,V 为参数矩阵,
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> b , c b, c </math>b,c 为偏置项。

下面给出一个 PyTorch 实现的解码器接口示例代码:

python 复制代码
class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""

    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

四、合并编码器和解码器

编码器与解码器虽然是两个独立的组件,但在整个序列转换模型中必须协同工作。整体流程如下:

  1. 编码阶段 :输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 经过编码器处理,生成编码输出 enc_outputs
  2. 状态初始化 :将 enc_outputs 通过解码器的 init_state 方法转换为初始解码状态 dec_state
  3. 解码阶段 :解码器根据输入(如目标序列的部分信息)和 dec_state 逐步生成输出序列。

整个过程可以用下面的公式来描述:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y = D e c o d e r ( d e c _ X , i n i t _ s t a t e ( E n c o d e r ( e n c _ X ) ) ) y = \mathrm{Decoder}(dec\_X, \mathrm{init\_state}(\mathrm{Encoder}(enc\_X))) </math>y=Decoder(dec_X,init_state(Encoder(enc_X)))

下面是一个整合编码器和解码器的 PyTorch 接口示例代码:

python 复制代码
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""

    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        # enc_X:输入序列
        # dec_X:解码器接收的输入(例如目标序列的部分)
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

五、小结

本文介绍了编码器-解码器架构,这种架构在处理序列转换问题(如机器翻译)中具有广泛的应用。主要内容回顾如下:

  • 编码器:将变长输入序列

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X = ( x 1 , x 2 , ... , x T ) X = (x_1, x_2, \ldots, x_T) </math>X=(x1,x2,...,xT)

    转换为固定形状的编码状态
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s = f ( X ) s = f(X) </math>s=f(X)

  • 解码器:接收编码状态,利用状态更新公式

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t = tanh ⁡ ( W   y t − 1 + U   h t − 1 + b ) h_t = \tanh(W\, y_{t-1} + U\, h_{t-1} + b) </math>ht=tanh(Wyt−1+Uht−1+b)

    和输出生成公式
    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y t = s o f t m a x ( V   h t + c ) y_t = \mathrm{softmax}(V\, h_t + c) </math>yt=softmax(Vht+c)

    逐步生成输出序列。

  • 整合:整个模型通过编码器输出与解码器初始状态的衔接实现完整的序列转换过程,即

    <math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y = D e c o d e r ( d e c _ X , i n i t _ s t a t e ( E n c o d e r ( e n c _ X ) ) ) y = \mathrm{Decoder}(dec\_X, \mathrm{init\_state}(\mathrm{Encoder}(enc\_X))) </math>y=Decoder(dec_X,init_state(Encoder(enc_X)))

这种架构不仅为后续更复杂的循环神经网络(例如带注意力机制的模型)奠定了基础,也为实现机器翻译等任务提供了清晰的模块化设计思路。希望通过本文的讲解,大家能对深度学习中的序列转换模型有更直观的理解,并尝试阅读和实现相关代码。

相关推荐
yvestine5 小时前
自然语言处理——Transformer
人工智能·深度学习·自然语言处理·transformer
码上地球11 小时前
卷积神经网络设计指南:从理论到实践的经验总结
人工智能·深度学习·cnn
MYH51611 小时前
神经网络 隐藏层
人工智能·深度学习·神经网络
king of code porter14 小时前
深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏
人工智能·深度学习·剪枝
聚客AI16 小时前
PyTorch进阶:从自定义损失函数到生产部署全栈指南
人工智能·pytorch·深度学习
寻丶幽风19 小时前
论文阅读笔记——Muffin: Testing Deep Learning Libraries via Neural Architecture Fuzzing
论文阅读·笔记·深度学习·网络安全·差分测试
强盛小灵通专卖员20 小时前
DL00871-基于深度学习YOLOv11的盲人障碍物目标检测含完整数据集
人工智能·深度学习·yolo·目标检测·计算机视觉·无人机·核心期刊
Morpheon20 小时前
循环神经网络(RNN):从理论到翻译
人工智能·rnn·深度学习·循环神经网络
Blossom.11821 小时前
基于机器学习的智能故障预测系统:构建与优化
人工智能·python·深度学习·神经网络·机器学习·分类·tensorflow
DartistCode21 小时前
动手学深度学习pytorch(第一版)学习笔记汇总
pytorch·深度学习·学习