先上代码:
import tensorflow as tf
from dataset import PoetryDataGenerator, poetry, tokenizer
from model import model
import settings
import utils
class Evaluate(tf.keras.callbacks.Callback):
"""
在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示
"""
def __init__(self):
super().__init__()
# 给loss赋一个较大的初始值
self.lowest = 1e10
def on_epoch_end(self, epoch, logs=None):
# 在每个epoch训练完成后调用
# 如果当前loss更低,就保存当前模型参数
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save(settings.BEST_MODEL_PATH)
# 随机生成几首古体诗测试,查看训练效果
print(tokenizer.id_to_token((3)))
print(tokenizer.id_to_token(2))
print(tokenizer.id_to_token(1))
print(tokenizer.id_to_token(0))
for i in range(settings.SHOW_NUM):
print(utils.generate_random_poetry(tokenizer, model))
# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
callbacks=[Evaluate()])
接下来我们开始分析代码:
class Evaluate(tf.keras.callbacks.Callback):
定义了一个回调类Evaluate,它是tf.keras.callbacks.Callback类的子类,回调函数是在训练的不同阶段调用的函数,用于执行额外的操作或监控模型的性能。
def __init__(self):
super().__init__()
# 给loss赋一个较大的初始值
self.lowest = 1e10
这是Evaluate类的构造函数__init__(self),在这个构造函数中,有以下操作:
- super().init():调用父类tf.keras.callbacks.Callback的构造函数,确保父类的初始化操作得到执行。
- self.lwest=1e10:将lowest属性初始化为一个较大的值1e10。这个属性用于跟踪最低的损失值。通常将其初始化为一个较大的值,确保在训练过程的初始阶段,任何较小的损失值都可以成为新的最新值。
def on_epoch_end(self, epoch, logs=None):
# 在每个epoch训练完成后调用
# 如果当前loss更低,就保存当前模型参数
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save(settings.BEST_MODEL_PATH)
# 随机生成几首古体诗测试,查看训练效果
print(tokenizer.id_to_token((3)))
print(tokenizer.id_to_token(2))
print(tokenizer.id_to_token(1))
print(tokenizer.id_to_token(0))
for i in range(settings.SHOW_NUM):
print(utils.generate_random_poetry(tokenizer, model))
在TensorFlow的回调函数中,logs是一个字典,其中包含了训练过程中的各种指标和损失值。它提供了一些有关模型的信息,可以用于监控和记录训练的进程。我们初始化我们的logs为空值,也就是没有记录任何信息。
logs['loss']表示访问logs字典的loss键。(由于函数是在每个epoch训练完成之后使用,训练之后logs就保存了模型的信息)。同时,如果损失值低于我们的预设值(第一轮),就将最低损失值进行更新。
然后使用模型的save方法,将模型的各种参数都保存到我们给定的路径中去。
然后我们就输出SHOW_NUM首我们生成的古诗。
# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,
callbacks=[Evaluate()])
这段代码调用PoetryDataGenerator类,我们将poetry传入模型,并进行随机打乱。
开始训练,使用fit_generator方法来训练模型。是模型对象的方法,用于使用生成器进行模型训练。它适用于数据较大无法一次加载到内存的情况,可以按照批次从生成器中获取数据进行训练。
steps_pre_epoch表示每个时钟周期加载多少个批次的数据进行训练。
callbacks=[Evaluate()]表示在训练模型过程中使用Evalute()函数,并将其作为一个回调函数传递给callbacks函数。回调函数是在训练的过程中特定时间被调用的函数,用于执行一些额外的操作,回调函数是在每个训练周期结束后被调用,在每轮训练的on_epoch_end事件中,回调函数会被触发并执行相应的操作。这意味着在每个训练周期结束时,回调函数会被调用用以执行自定义的评估操作。