RNN模型文本训练与推理

RNN模型文本训练与推理

这一篇文章, 就是对上一篇文章内容进行补充, 上一次我们还遗留着模型训练与文本推理的内容, 我们在这一篇文章里面都会讲到。

我们针对上一篇文章已经构建好的模型来做文本训练和推理, 上一篇文章构建的RNN模型如下:

python 复制代码
from torch import nn


class PoemRNN(nn.Module):

    def __init__(self,
                 vocab_size,
                 embedding_dim,
                 hidden_dim,
                 num_layers,
                 batch_first=True,
                 dropout=0.3) -> None:
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim,
                          hidden_dim,
                          num_layers,
                          batch_first=batch_first,
                          dropout=dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

    def forward(self, x, hidden):
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded, hidden)
        output = self.fc(output)
        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.num_layers, batch_size,
                           self.hidden_dim).to(device)

我们接下来使用多分类交叉熵损失函数:

python 复制代码
from torch import optim

criterion = nn.CrossEntropyLoss()  # 多分类交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器

多分类交叉熵损失函数原理:

接下来我们写关于生成诗词的模型:

python 复制代码
def generate_poem(model,
                  char_to_idx,
                  idx_to_char,
                  start_char, 
                  max_seq_length=32,
                  temperature=1.0):
    """
    更具给定的字符串,生成一首诗,推理

    Args:
        model: 模型-训练好的模型
        char_to_idx: 字符到索引的映射
        idx_to_char: 索引到字符的映射
        start_char: 起始字符
        max_seq_length: 最大序列长度
        temperature: 温度参数
    Returns:
        生成的诗
    """
    model.eval()

    # 拼接开始词元,转为索引序列
    current_seq = torch.tensor(
        [[BOS_IDX, char_to_idx.get(start_char, UNK_IDX)]]).to(device)

    # 初始化隐状态 每一层h_0
    hidden = model.init_hidden(batch_size=1)

    # 当前时刻写完的内容 春天
    generate = start_char

    # 生成序列
    with torch.no_grad():
        for _ in range(max_seq_length):
            # 前向传播
            output, hidden = model(current_seq, hidden)

            # 取最后一个时间步的输出
            probs = torch.softmax(output[0, -1, :] / temperature, dim=0)
            # 根据输入的概率分布,随机选择指定个元素的索引,选中的概率和输入的概率分布概率权重值相等的
            next_idx = torch.multinomial(probs, num_samples=1).item()
            # 根据概率分布采样选择一个token作为下一个token
            next_char = idx_to_char[next_idx]

            if next_char == EOS_TOKEN:
                break

            generate += next_char

            current_seq = torch.tensor([[BOS_IDX, next_idx]]).to(device)

    return generate

函数的参数详解:

Args:

model: 模型-训练好的模型

char_to_idx: 字符到索引的映射

idx_to_char: 索引到字符的映射

start_char: 起始字符

max_seq_length: 最大序列长度

temperature: 温度参数

由于生成诗词, 是在模型训练之后干的事情, 所以是model.eval()。

这边再讲一个关于temperature参数的意思:

  • temperature = 1.0 标准的softmax 没有温度影响

  • temperature = 5.0 温度高 结果的随机性 多样性会更高 多样但是可能不合理

  • temperature = 0.1 温度低 结果的确定性 多样性会更低 合理但是可能单调

在模型训练之前, 先设定模型的早停耐心值, 初始化损失(一开始是无穷大), 记录当前最佳模型状态的路径以及记录历史损失。

python 复制代码
patience = 10  # 早停耐心值,如果验证的损失在100次训练中都没有提升,则停止训练
best_val_loss = float('inf')  # 初始化无穷大

best_model_path = r'./models/best_model01.pth'
history = {'loss': []}

接下来我们训练模型并尝试让AI写出诗词:

python 复制代码
from tqdm import tqdm

