【动手学深度学习】9.6 编码器和解码器

正如我们在 9.5节中所讨论的, 机器翻译是序列转换模型的一个核心问题, 其输入和输出都是长度可变的序列。 为了处理这种类型的输入和输出, 我们可以设计一个包含两个主要组件的架构: 第一个组件是一个编码器(encoder): 它接受一个长度可变的序列作为输入, 并将其转换为具有固定形状的编码状态。 第二个组件是解码器(decoder): 它将固定形状的编码状态映射到长度可变的序列。 这被称为编码器-解码器(encoder-decoder)架构, 如 图9.6.1 所示。

我们以英语到法语的机器翻译为例: 给定一个英文的输入序列:"They""are""watching""."。 首先,这种"编码器-解码器"架构将长度可变的输入序列编码成一个"状态", 然后对该状态进行解码, 一个词元接着一个词元地生成翻译后的序列作为输出: "Ils""regordent""."。 由于"编码器-解码器"架构是形成后续章节中不同序列转换模型的基础, 因此本节将把这个架构转换为接口方便后面的代码实现。

编码器

在编码器接口中,我们只指定长度可变的序列作为编码器的输入X。 任何继承这个Encoder基类的模型将完成代码实现。

python 复制代码
from torch import nn


#@save
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError

解码器

在下面的解码器接口中,我们新增一个init_state函数, 用于将编码器的输出(enc_outputs)转换为编码后的状态。 注意,此步骤可能需要额外的输入,例如:输入序列的有效长度, 这在 9.5.4节中进行了解释。 为了逐个地生成长度可变的词元序列, 解码器在每个时间步都会将输入 (例如:在前一时间步生成的词元)和编码后的状态 映射成当前时间步的输出词元。

python 复制代码
#@save
class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError
python 复制代码
# 合并编码器和解码器
#@save
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

小结

  • "编码器-解码器"架构可以将长度可变的序列作为输入和输出,因此适用于机器翻译等序列转换问题。

  • 编码器将长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。

  • 解码器将具有固定形状的编码状态映射为长度可变的序列。

相关推荐
Mintopia26 分钟前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮1 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬1 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia2 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区2 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两5 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪5 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232555 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
程序员打怪兽5 小时前
详解Visual Transformer (ViT)网络模型
深度学习