【Python机器学习】序列到序列建模和注意力机制——训练序列到序列网络

在Keras模型中,创建序列到序列模型的最后一个步骤是编译(compile)和拟合(fit)。与其他神经网络模型相比,唯一的区别在于,之前预测的是二元分类:是或不是。这里有一个单分类或多分类的问题。在每个时刻,必须确定许多"类别"中的哪一个是正确的,这里有很多类别。模型必须在所有可能的词条之间进行选择。因为预测是使字符或词,而不是二进制状态,所以将基于categorical_crossentropy损失函数进行优化,而不是基于binary_crossentropy。因此,这是需要对Keras代码中model.compile步骤进行的唯一更改:

python 复制代码
model.compile(optimizer='rmsprop',loss='categorical_crossentropy')
model.fit([encoder_input_data,decoder_input_data],
          decoder_target_data,
          batch_size=batch_size,
          epochs=epochs)

通过调用model.fit函数,这里正在训练序列到序列的端到端网络。

生成输出序列

在生成序列之前,需要获取训练层的结构,并将其重新组装以用于生成序列。首先,定义特定编码器的模型,这个模型将被用来生成思想向量:

python 复制代码
encoder_model=Model(inputs=encoder_inputs,outputs=encoder_states)

解码器的定义看起来不易理解。首先,我们将定义解码器的输入,这里使用Keras输入层,但是传递的是编码器网络生成的思想向量,而不是传递独热向量、字符或词嵌入。要注意的是编码器返回一个包含两种状态的列表,在调用之前定义的decoder_lstm时,需要将该列表传递给之前也定义过的稠密层。该层的输出将提供所有解码器输出词条的概率。

在每个时刻,预测概率最高的词条接下来将作为最有可能的词条返回给解码器网络,并作为新输入继续传递到解码器的下一个迭代步骤:

python 复制代码
#定义一个输入层以获取编码器状态
thought_input=[Input(shape=(num_neurons,)),Input(shape=(num_neurons,))]
#将编码器状态作为初始状态传递给LSTM层
decoder_outputs,state_h,state_c=decoder_lstm(decoder_inputs,initial_state=thought_input)
#更新后的LSTM状态将成为下一次迭代的新细胞状态
decoder_states=[state_h,state_c]
#将输出从LSTM传递到稠密层,以预测下一个词条
decoder_outputs=decoder_dense(decoder_outputs)

#最后一步是将解码器模型绑定在一起
decoder_model=Model(
    inputs=[decoder_inputs]+thought_input,
    output=[decoder_outputs]+decoder_states
)

一旦建立了模型,就可以根据一个独热编码的输入序列和最后生成的词条来预测思想向量,从而生成整个序列。在第一次迭代期间,target_seq被设置成初始词条。在接下来的所有迭代中,target_seq将使用最后生成的词条进行更新。这个循环会一直进行下去,直到达到序列元素的最大数量或者解码器生成一个终止词条,此时生成过程停止:

python 复制代码
thought=encoder_model.predict(input_seq)
while not stop_condition:
    output_token,h,c=decoder_model.predict(
        [target_seq]+thought
    )
相关推荐
代码游侠13 分钟前
学习笔记——数据结构学习
linux·开发语言·数据结构·笔记·学习
vvoennvv15 分钟前
【Python TensorFlow】 TCN-GRU时间序列卷积门控循环神经网络时序预测算法(附代码)
python·rnn·神经网络·机器学习·gru·tensorflow·tcn
沐知全栈开发22 分钟前
XML 验证器
开发语言
YJlio26 分钟前
[编程达人挑战赛] 用 PowerShell 写了一个“电脑一键初始化脚本”:从混乱到可复制的开发环境
数据库·人工智能·电脑
玦尘、39 分钟前
《统计学习方法》第4章——朴素贝叶斯法【学习笔记】
笔记·机器学习
RoboWizard1 小时前
PCIe 5.0 SSD有无独立缓存对性能影响大吗?Kingston FURY Renegade G5!
人工智能·缓存·电脑·金士顿
自学互联网1 小时前
使用Python构建钢铁行业生产监控系统:从理论到实践
开发语言·python
合作小小程序员小小店1 小时前
桌面开发,在线%医院管理%系统,基于vs2022,c#,winform,sql server数据
开发语言·数据库·sql·microsoft·c#
无心水1 小时前
【Python实战进阶】7、Python条件与循环实战详解:从基础语法到高级技巧
android·java·python·python列表推导式·python条件语句·python循环语句·python实战案例
一点★1 小时前
“equals”与“==”、“hashCode”的区别和使用场景
java·开发语言