first = True
for epoch in range(epochs):
    model.train()
    epoch_loss = 0  # 每个epoch的损失 每个批次的总损失

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, targets) in progress_bar:
        inputs = inputs.to(device)
        targets = targets.to(device)

        batch_size = inputs.size(0)
        hidden = model.init_hidden(batch_size)

        # 对每个批次进行训练
        optimizer.zero_grad()
        outputs, hidden = model(inputs, hidden)

        if first:
            print('原始预测和标记', outputs.shape, targets.shape)
            print('转换后',
                  outputs.view(-1, vocab_size).shape,
                  targets.view(-1).shape)
            first = False

        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        progress_bar.set_description(
            f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/(i+1):.4f}')

    avg_train_loss = epoch_loss / len(dataloader)
    history['loss'].append(avg_train_loss)

    if avg_train_loss < best_val_loss:
        best_val_loss = avg_train_loss
        counter = 0  # 重置计数器
        # 保存模型参数
        torch.save(model.state_dict(), best_model_path)
        print(
            f"Best model saved at Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}"
        )
    else:
        # 本次训练没有提升
        counter += 1  # 累计没有提升的次数
        if counter >= patience:
            print(f"Early stopping at Epoch {epoch+1}")
            break

    if (epoch + 1) % sample_every == 0:
        print(f'\nEpoch [{epoch+1}/{epochs}] 生成的样本')
        print(
            generate_poem(model, char_to_idx, idx_to_char, '春',
                          max_seq_length))

sample_every是超参数, 在上一篇文章里面有写, 上一篇文章的sample_every设置的是20, 代表每训练20轮, 就让AI去尝试的生成诗词。

这边, 我们用"春"字来做诗词的开头, 然后根据AI学习到的所有在vocab里面的所有字符, 来生成后面的文字。

运行结果:

text 复制代码
Epoch 1/300, Loss: 9.0558:   0%|          | 1/237 [00:00<01:07,  3.51it/s]
原始预测和标记 torch.Size([256, 97, 7487]) torch.Size([256, 97])
转换后 torch.Size([24832, 7487]) torch.Size([24832])
Epoch 1/300, Loss: 3.0043: 100%|██████████| 237/237 [00:19<00:00, 12.38it/s]
Best model saved at Epoch 1, Train Loss: 3.0043
Epoch 2/300, Loss: 2.6506: 100%|██████████| 237/237 [00:19<00:00, 12.34it/s]
Best model saved at Epoch 2, Train Loss: 2.6506
Epoch 3/300, Loss: 2.4866: 100%|██████████| 237/237 [00:19<00:00, 12.01it/s]
Best model saved at Epoch 3, Train Loss: 2.4866
Epoch 4/300, Loss: 2.3974: 100%|██████████| 237/237 [00:19<00:00, 12.24it/s]
Best model saved at Epoch 4, Train Loss: 2.3974
Epoch 5/300, Loss: 2.3380: 100%|██████████| 237/237 [00:19<00:00, 12.18it/s]
Best model saved at Epoch 5, Train Loss: 2.3380
Epoch 6/300, Loss: 2.2921: 100%|██████████| 237/237 [00:19<00:00, 12.19it/s]
Best model saved at Epoch 6, Train Loss: 2.2921
Epoch 7/300, Loss: 2.2568: 100%|██████████| 237/237 [00:19<00:00, 12.22it/s]
Best model saved at Epoch 7, Train Loss: 2.2568
Epoch 8/300, Loss: 2.2274: 100%|██████████| 237/237 [00:19<00:00, 12.28it/s]
Best model saved at Epoch 8, Train Loss: 2.2274
Epoch 9/300, Loss: 2.2020: 100%|██████████| 237/237 [00:19<00:00, 12.28it/s]
Best model saved at Epoch 9, Train Loss: 2.2020
Epoch 10/300, Loss: 2.1800: 100%|██████████| 237/237 [00:19<00:00, 12.25it/s]
Best model saved at Epoch 10, Train Loss: 2.1800
Epoch 11/300, Loss: 2.1645: 100%|██████████| 237/237 [00:19<00:00, 12.28it/s]
Best model saved at Epoch 11, Train Loss: 2.1645
Epoch 12/300, Loss: 2.1526: 100%|██████████| 237/237 [00:19<00:00, 12.23it/s]
Best model saved at Epoch 12, Train Loss: 2.1526
Epoch 13/300, Loss: 2.1360: 100%|██████████| 237/237 [00:19<00:00, 12.30it/s]
Best model saved at Epoch 13, Train Loss: 2.1360
Epoch 14/300, Loss: 2.1252: 100%|██████████| 237/237 [00:19<00:00, 12.30it/s]
Best model saved at Epoch 14, Train Loss: 2.1252
Epoch 15/300, Loss: 2.1134: 100%|██████████| 237/237 [00:19<00:00, 12.24it/s]
Best model saved at Epoch 15, Train Loss: 2.1134
Epoch 16/300, Loss: 2.1025: 100%|██████████| 237/237 [00:19<00:00, 12.27it/s]
Best model saved at Epoch 16, Train Loss: 2.1025
Epoch 17/300, Loss: 2.0924: 100%|██████████| 237/237 [00:19<00:00, 12.16it/s]
Best model saved at Epoch 17, Train Loss: 2.0924
Epoch 18/300, Loss: 2.0960: 100%|██████████| 237/237 [00:19<00:00, 12.28it/s]
Epoch 19/300, Loss: 2.0854: 100%|██████████| 237/237 [00:19<00:00, 12.26it/s]
Best model saved at Epoch 19, Train Loss: 2.0854
Epoch 20/300, Loss: 2.0707: 100%|██████████| 237/237 [00:19<00:00, 12.18it/s]
Best model saved at Epoch 20, Train Loss: 2.0707

