pytorch之诗词生成5--train

先上代码:

复制代码
import tensorflow as tf
from dataset import PoetryDataGenerator, poetry, tokenizer
from model import model
import settings
import utils


class Evaluate(tf.keras.callbacks.Callback):
    """
    在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示
    """

    def __init__(self):
        super().__init__()
        # 给loss赋一个较大的初始值
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 在每个epoch训练完成后调用
        # 如果当前loss更低,就保存当前模型参数
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save(settings.BEST_MODEL_PATH)
        # 随机生成几首古体诗测试,查看训练效果
        print(tokenizer.id_to_token((3)))
        print(tokenizer.id_to_token(2))
        print(tokenizer.id_to_token(1))
        print(tokenizer.id_to_token(0))
        for i in range(settings.SHOW_NUM):
            print(utils.generate_random_poetry(tokenizer, model))


# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

接下来我们开始分析代码:

复制代码
class Evaluate(tf.keras.callbacks.Callback):

定义了一个回调类Evaluate,它是tf.keras.callbacks.Callback类的子类,回调函数是在训练的不同阶段调用的函数,用于执行额外的操作或监控模型的性能。

复制代码
def __init__(self):
    super().__init__()
    # 给loss赋一个较大的初始值
    self.lowest = 1e10

这是Evaluate类的构造函数__init__(self),在这个构造函数中,有以下操作:

  • super().init():调用父类tf.keras.callbacks.Callback的构造函数,确保父类的初始化操作得到执行。
  • self.lwest=1e10:将lowest属性初始化为一个较大的值1e10。这个属性用于跟踪最低的损失值。通常将其初始化为一个较大的值,确保在训练过程的初始阶段,任何较小的损失值都可以成为新的最新值。
复制代码
def on_epoch_end(self, epoch, logs=None):
    # 在每个epoch训练完成后调用
    # 如果当前loss更低,就保存当前模型参数
    if logs['loss'] <= self.lowest:
        self.lowest = logs['loss']
        model.save(settings.BEST_MODEL_PATH)
    # 随机生成几首古体诗测试,查看训练效果
    print(tokenizer.id_to_token((3)))
    print(tokenizer.id_to_token(2))
    print(tokenizer.id_to_token(1))
    print(tokenizer.id_to_token(0))
    for i in range(settings.SHOW_NUM):
        print(utils.generate_random_poetry(tokenizer, model))

在TensorFlow的回调函数中,logs是一个字典,其中包含了训练过程中的各种指标和损失值。它提供了一些有关模型的信息,可以用于监控和记录训练的进程。我们初始化我们的logs为空值,也就是没有记录任何信息。

logs['loss']表示访问logs字典的loss键。(由于函数是在每个epoch训练完成之后使用,训练之后logs就保存了模型的信息)。同时,如果损失值低于我们的预设值(第一轮),就将最低损失值进行更新。

然后使用模型的save方法,将模型的各种参数都保存到我们给定的路径中去。

然后我们就输出SHOW_NUM首我们生成的古诗。

复制代码
# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

这段代码调用PoetryDataGenerator类,我们将poetry传入模型,并进行随机打乱。

开始训练,使用fit_generator方法来训练模型。是模型对象的方法,用于使用生成器进行模型训练。它适用于数据较大无法一次加载到内存的情况,可以按照批次从生成器中获取数据进行训练。

steps_pre_epoch表示每个时钟周期加载多少个批次的数据进行训练。

callbacks=[Evaluate()]表示在训练模型过程中使用Evalute()函数,并将其作为一个回调函数传递给callbacks函数。回调函数是在训练的过程中特定时间被调用的函数,用于执行一些额外的操作,回调函数是在每个训练周期结束后被调用,在每轮训练的on_epoch_end事件中,回调函数会被触发并执行相应的操作。这意味着在每个训练周期结束时,回调函数会被调用用以执行自定义的评估操作。

相关推荐
鹏多多35 分钟前
纯前端人脸识别利器:face-api.js手把手深入解析教学
前端·javascript·人工智能
aneasystone本尊1 小时前
盘点 Chat2Graph 中的专家和工具
人工智能
这里有鱼汤1 小时前
小白必看:QMT里的miniQMT入门教程
后端·python
Baihai_IDP1 小时前
AI Agents 能自己开发工具自己使用吗?一项智能体自迭代能力研究
人工智能·面试·llm
大模型真好玩2 小时前
大模型工程面试经典(七)—如何评估大模型微调效果?
人工智能·面试·deepseek
黎燃10 小时前
短视频平台内容推荐算法优化:从协同过滤到多模态深度学习
人工智能
TF男孩11 小时前
ARQ:一款低成本的消息队列,实现每秒万级吞吐
后端·python·消息队列
飞哥数智坊12 小时前
多次尝试用 CodeBuddy 做小程序,最终我放弃了
人工智能·ai编程
后端小肥肠12 小时前
别再眼馋 10w + 治愈漫画!Coze 工作流 3 分钟出成品,小白可学
人工智能·aigc·coze
唐某人丶15 小时前
教你如何用 JS 实现 Agent 系统(2)—— 开发 ReAct 版本的“深度搜索”
前端·人工智能·aigc