【附jupyter源码】使用长短期记忆网络(LSTM)实现一个小说写作AI——以训练《西游记》为例

使用长短期记忆网络(LSTM)实现一个小说写作AI------以训练《西游记》为例

这个项目使用LSTM长短期记忆网络训练了一个字符级的文本生成模型,喂了一整本《西游记》进去,模型就能自己写出一段神魔小说风的文字。

整个项目的源码统一整理成一个jupyter notebook文件。从数据预处理到模型训练、推理生成,每一块代码都有涉及。

项目源码地址:https://github.com/anjuxi/LSTM-Journey-to-the-West

项目概览

  • 任务:字符级语言模型,给定一段起始文字,续写后续内容
  • 语料:《西游记》全文,约74万字符
  • 模型:3层LSTM + 字符Embedding + 线性层,参数量约900万
  • 训练环境:Kaggle T4 GPU
  • 特色
    • 字符级建模,词汇表大小4198,直接学习汉字分布
    • 使用梯度裁剪、学习率自适应衰减(ReduceLROnPlateau)
    • 支持断点续训、混合精度加速
    • 生成时采用温度采样(Temperature Sampling),控制随机性

1. 基础配置与随机种子

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import os
import re
import pickle
from collections import Counter

set_seed函数固定了PyTorch、NumPy和Python内置随机数的种子,并且设置了torch.backends.cudnn.deterministic = True。这样做是为了让实验可复现,尤其在调试生成效果时,同一个种子跑出来的模型效果应该一致。

python 复制代码
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

2. 超参数集中管理

我做项目习惯把可能调整的变量放在最前面,这样不会到处找。这里直接定义成"宏常量":

python 复制代码
# 数据相关
DATA_PATH = "./data/xyj.txt"      # 西游记文本路径
SEQ_LENGTH = 100                  # 每个样本的字符长度
BATCH_SIZE = 1024                 # 批次大小

# 模型结构
EMBED_DIM = 256                   # 嵌入维度
HIDDEN_DIM = 512                  # LSTM隐藏层维度
NUM_LAYERS = 3                    # LSTM层数
DROPOUT = 0.3                    # Dropout概率

# 训练参数
EPOCHS = 50
LEARNING_RATE = 0.001
CLIP_GRAD = 5                     # 梯度裁剪阈值
PRINT_EVERY = 100
SAVE_EVERY = 1                    # 每5轮保存一次模型并输出一段生成示例

# 续训开关
CONTINUE_TRAIN = True
MODEL_SAVE_PATH = "./models/lstm_xiyouji.pth"
VOCAB_SAVE_PATH = "./models/vocab.pkl"

# 生成参数
GENERATE_LENGTH = 500
TEMPERATURE = 0.1                # 写死没用,后面生成时动态指定

os.makedirs("./models", exist_ok=True)

3. 数据预处理

3.1 加载原始文本

python 复制代码
def load_and_preprocess_data(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        text = f.read()
    # 把换行、多余空白压缩成单个空格
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    print(f"文本总长度: {len(text)} 字符")
    print(f"唯一字符数: {len(set(text))}")
    return text

《西游记》原文下载后约有74万多字符。re.sub(r'\s+', ' ', text)把所有连续的空白字符(空格、换行、制表)都替换成了一个空格,这样模型就不需要去学习换行符这些没意义的特征。但这样就丢失了段落结构,如果想让模型自动分段,可以保留换行符作为特殊字符。

3.2 构建字符级数据集

字符级建模的核心是把每一个汉字(包括标点、空格)当成一个类别。TextDataset做了这些事:

python 复制代码
class TextDataset(Dataset):
    def __init__(self, text, seq_length):
        self.seq_length = seq_length
        self.text = text
        self.chars = sorted(list(set(text)))
        self.char2idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx2char = {i: ch for i, ch in enumerate(self.chars)}
        self.vocab_size = len(self.chars)
        self.data = [self.char2idx[ch] for ch in text]
    
    def __len__(self):
        return len(self.data) - self.seq_length
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx+self.seq_length], dtype=torch.long)
        y = torch.tensor(self.data[idx+1:idx+self.seq_length+1], dtype=torch.long)
        return x, y
  • 词汇表不分词,直接字符去重排序,共4198个。
  • 整个文本转成索引列表self.data
  • __getitem__返回一对(x, y)y就是x整体后移一个位置的结果。这是语言模型的标准做法:给定前100个字符,预测下一个字符。

示例:假设文本为"孙悟空大闹天宫",x["孙","悟","空"]y["悟","空","大"]

3.3 保存和加载词汇表

python 复制代码
def save_vocab(char2idx, idx2char, save_path):
    vocab = {'char2idx': char2idx, 'idx2char': idx2char}
    with open(save_path, 'wb') as f:
        pickle.dump(vocab, f)
    print(f"词汇表已保存到: {save_path} !")

def load_vocab(load_path):
    with open(load_path, 'rb') as f:
        vocab = pickle.load(f)
    print(f"词汇表已从 {load_path} 加载")
    return vocab['char2idx'], vocab['idx2char']