Epoch [20/300] 生成的样本
春草堂,褰翠仪容盛,遨会直指为。李之作含灵,步以和虚。范恭尧,岂启德属聿制遗庆,靡律。授絜,俯不克成。
Epoch 21/300, Loss: 2.0656: 100%|██████████| 237/237 [00:19<00:00, 12.18it/s]
Best model saved at Epoch 21, Train Loss: 2.0656
Epoch 22/300, Loss: 2.0566: 100%|██████████| 237/237 [00:19<00:00, 12.11it/s]
Best model saved at Epoch 22, Train Loss: 2.0566
Epoch 23/300, Loss: 2.0621: 100%|██████████| 237/237 [00:19<00:00, 12.11it/s]
Epoch 24/300, Loss: 2.0492: 100%|██████████| 237/237 [00:19<00:00, 12.12it/s]
Best model saved at Epoch 24, Train Loss: 2.0492
Epoch 25/300, Loss: 2.0441: 100%|██████████| 237/237 [00:19<00:00, 12.17it/s]
Best model saved at Epoch 25, Train Loss: 2.0441
Epoch 26/300, Loss: 2.0389: 100%|██████████| 237/237 [00:19<00:00, 12.19it/s]
Best model saved at Epoch 26, Train Loss: 2.0389
Epoch 27/300, Loss: 2.0441: 100%|██████████| 237/237 [00:19<00:00, 12.14it/s]
Epoch 28/300, Loss: 2.0275: 100%|██████████| 237/237 [00:19<00:00, 12.17it/s]
Best model saved at Epoch 28, Train Loss: 2.0275
Epoch 29/300, Loss: 2.0258: 100%|██████████| 237/237 [00:19<00:00, 12.23it/s]
Best model saved at Epoch 29, Train Loss: 2.0258
Epoch 30/300, Loss: 2.0179: 100%|██████████| 237/237 [00:19<00:00, 12.21it/s]
Best model saved at Epoch 30, Train Loss: 2.0179
Epoch 31/300, Loss: 2.0176: 100%|██████████| 237/237 [00:19<00:00, 12.18it/s]
Best model saved at Epoch 31, Train Loss: 2.0176
Epoch 32/300, Loss: 2.0253: 100%|██████████| 237/237 [00:19<00:00, 11.96it/s]
Epoch 33/300, Loss: 2.0543: 100%|██████████| 237/237 [00:20<00:00, 11.84it/s]
Epoch 34/300, Loss: 2.0174: 100%|██████████| 237/237 [00:20<00:00, 11.78it/s]
Best model saved at Epoch 34, Train Loss: 2.0174
Epoch 35/300, Loss: 2.0141: 100%|██████████| 237/237 [00:19<00:00, 11.92it/s]
Best model saved at Epoch 35, Train Loss: 2.0141
Epoch 36/300, Loss: 2.0094: 100%|██████████| 237/237 [00:20<00:00, 11.84it/s]
Best model saved at Epoch 36, Train Loss: 2.0094
Epoch 37/300, Loss: 2.0010: 100%|██████████| 237/237 [00:19<00:00, 12.11it/s]
Best model saved at Epoch 37, Train Loss: 2.0010
Epoch 38/300, Loss: 1.9977: 100%|██████████| 237/237 [00:19<00:00, 12.00it/s]
Best model saved at Epoch 38, Train Loss: 1.9977
Epoch 39/300, Loss: 1.9924: 100%|██████████| 237/237 [00:19<00:00, 11.91it/s]
Best model saved at Epoch 39, Train Loss: 1.9924
Epoch 40/300, Loss: 2.0381: 100%|██████████| 237/237 [00:19<00:00, 11.92it/s]

