深度学习之生成唐诗案例(Pytorch版)

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

运行结果:

复制代码
代码部分:
Dataset_Dataloader.py
python 复制代码
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def deal_tangshi():
    with open("tangshis.txt", "r", encoding="utf-8") as fr:
        lines = fr.read().strip().split("\n")

    tangshis = []
    for line in lines:
        splits = line.split(":")
        if len(splits) != 2:
            continue
        tangshis.append("S" + splits[1] + "E")

    word2idx = {"S": 0, "E": 1}
    word2idx_count = 2

    tangshi_ids = []

    for tangshi in tangshis:
        for word in tangshi:
            if word not in word2idx:
                word2idx[word] = word2idx_count
                word2idx_count += 1

    idx2word = {idx: w for w, idx in word2idx.items()}

    for tangshi in tangshis:
        tangshi_ids.extend([word2idx[w] for w in tangshi])

    return word2idx, idx2word, tangshis, word2idx_count, tangshi_ids


word2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()


class TangShiDataset(Dataset):
    def __init__(self, tangshi_ids, num_chars):
        # 语料数据
        self.tangshi_ids = tangshi_ids
        # 语料长度
        self.num_chars = num_chars
        # 词的数量
        self.word_count = len(self.tangshi_ids)
        # 句子数量
        self.number = self.word_count // self.num_chars

    def __len__(self):
        return self.number

    def __getitem__(self, idx):
        # 修正索引值到: [0, self.word_count - 1]
        start = min(max(idx, 0), self.word_count - self.num_chars - 2)

        x = self.tangshi_ids[start: start + self.num_chars]
        y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]

        return torch.tensor(x), torch.tensor(y)


def __test_Dataset():
    dataset = TangShiDataset(tangshi_ids, 8)
    x, y = dataset[0]

    print(x, y)


if __name__ == '__main__':
    # deal_tangshi()
    __test_Dataset()
复制代码
TangShiModel.py:唐诗的模型
python 复制代码
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as F


class TangShiRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # 初始化词嵌入层
        self.ebd = nn.Embedding(vocab_size, 128)
        # 循环网络层
        self.rnn = nn.RNN(128, 128, 1)
        # 输出层
        self.out = nn.Linear(128, vocab_size)

    def forward(self, inputs, hidden):

        embed = self.ebd(inputs)

        # 正则化层
        embed = F.dropout(embed, p=0.2)

        output, hidden = self.rnn(embed.transpose(0, 1), hidden)

        # 正则化层
        embed = F.dropout(output, p=0.2)

        output = self.out(output.squeeze())

        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 64, 128)

main.py:

python 复制代码
import time

import torch

from Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
    dataset = TangShiDataset(tangshi_ids, 128)
    epochs = 100
    model = TangShiRNN(word2idx_count).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for idx in range(epochs):
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
        start_time = time.time()
        total_loss = 0
        total_num = 0
        total_correct = 0
        total_correct_num = 0
        hidden = model.init_hidden()

        for x, y in tqdm(dataloader):
            x = x.to(device)
            y = y.to(device)
            # 隐藏状态
            hidden = model.init_hidden()
            hidden = hidden.to(device)
            # 模型计算
            output, hidden = model(x, hidden)
            # print(output.shape)
            # print(y.shape)
            # 计算损失
            loss = criterion(output.permute(1, 2, 0), y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            total_loss += loss.sum().item()
            total_num += len(y)
            total_correct_num += y.shape[0] * y.shape[1]
            # print(output.shape)
            total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()

        print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %
              (idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))

        torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")


if __name__ == '__main__':
    train()

predict.py

python 复制代码
import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def predict():
    model = TangShiRNN(word2idx_count)
    model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))

    model.eval()

    hidden = torch.zeros(1, 1, 128)

    start_word = input("输入第一个字:")

    flag = None

    tangshi_strs = []

    while True:
        if not flag:
            outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)
            tangshi_strs.append("S")
            flag = True
        else:
            tangshi_strs.append(start_word)
            outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)
            top_i = torch.argmax(outputs, dim=-1)

            if top_i.item() == word2idx["E"]:
                break

            print(top_i)

            start_word = idx2word[top_i.item()]
        print(tangshi_strs)


if __name__ == '__main__':
    predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.githttps://github.com/STZZ-1992/tangshi-generator.git

相关推荐
好家伙VCC3 分钟前
**神经编码新视角:用Python实现生物启发的神经信号压缩与解码算法**在人工智能飞速发展的今天
java·人工智能·python·算法
Navicat中国3 分钟前
如何使用 Ollama 配置 AI 助手 | Navicat 教程
数据库·人工智能·ai·navicat·ollama
@小匠4 小时前
Read Frog:一款开源的 AI 驱动浏览器语言学习扩展
人工智能·学习
网教盟人才服务平台7 小时前
“方班预备班盾立方人才培养计划”正式启动!
大数据·人工智能
芯智工坊7 小时前
第15章 Mosquitto生产环境部署实践
人工智能·mqtt·开源
菜菜艾7 小时前
基于llama.cpp部署私有大模型
linux·运维·服务器·人工智能·ai·云计算·ai编程
TDengine (老段)8 小时前
TDengine IDMP 可视化 —— 分享
大数据·数据库·人工智能·时序数据库·tdengine·涛思数据·时序数据
小真zzz8 小时前
搜极星:第三方多平台中立GEO洞察专家全面解析
人工智能·搜索引擎·seo·geo·中立·第三方平台
GreenTea8 小时前
从 Claw-Code 看 AI 驱动的大型项目开发:2 人 + 10 个自治 Agent 如何产出 48K 行 Rust 代码
前端·人工智能·后端
火山引擎开发者社区9 小时前
秒级创建实例,火山引擎 Milvus Serverless 让 AI Agent 开发更快更省
人工智能