pytorch之诗词生成5--train

先上代码:

复制代码
import tensorflow as tf
from dataset import PoetryDataGenerator, poetry, tokenizer
from model import model
import settings
import utils


class Evaluate(tf.keras.callbacks.Callback):
    """
    在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示
    """

    def __init__(self):
        super().__init__()
        # 给loss赋一个较大的初始值
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 在每个epoch训练完成后调用
        # 如果当前loss更低,就保存当前模型参数
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save(settings.BEST_MODEL_PATH)
        # 随机生成几首古体诗测试,查看训练效果
        print(tokenizer.id_to_token((3)))
        print(tokenizer.id_to_token(2))
        print(tokenizer.id_to_token(1))
        print(tokenizer.id_to_token(0))
        for i in range(settings.SHOW_NUM):
            print(utils.generate_random_poetry(tokenizer, model))


# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

接下来我们开始分析代码:

复制代码
class Evaluate(tf.keras.callbacks.Callback):

定义了一个回调类Evaluate,它是tf.keras.callbacks.Callback类的子类,回调函数是在训练的不同阶段调用的函数,用于执行额外的操作或监控模型的性能。

复制代码
def __init__(self):
    super().__init__()
    # 给loss赋一个较大的初始值
    self.lowest = 1e10

这是Evaluate类的构造函数__init__(self),在这个构造函数中,有以下操作:

  • super().init():调用父类tf.keras.callbacks.Callback的构造函数,确保父类的初始化操作得到执行。
  • self.lwest=1e10:将lowest属性初始化为一个较大的值1e10。这个属性用于跟踪最低的损失值。通常将其初始化为一个较大的值,确保在训练过程的初始阶段,任何较小的损失值都可以成为新的最新值。
复制代码
def on_epoch_end(self, epoch, logs=None):
    # 在每个epoch训练完成后调用
    # 如果当前loss更低,就保存当前模型参数
    if logs['loss'] <= self.lowest:
        self.lowest = logs['loss']
        model.save(settings.BEST_MODEL_PATH)
    # 随机生成几首古体诗测试,查看训练效果
    print(tokenizer.id_to_token((3)))
    print(tokenizer.id_to_token(2))
    print(tokenizer.id_to_token(1))
    print(tokenizer.id_to_token(0))
    for i in range(settings.SHOW_NUM):
        print(utils.generate_random_poetry(tokenizer, model))

在TensorFlow的回调函数中,logs是一个字典,其中包含了训练过程中的各种指标和损失值。它提供了一些有关模型的信息,可以用于监控和记录训练的进程。我们初始化我们的logs为空值,也就是没有记录任何信息。

logs'loss'表示访问logs字典的loss键。(由于函数是在每个epoch训练完成之后使用,训练之后logs就保存了模型的信息)。同时,如果损失值低于我们的预设值(第一轮),就将最低损失值进行更新。

然后使用模型的save方法,将模型的各种参数都保存到我们给定的路径中去。

然后我们就输出SHOW_NUM首我们生成的古诗。

复制代码
# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

这段代码调用PoetryDataGenerator类,我们将poetry传入模型,并进行随机打乱。

开始训练,使用fit_generator方法来训练模型。是模型对象的方法,用于使用生成器进行模型训练。它适用于数据较大无法一次加载到内存的情况,可以按照批次从生成器中获取数据进行训练。

steps_pre_epoch表示每个时钟周期加载多少个批次的数据进行训练。

callbacks=Evaluate()表示在训练模型过程中使用Evalute()函数,并将其作为一个回调函数传递给callbacks函数。回调函数是在训练的过程中特定时间被调用的函数,用于执行一些额外的操作,回调函数是在每个训练周期结束后被调用,在每轮训练的on_epoch_end事件中,回调函数会被触发并执行相应的操作。这意味着在每个训练周期结束时,回调函数会被调用用以执行自定义的评估操作。

相关推荐
AI的探索之旅几秒前
AI Agent替我做原理图:立创EDA + CubeMX + 知识库的三合一工作流
人工智能
大气的小蜜蜂1 分钟前
基于Python+PyQt5+SQLite的药房管理系统实现:事务一致性与界面解耦全流程解析
python·qt·sqlite
阿拉斯攀登5 分钟前
Agent 框架对比:LangChain / AutoGPT / CrewAI
人工智能·langchain·agent·rag·function
丹宇码农7 分钟前
基于 Top-K Logits 的 LLM 知识蒸馏实战
人工智能·ai·ai编程
lkshop11 分钟前
自研 GEO 系统实战:从架构设计到“一键投喂”多平台 AI 大模型
人工智能·geo
维基框架13 分钟前
Claude Mythos Preview 发布后严重漏洞激增:安全还是营销?
人工智能·安全
Csvn14 分钟前
AI Prompt 炼金术:让 AI 写代码 一次过
人工智能
Csvn18 分钟前
AI 编程提效核心技巧(直接复制套用,大幅减少手写代码时间)
人工智能
delishcomcn19 分钟前
预见性切割:机器学习如何提前预警碳带分切机的报废风险
人工智能·机器学习
拧AI螺丝20 分钟前
你往 AI 里装的那些 skill,打开看过一眼吗?
人工智能·agent