Epoch [40/300] 生成的样本
春风帘外,知酒懒归,归思逢花上,春门深。
Epoch 41/300, Loss: 2.0041: 100%|██████████| 237/237 [00:20<00:00, 11.76it/s]
Epoch 42/300, Loss: 1.9894: 100%|██████████| 237/237 [00:20<00:00, 11.73it/s]
Best model saved at Epoch 42, Train Loss: 1.9894
Epoch 43/300, Loss: 2.0076: 100%|██████████| 237/237 [00:19<00:00, 11.94it/s]
Epoch 44/300, Loss: 2.0054: 100%|██████████| 237/237 [00:19<00:00, 12.08it/s]
Epoch 45/300, Loss: 1.9960: 100%|██████████| 237/237 [00:19<00:00, 12.01it/s]
Epoch 46/300, Loss: 1.9910: 100%|██████████| 237/237 [00:19<00:00, 11.87it/s]
Epoch 47/300, Loss: 1.9791: 100%|██████████| 237/237 [00:20<00:00, 11.76it/s]
Best model saved at Epoch 47, Train Loss: 1.9791
Epoch 48/300, Loss: 1.9723: 100%|██████████| 237/237 [00:19<00:00, 11.96it/s]
Best model saved at Epoch 48, Train Loss: 1.9723
Epoch 49/300, Loss: 1.9674: 100%|██████████| 237/237 [00:19<00:00, 11.94it/s]
Best model saved at Epoch 49, Train Loss: 1.9674
Epoch 50/300, Loss: 1.9648: 100%|██████████| 237/237 [00:19<00:00, 11.91it/s]
Best model saved at Epoch 50, Train Loss: 1.9648
Epoch 51/300, Loss: 1.9701: 100%|██████████| 237/237 [00:19<00:00, 11.93it/s]
Epoch 52/300, Loss: 1.9878: 100%|██████████| 237/237 [00:19<00:00, 11.97it/s]
Epoch 53/300, Loss: 2.0315: 100%|██████████| 237/237 [00:19<00:00, 11.86it/s]
Epoch 54/300, Loss: 2.0441: 100%|██████████| 237/237 [00:19<00:00, 11.99it/s]
Epoch 55/300, Loss: 2.0078: 100%|██████████| 237/237 [00:20<00:00, 11.85it/s]
Epoch 56/300, Loss: 2.0001: 100%|██████████| 237/237 [00:19<00:00, 11.98it/s]
Epoch 57/300, Loss: 1.9852: 100%|██████████| 237/237 [00:19<00:00, 12.12it/s]
Epoch 58/300, Loss: 1.9759: 100%|██████████| 237/237 [00:19<00:00, 12.17it/s]
Epoch 59/300, Loss: 1.9719: 100%|██████████| 237/237 [00:19<00:00, 12.08it/s]
Epoch 60/300, Loss: 1.9659: 100%|██████████| 237/237 [00:19<00:00, 11.96it/s]
Early stopping at Epoch 60

代码中的tqdm指的是进度条, 如果没有这个库的话, 需要通过命令下载一下。

我们分别看一看每20轮生成的样本:

text 复制代码
Epoch [20/300] 生成的样本
春草堂,褰翠仪容盛,遨会直指为。李之作含灵,步以和虚。范恭尧,岂启德属聿制遗庆,靡律。授絜,俯不克成。
Epoch [40/300] 生成的样本
春风帘外,知酒懒归,归思逢花上,春门深。

虽然看上去这生成的诗词不怎么样, 但是这个AI已经学会了如何组词, 至少能够做到前后两个字的衔接, 因为我们这个模型的参数只有六百多万, 别的大模型的参数有几十个亿甚至上百个亿, 所以对照其它的那些大模型, 我们这个模型还是属于小卡拉, 根本没有可比性, 但是这个模型, 可以让我们更能够帮助我们对RNN模型有更深刻的理解, 如果大家觉得有什么优化的点, 可以自己尝试去优化模型, 生成更好的诗词。目前来讲, 只是RNN生成诗词文本的初步模型。

我们画个图, 来记录模型训练的损失:

