pytorch之诗词生成5--train

先上代码:

import tensorflow as tf
from dataset import PoetryDataGenerator, poetry, tokenizer
from model import model
import settings
import utils


class Evaluate(tf.keras.callbacks.Callback):
    """
    在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示
    """

    def __init__(self):
        super().__init__()
        # 给loss赋一个较大的初始值
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 在每个epoch训练完成后调用
        # 如果当前loss更低,就保存当前模型参数
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save(settings.BEST_MODEL_PATH)
        # 随机生成几首古体诗测试,查看训练效果
        print(tokenizer.id_to_token((3)))
        print(tokenizer.id_to_token(2))
        print(tokenizer.id_to_token(1))
        print(tokenizer.id_to_token(0))
        for i in range(settings.SHOW_NUM):
            print(utils.generate_random_poetry(tokenizer, model))


# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

接下来我们开始分析代码:

复制代码
class Evaluate(tf.keras.callbacks.Callback):

定义了一个回调类Evaluate,它是tf.keras.callbacks.Callback类的子类,回调函数是在训练的不同阶段调用的函数,用于执行额外的操作或监控模型的性能。

复制代码
def __init__(self):
    super().__init__()
    # 给loss赋一个较大的初始值
    self.lowest = 1e10

这是Evaluate类的构造函数__init__(self),在这个构造函数中,有以下操作:

  • super().init():调用父类tf.keras.callbacks.Callback的构造函数,确保父类的初始化操作得到执行。
  • self.lwest=1e10:将lowest属性初始化为一个较大的值1e10。这个属性用于跟踪最低的损失值。通常将其初始化为一个较大的值,确保在训练过程的初始阶段,任何较小的损失值都可以成为新的最新值。
复制代码
def on_epoch_end(self, epoch, logs=None):
    # 在每个epoch训练完成后调用
    # 如果当前loss更低,就保存当前模型参数
    if logs['loss'] <= self.lowest:
        self.lowest = logs['loss']
        model.save(settings.BEST_MODEL_PATH)
    # 随机生成几首古体诗测试,查看训练效果
    print(tokenizer.id_to_token((3)))
    print(tokenizer.id_to_token(2))
    print(tokenizer.id_to_token(1))
    print(tokenizer.id_to_token(0))
    for i in range(settings.SHOW_NUM):
        print(utils.generate_random_poetry(tokenizer, model))

在TensorFlow的回调函数中,logs是一个字典,其中包含了训练过程中的各种指标和损失值。它提供了一些有关模型的信息,可以用于监控和记录训练的进程。我们初始化我们的logs为空值,也就是没有记录任何信息。

logs['loss']表示访问logs字典的loss键。(由于函数是在每个epoch训练完成之后使用,训练之后logs就保存了模型的信息)。同时,如果损失值低于我们的预设值(第一轮),就将最低损失值进行更新。

然后使用模型的save方法,将模型的各种参数都保存到我们给定的路径中去。

然后我们就输出SHOW_NUM首我们生成的古诗。

复制代码
# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

这段代码调用PoetryDataGenerator类,我们将poetry传入模型,并进行随机打乱。

开始训练,使用fit_generator方法来训练模型。是模型对象的方法,用于使用生成器进行模型训练。它适用于数据较大无法一次加载到内存的情况,可以按照批次从生成器中获取数据进行训练。

steps_pre_epoch表示每个时钟周期加载多少个批次的数据进行训练。

callbacks=[Evaluate()]表示在训练模型过程中使用Evalute()函数,并将其作为一个回调函数传递给callbacks函数。回调函数是在训练的过程中特定时间被调用的函数,用于执行一些额外的操作,回调函数是在每个训练周期结束后被调用,在每轮训练的on_epoch_end事件中,回调函数会被触发并执行相应的操作。这意味着在每个训练周期结束时,回调函数会被调用用以执行自定义的评估操作。

相关推荐
陈苏同学7 分钟前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
唐家小妹10 分钟前
介绍一款开源的 Modern GUI PySide6 / PyQt6的使用
python·pyqt
吾名招财25 分钟前
yolov5-7.0模型DNN加载函数及参数详解(重要)
c++·人工智能·yolo·dnn
羊小猪~~42 分钟前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
鼠鼠龙年发大财1 小时前
【鼠鼠学AI代码合集#7】概率
人工智能
Marst Code1 小时前
(Django)初步使用
后端·python·django
龙的爹23331 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现1 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
醒了就刷牙1 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习
985小水博一枚呀1 小时前
【对于Python爬虫的理解】数据挖掘、信息聚合、价格监控、新闻爬取等,附代码。
爬虫·python·深度学习·数据挖掘