推理时必须用和训练时完全一样的词汇表,否则索引对不上。所以训练时用save_vocab存一份,后面加载模型时一并载入。

3.4 创建DataLoader

python 复制代码
dataset = TextDataset(text, SEQ_LENGTH)
vocab_size = dataset.vocab_size
char2idx = dataset.char2idx
idx2char = dataset.idx2char
save_vocab(char2idx, idx2char, VOCAB_SAVE_PATH)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"数据集大小: {len(dataset)} 样本")
print(f"词汇表大小: {vocab_size}")
print(f"批次数量: {len(dataloader)}")

输出:

makefile 复制代码
数据集大小: 734850 样本
词汇表大小: 4198
批次数量: 718

73万多个样本,每个epoch大概718个batch。num_workers=4多进程加载数据,避免了IO成为瓶颈。


4. 模型搭建

模型结构并不复杂,一个Embedding层、一个3层LSTM、一个Dropout、一个线性层:

python 复制代码
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)
  • batch_first=True 让输入张量的维度是 (batch, seq_len, feature),处理起来更直观。
  • LSTM的dropout参数只在num_layers > 1时有效,它是层与层之间的dropout,和后面额外加的self.dropout不冲突:前者是LSTM内部多层传递时的dropout,后者是LSTM输出后、进入全连接前的dropout。
  • 输出层fc把512维的隐藏状态映射到4198维的词汇表空间,不做softmax,因为计算loss时用CrossEntropyLoss会自动处理。

在前向传播时:

python 复制代码
def forward(self, x, hidden=None):
    batch_size = x.size(0)
    embed = self.embedding(x)               # (batch, seq_len, embed_dim)
    if hidden is None:
        lstm_out, hidden = self.lstm(embed)
    else:
        lstm_out, hidden = self.lstm(embed, hidden)
    lstm_out = self.dropout(lstm_out)
    output = self.fc(lstm_out)              # (batch, seq_len, vocab_size)
    return output, hidden

这里支持传入hidden,是因为文本生成时需要一步一步迭代,保持隐藏状态就能避免每次都重新计算整个序列的历史,效率高很多。

初始化隐藏状态的方法:

python 复制代码
def init_hidden(self, batch_size):
    weight = next(self.parameters())
    hidden = (
        weight.new_zeros(self.num_layers, batch_size, self.hidden_dim),
        weight.new_zeros(self.num_layers, batch_size, self.hidden_dim)
    )
    return hidden

这样初始化的张量会放在和模型参数相同的设备上,不用手动.to(device)

模型打印:

ini 复制代码
LSTMModel(
  (embedding): Embedding(4198, 256)
  (lstm): LSTM(256, 512, num_layers=3, batch_first=True, dropout=0.3)
  (dropout): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=512, out_features=4198, bias=True)
)

总参数量: 9,007,718
可训练参数量: 9,007,718
模型大小: 34.36 MB

900万参数对RNN来说不小,不过这里用GPU完全能应付。如果用单卡,34MB的模型很小,推理也很快。


5. 训练模块

训练函数是整个代码中最复杂的一块,值得细细拆分。

5.1 断点续训

continue_train=True时,如果检测到已有模型文件,就会加载权重继续训练:

python 复制代码
checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)
state_dict = checkpoint['model_state_dict']
if isinstance(model, nn.DataParallel):
    model.module.load_state_dict(state_dict)
else:
    model.load_state_dict(state_dict)
start_epoch = checkpoint.get('epoch', 0)

需要注意的是,多GPU训练保存模型时state_dict的key会是module.xxx,加载时如果当前不是DataParallel就要去掉前缀,或者直接按照实例类型加载。这里我直接根据isinstance判断。

5.2 优化器与调度器

python 复制代码
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

我用了ReduceLROnPlateau而不是StepLR,因为语言模型loss下降趋势不稳定,当连续3个epoch的loss不降时就将学习率减半,比固定epoch衰减更灵活。一般前面几个epoch loss下降得很快,后面开始震荡,此时减少学习率能让模型继续慢慢收敛。

5.3 混合精度与梯度裁剪

python 复制代码
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

在循环内部:

python 复制代码
if scaler is not None:
    with torch.cuda.amp.autocast():
        outputs, _ = model(inputs)
        loss = criterion(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
    scaler.step(optimizer)
    scaler.update()
else:
    # 非混合精度路径,同样做梯度裁剪

核心要点:

  • outputs.reshape(-1, vocab_size)(batch, seq_len, vocab_size)展平成(batch*seq_len, vocab_size),targets也展平,这样计算交叉熵损失时就是把每一个位置的预测与真实下一个字符比较。
  • 梯度裁剪必须在scaler.unscale_(optimizer)之后,step之前;对于非混合精度分支,直接clip_grad_norm_后再optimizer.step()
  • 如果不裁剪,损失值偶尔会突然爆炸,一个batch的训练就能把参数冲烂,恢复不了的。

5.4 定期生成示例

每1个epoch,保存模型后,用generate_text生成200个字符看效果:

python 复制代码
sample = generate_text(model, "话说", 200, device, temperature=0.8)
print(sample)

这是一个很好的debug手段。通过看生成的文字是完全是乱码还是有点模样,就能判断模型学习到了什么程度,比光看loss数字直观多了。


6. 文本生成(推理)

generate_text实现了自回归逐个字符生成:

python 复制代码
def generate_text(model, start_text, length, device, temperature=0.1):
    model.eval()
    input_seq = [char2idx.get(ch, 0) for ch in start_text]
    input_tensor = torch.tensor([input_seq], dtype=torch.long).to(device)
    generated = list(start_text)
    hidden = None
    with torch.no_grad():
        if len(input_seq) > 0:
            output, hidden = model(input_tensor, hidden)
        for _ in range(length):
            logits = output[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_char_idx = torch.multinomial(probs, 1).item()
            next_char = idx2char[next_char_idx]
            generated.append(next_char)
            input_tensor = torch.tensor([[next_char_idx]], dtype=torch.long).to(device)
            output, hidden = model(input_tensor, hidden)
    return ''.join(generated)
  • 先对start_text做一次完整的前向传播,拿最后一个时间步的logits。
  • 用温度temperature缩放logits,然后softmax得到概率分布。temperature越小分布越尖锐(倾向于选高概率词,文本更保守),越大分布越平滑(更具多样性,但也更容易出错)。
  • torch.multinomial按照概率采样,而不是直接取argmax,这样每次生成的结果会有变化。
  • 生成下一个字符后,把它作为下一轮的输入,同时hidden继续传递,保持了历史上下文。

一个小细节:如果从空字符串开始生成,input_seq为空,就不会执行第一个if分支,hidden保持None直接进入循环,这时相当于从初始状态随机生成,通常效果不太"语意连贯",最好给个开头词。


7. 模型保存与加载工具

保存函数save_checkpoint把模型权重、优化器状态、epoch和loss一起存入字典:

python 复制代码
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model_state,
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}
torch.save(checkpoint, MODEL_SAVE_PATH)

这样续训时不仅恢复模型参数,还恢复了优化器的动量等状态,无缝接着训练。

推理加载函数load_model_for_inference专门处理了可能的module.前缀:

python 复制代码
new_state_dict = {}
for k, v in state_dict.items():
    name = k.replace('module.', '') if k.startswith('module.') else k
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

因为训练时可能用到了DataParallel,推理时往往是单卡,这种清洗可以直接去掉前缀。


8. 推理效果测试

最后一段代码遍历多个提示词和温度参数,生成文本展示:

python 复制代码
test_prompts = ["话说", "悟空", "唐僧", "那妖怪"]
temperatures = [0.5, 0.8, 1.0]
for prompt in test_prompts:
    for temp in temperatures:
        generated = generate_text(inference_model, prompt, GENERATE_LENGTH, device, temperature=temp)
        print(generated)

温度低的时候,模型输出的文本更像原文摘录,甚至会大段复制原著。温度高到1.0时,可能生成一些奇特的新词组合,有"形似神魔"的怪诞感,偶尔语法错乱。实际玩耍时我偏向0.6~0.8,既有创意又不跑偏。


总结:

这个项目虽然规模不大,但覆盖了语言模型训练的很多核心环节:数据处理、模型设计、训练策略(续训、梯度裁剪、学习率调度)、生成采样。跑通一遍之后,对RNN的训练细节会有一个很完整的理解。后面如果你想把"写小说"模型扩展成"对话模型"或"智能续写助手",很多基础架构可以直接复用。

源码我放到了GitHub:https://github.com/anjuxi/LSTM-Journey-to-the-West,欢迎去star和讨论。如有疑问欢迎在评论区交流。


注:文中涉及《西游记》原文仅作技术研究使用,版权归原作者。

相关推荐
YJlio2 小时前
1 1.2 Windows 账户的分类:管理员 / 标准 / 来宾 + 微软账户 vs 本地账户
人工智能·python·microsoft·ai·chatgpt·openai·agent
Luhui Dev2 小时前
如何画圆的切线?几何作图技巧
人工智能·数学·ai
stolentime2 小时前
线段树套?——洛谷P7312 [COCI 2018/2019 #2] Sunčanje题解
c++·算法·图论·洛谷
chaofan9802 小时前
一张照片秒变3D模型!微软Copilot 3D正在颠覆三维创作的游戏规则
人工智能·microsoft·copilot
数字时代全景窗2 小时前
智能体架构进化路线:从Manus、OpenClaw到Evolver——与Palantir本体架构的比较研究
大数据·人工智能·架构·软件工程
wayz112 小时前
Day 12:支持向量机(SVM)原理与实践
算法·机器学习·支持向量机
kcuwu.2 小时前
1950-2024 AI百年跃迁:从图灵测试到ChatGPT
人工智能·chatgpt
JGDT_2 小时前
直播回顾2|底层逻辑重构:AI驱动下的财务工作五大范式转移
大数据·人工智能·系统架构·系统安全·软件工程
knight_9___2 小时前
RAG面试篇8
人工智能·python·面试·agent·rag