python 复制代码
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()
plt.figure(figsize=(10, 6))
plt.plot(history['loss'], label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

运行结果:

我们将保存好的模型加载进来, 然后使用模型再去生成文本:

python 复制代码
model.load_state_dict(torch.load(best_model_path))
while True:
    user_input = input("请输入一个起始字符: ")
    if user_input == 'exit':
        break
    print(
        generate_poem(model, char_to_idx, idx_to_char, user_input,
                      max_seq_length))

运行:

在这里面输入"春"字。

结果:

text 复制代码
春风景街,鸣簇春鸠舄,独着幽花,冷衾枕次,过语,道似与谁,江天竺玄一峰城内,自炼前时心,悄易得成幽显初,鹤异远孤步,其阴兮桥境,清昏,汲绿阴苔半径,鸟忘,凄凉夜不得行,荒浮生地,猿风昼无精,无劳止
夏腊吃,人闲苦饮吃食腻,劝小焦浆赠味珍吉贞诚非,玉孝明猛备,其厚,朱耀赫执戟,骨颂植,大圣箴规号德晻,景迁超。方惟德极,礼郊祀,曰之浮。刘斯,敢望浮名。神功,庆洪击郜极神颂,四岸,天肩舒焕。
夏雪初,林忧复低开,意高高百尺绳孤峻崇,参怒。岑危,八三灵液,满讹以诚哲,礼惟。韩大九垓,位煌,合象九楹。灵均,石抗。博单,百典,撤泄。县如,耸稽净镇虔。雾凉,祝维皇唐福穰季兆,贵咸亿载轮敢逾九岁
秋毫白云,临江涛开三千场,冰齐十一千万物,半妥万物庖。惠休,千祀,震禹坛上,郁陶逢夏,三四。金风,重而洁天声灵八帝烂唐,无四。潘冠陆哲以符华,穆其。禋克,克配宇通舜,正颂临海安夷海正乾坤,辉其。造
冬春草,江童五马笑,树将采桑。

我们可以发现, 同样的"春"字开头, 每次生成出来的诗词内容都不一样。目前的初步模型生成出来的质量确实没有自己想象中那么好, 不过AI也能够大致的知道后面该跟哪个字, 就是可以考虑前后两个字的顺序但是不能考虑更多个文本的顺序, 因为RNN模型是一个字一个字分析循环过去, 比如I love you, 它的起始是bos(bos是起始符) -> I, 然后是I -> love, 然后再是 love -> you, 最后还有个结束符号eos, 那就是you -> eos.

好了, 我们的这篇RNN模型文本训练与推理的这篇文章的内容就到此结束了。

以上就是RNN模型文本训练与推理的所有内容了, 如果有哪里不懂的地方,可以把问题打在评论区, 欢迎大家在评论区交流!!!如果我有写错的地方, 望大家指正, 也可以联系我, 让我们一起努力, 继续不断的进步。学习是个漫长的过程, 需要我们不断的去学习并掌握消化知识点, 有不懂或概念模糊不理解的情况下,一定要赶紧的解决问题, 否则问题只会越来越多, 漏洞也就越老越大。人生路漫漫, 白鹭常相伴!!!

相关推荐
摸鱼仙人~16 小时前
中国内需市场的战略重构与潜在增长点深度研究报告
大数据·人工智能
一招定胜负16 小时前
自然语言处理CBOW模型:基于上下文预测中间词
人工智能·深度学习·机器学习
jimmyleeee16 小时前
人工智能基础知识笔记三十二:向量数据库的查找类型和工作原理
人工智能·笔记
像风一样自由202016 小时前
MCP 入门指南:让 AI 连接真实世界
人工智能
尚可签17 小时前
怎么降低AI率(文本)?最近发现了非常简单的思路
人工智能
咕噜企业分发小米17 小时前
阿里云AI教育产品如何助力企业提升客户粘性?
人工智能·microsoft·阿里云
华如锦17 小时前
四:从零搭建一个RAG
java·开发语言·人工智能·python·机器学习·spring cloud·计算机视觉
F_D_Z17 小时前
TensorFlow Playground 交互式神经网络可视化工具
人工智能·神经网络·tensorflow
杭州泽沃电子科技有限公司17 小时前
核电的“热血管”与它的智能脉搏:热转换在线监测如何守护能源生命线
人工智能·在线监测
yuzhiboyouye17 小时前
指引上调是什么意思
人工智能