基于 双向RNN网络 的中文文本预测模型

目录

一、项目背景与数据准备

[1.1 数据来源与结构](#1.1 数据来源与结构)

[1.2 环境依赖](#1.2 环境依赖)

二、数据预处理

[2.1 文本提取与分割](#2.1 文本提取与分割)

[2.2 构建中文分词器](#2.2 构建中文分词器)

[2.3 构建训练数据](#2.3 构建训练数据)

[四、搭建双向 RNN 模型](#四、搭建双向 RNN 模型)

五、模型训练

[5.1 训练配置](#5.1 训练配置)

[5.2 训练与验证流程](#5.2 训练与验证流程)

[5.3 训练结果](#5.3 训练结果)

六、文本预测

七、完整代码如下


循环神经网络(RNN)是一类专为处理序列数据设计的递归神经网络,核心优势是通过循环结构实现记忆能力,能利用历史信息辅助当前数据的处理,在自然语言处理、语音识别等领域曾广泛应用,以下是其详细介绍及缺点说明:

  1. 核心介绍
    • 核心特性:与传统神经网络不同,RNN 的隐藏层会将上一时刻的隐藏状态反馈到当前时刻的计算中,相当于拥有 "记忆"。同时它还具备权重共享特性,循环节点在所有时间步使用相同权重系数,既能减少参数数量,又能实现对序列中前期信息的编码传递,让模型可处理随时间变化的特征。
    • 基本结构:基础结构包含输入层、隐藏层和输出层。输入层负责接收序列中的单个元素数据;隐藏层是核心,其输入结合当前输入和上一时刻隐藏状态,通过激活函数处理后生成当前隐藏状态;输出层接收隐藏层结果并转化为最终输出,比如文本分类结果、语音识别文字等。
    • **主要应用:**早期在语言建模、机器翻译、语音识别等领域表现突出。例如在情感分析任务中,可通过分析句子中词语的先后顺序及关联,判断文本的情感倾向;也能用于简单的时间序列预测,如短期气温变化、商品短期销量预估等。
  2. 主要缺点
    • **梯度消失与梯度爆炸:**这是 RNN 最核心的问题。训练时采用随时间反向传播算法,梯度需沿时间步反向传递。由于梯度计算包含多次权重连乘,若权重绝对值小于 1,多次相乘后梯度会逐渐衰减至趋近于 0,导致早期时间步的参数难以更新,无法学习长期依赖;若权重绝对值大于 1,梯度会不断增大,造成模型训练不稳定,甚至无法收敛。
    • 难以处理长期依赖:受梯度消失问题的直接影响,RNN 无法有效记住序列中距离较远的关键信息。比如处理一篇长文章时,它很难将文末的词语与文章开头的核心概念关联起来,这使得其在长文本分析、长周期时间序列预测等任务中表现极差。

在自然语言处理(NLP)领域,文本预测是一项基础且重要的任务,它广泛应用于输入法联想、智能写作辅助等场景。本文将详细介绍如何基于 PyTorch 框架,利用双向 RNN 构建一个中文文本预测模型,从数据预处理到模型训练、预测全流程进行拆解说明。

一、项目背景与数据准备

1.1 数据来源与结构

本次实验使用的数据集是包含多轮对话的 JSONL 格式文件,每条数据包含topic(话题)、user1user2(对话双方)和dialog(对话内容)四个字段,总计 2476 行对话数据。我们的核心目标是从对话文本中提取语句,训练模型实现 "给定前 5 个分词,预测第 6 个分词" 的文本预测任务。

1.2 环境依赖

实验所需核心库如下:

python 复制代码
import torch
import pandas as pd
from typing import List
from sklearn import model_selection
import jieba
from tqdm import tqdm
import json
from torch.utils.data import Dataset, DataLoader
from torch import nn,optim
from torch.utils.tensorboard import SummaryWriter

二、数据预处理

2.1 文本提取与分割

首先读取数据集,并从dialog字段中提取纯文本内容(去除user1:/user2:前缀):

python 复制代码
# 读取数据
data=pd.read_json('data/synthesized_.jsonl',lines=True,orient='records')

# 提取对话文本
sentence_list=[]
for row in data['dialog']:
    for item in row:
        item=item.split(':')[1]  # 去除说话人前缀
        sentence_list.append(item)

# 划分训练集和测试集
train_list,test_list=model_selection.train_test_split(sentence_list,test_size=0.2)

数据集如图

2.2 构建中文分词器

由于中文文本无天然分隔符,我们基于 jieba 分词构建自定义 Tokenizer,实现 "分词 - 编码 - 词表构建" 功能:

python 复制代码
class JieBaTokenizer:
    unk_index=1  # 未知词索引
    def __init__(self,vocab_list):
        self.vocab_list=vocab_list
        self.vocab_size = len(vocab_list)
        self.world2index={value:index for index,value in enumerate(vocab_list)}
        self.index2world={index:value for index,value in enumerate(vocab_list)}

    @staticmethod
    def tokenize(text:str)->List[str]:
        return jieba.lcut(text)  # jieba精准分词

    def encode(self,text:str)->List[int]:
        # 将文本转换为索引序列
        tokens=self.tokenize(text)
        tokens_index=[self.world2index.get(token,self.unk_index) for token in tokens]
        return tokens_index

    @classmethod
    def build_vocab(cls,sentences:List[str],unk_token:str='<unknown>',vocab_path:str='./vocab.json'):
        # 从训练集构建词表
        vocab_set=set()
        for sentence in tqdm(sentences,desc='构建词表:'):
            vocab_set.update(jieba.lcut(sentence))
        vocab_list = [unk_token] + sorted(list(vocab_set))
        vocab_dict={index:value for index,value in enumerate(vocab_list)}
        with open(vocab_path,'w',encoding='utf-8') as f:
            json.dump(vocab_dict,f,ensure_ascii=False,indent=2)

    @classmethod        
    def read_vocab(cls,vocab_path:str='./vocab.json'):
        # 加载词表
        with open(vocab_path,'r',encoding='utf-8') as f:
            json_dict=json.load(f)
            sentences=[value for key,value in json_dict.items()] 
        return cls(sentences)

# 构建并加载词表
JieBaTokenizer.build_vocab(sentences=train_list,unk_token='<unknown>',vocab_path='./vocab.json')
tokenizer=JieBaTokenizer.read_vocab(vocab_path='./vocab.json')

2.3 构建训练数据

将文本转换为 "输入序列(5 个分词索引)- 目标(第 6 个分词索引)" 的格式,并保存为 JSONL 文件:

python 复制代码
def build_dataset(dataset:list,save_path:str):
    # 编码文本为索引
    dataset_id=[tokenizer.encode(item) for item in tqdm(dataset,desc='构建索引')]
    dataset_list=[]
    # 构造输入-目标对
    for item in tqdm(dataset_id,desc='构建数据列表'):
        for i in range(len(item)-5):
            input=item[i:i+5]
            target=item[i+5]
            dataset_list.append({'input':input,'target':target})
    # 保存数据
    with open(save_path, 'w', encoding='utf-8') as f:
        for line in tqdm(dataset_list, desc='保存文件'):
            json.dump(line, f, ensure_ascii=False)
            f.write('\n')

# 生成训练/测试数据
build_dataset(train_list,'data/train_dataset.jsonl')
build_dataset(test_list,'data/test_dataset.jsonl')

四、搭建双向 RNN 模型

使用 Embedding 层将索引转换为向量,结合双向 RNN 提取上下文特征,最后通过全连接层输出预测结果:

python 复制代码
class Network(nn.Module):
    def __init__(self,vocab_size):
        super(Network,self).__init__()
        # 词嵌入层
        self.embeding=nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=128
        )
        # 双向RNN层
        self.rnn=nn.RNN(
            input_size=128,
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )
        # 全连接层(双向RNN输出维度=256*2)
        self.linear = nn.Linear(in_features=256*2,out_features=vocab_size)
        
    def forward(self,x):
        embeding=self.embeding(x)  # [batch_size, 5, 128]
        output,_hn=self.rnn(embeding)  # [batch_size, 5, 512]
        return self.linear(output[:,-1,:])  # 取最后一个时间步输出 [batch_size, vocab_size]

# 初始化模型
device='cuda' if torch.cuda.is_available() else 'cpu'
network=Network(vocab_size=tokenizer.vocab_size).to(device)

五、模型训练

5.1 训练配置

设置训练轮数、损失函数、优化器,并使用 TensorBoard 记录训练过程:

python 复制代码
writer=SummaryWriter(log_dir='./logs')  # TensorBoard日志
epochs=3  # 训练轮数
lossfn=nn.CrossEntropyLoss()  # 交叉熵损失(适配分类任务)
lr=1e-3  # 学习率
optimizer=optim.Adam(network.parameters(),lr=lr)  # Adam优化器
best_loss=float('inf')  # 最优验证损失(用于保存最佳模型)

5.2 训练与验证流程

python 复制代码
for epoch in range(epochs):
    print(f'==========第{epoch+1}轮===========')
    
    # 训练阶段
    network.train()
    train_total_loss = 0.0
    train_correct = 0
    train_total = 0
    train_pbar = tqdm(train_dataloader, desc='训练')
    for index,(batch_x, batch_y) in enumerate(train_pbar):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        pred_y=network(batch_x)
        loss=lossfn(pred_y,batch_y)
        loss.backward()
        optimizer.step()
        # 计算训练指标
        train_total_loss += loss.item()
        batch_avg_loss = train_total_loss / (index + 1)
        pred_idx = torch.argmax(pred_y, dim=1) 
        batch_correct = (pred_idx == batch_y).sum().item()
        train_correct += batch_correct
        train_total += batch_y.size(0) 
        train_acc = train_correct / train_total 
        train_pbar.postfix = ({
            "平均损失": f"{batch_avg_loss:.4f}",
            "准确率": f"{train_acc:.4f}" 
        })
    
    # 验证阶段
    network.eval()
    total_test_loss = 0.0
    test_correct = 0
    test_total = 0
    test_pbar=tqdm(test_dataloader,desc='验证')
    with torch.no_grad():  # 关闭梯度计算
        for index,(batch_x, batch_y) in enumerate(test_pbar):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            pred_y=network(batch_x)
            loss=lossfn(pred_y,batch_y)
            total_test_loss += loss.item()
            # 计算验证指标
            batch_avg_loss = total_test_loss / (index + 1)
            pred_idx = torch.argmax(pred_y, dim=1) 
            batch_correct = (pred_idx == batch_y).sum().item()
            test_correct += batch_correct
            test_total += batch_y.size(0)
            test_acc = test_correct / test_total
            test_pbar.postfix = ({
            "平均损失": f"{batch_avg_loss:.4f}",
            "准确率": f"{test_acc:.4f}" 
        })
    
    # 记录训练/验证指标
    train_avg_loss=train_total_loss/len(train_dataloader)
    train_avg_acc = train_correct / train_total
    test_avg_loss=total_test_loss/len(test_dataloader)
    test_avg_acc = test_correct / test_total
    print(f'训练平均损失为 {train_avg_loss:.4f},训练平均准确率为 {train_avg_acc:.4f},验证平均损失为 {test_avg_loss:.4f},验证平均准确率为{test_avg_acc:.4f}')
    writer.add_scalar('loss/train', train_avg_loss, epoch)
    writer.add_scalar('loss/val', test_avg_loss, epoch)
    writer.add_scalar('acc/train', train_avg_acc, epoch)
    writer.add_scalar('acc/val', test_avg_acc, epoch)
    
    # 保存最优模型
    if test_avg_loss < best_loss:
        best_loss = test_avg_loss
        torch.save(network,'best_model.pt')

5.3 训练结果

经过 3 轮训练,模型表现如下:

  • 训练平均损失从 5.2361 降至 4.8160,训练准确率从 19.03% 提升至 21.92%;
  • 验证平均损失从 5.0281 降至 4.9753,验证准确率从 21.31% 提升至 22.34%;
  • 模型在验证集上的损失持续降低,无明显过拟合现象。

由于我这里是取的argmax的最高的一个,导致准确率不高,可以取TopK的前5个来重新计算准确率,这样比较合理。

六、文本预测

加载最优模型,实现 "输入文本→预测 Top5 候选词→用户选择→拼接文本" 的交互式预测:

python 复制代码
def predict(text:str):
    # 文本编码
    input_id=tokenizer.encode(text)
    input_tensor=torch.tensor(input_id,dtype=torch.long).unsqueeze(0).to(device)
    # 加载模型
    model=torch.load('best_model.pt').to(device)
    model.eval()
    # 预测Top5
    with torch.no_grad():
        pred_y=model(input_tensor)
        top5_index=torch.topk(pred_y,k=5).indices.squeeze()
        top5_world=[tokenizer.index2world.get(id) for id in top5_index.tolist()]
        return top5_world

# 交互式预测
input_text=input('请输入预测的词:')
while True:
    top5_world=predict(input_text)
    word_dict={index:value for index,value in enumerate(top5_world)}
    choose_world=input(f'请选择预测的词:{word_dict}')
    if choose_world=='q':
        print('已退出!')
        break
    input_text+=top5_world[int(choose_world)]
    print(f'输入历史为:{input_text}')

七、完整代码如下

python 复制代码
import torch
import pandas as pd
from typing import List
from sklearn import model_selection
import jieba
from tqdm import tqdm
import json
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn,optim
from torch.utils.tensorboard import SummaryWriter


data=pd.read_json('data/synthesized_.jsonl',lines=True,orient='records') # 读取文件

# 划分训练集和测试集
train_list,test_list=model_selection.train_test_split(sentence_list,test_size=0.2)


class JieBaTokenizer:  # 构建tokenizer
    unk_index=1
    def __init__(self,vocab_list):
        self.vocab_list=vocab_list
        self.vocab_size = len(vocab_list)
        self.world2index={value:index for index,value in enumerate(vocab_list)}
        self.index2world={index:value for index,value in enumerate(vocab_list)}

    @staticmethod
    def tokenize(text:str)->List[str]:
        return jieba.lcut(text)

    def encode(self,text:str)->List[int]:
        tokens=self.tokenize(text)
        tokens_index=[self.world2index.get(token,self.unk_index) for token in tokens]
        return tokens_index

    @classmethod
    def build_vocab(
            cls,sentences:List[str],
            unk_token:str='<unknown>',
            vocab_path:str='./vocab.json'
    ):
        vocab_set=set()
        for sentence in tqdm(sentences,desc='构建词表:'):
            vocab_set.update(jieba.lcut(sentence))
        vocab_list = [unk_token] + sorted(list(vocab_set))
        vocab_dict={index:value for index,value in enumerate(vocab_list)}
        vocab_dict[cls.unk_index]=unk_token
        with open(vocab_path,'w',encoding='utf-8') as f:
            json.dump(vocab_dict,f,ensure_ascii=False,indent=2)

    @classmethod        
    def read_vocab(cls,vocab_path:str='./vocab.json'):
        with open(vocab_path,'r',encoding='utf-8') as f:
            json_dict=json.load(f)
            sentences=[value for key,value in json_dict.items()] 
        return cls(sentences)


JieBaTokenizer.build_vocab(sentences=train_list,unk_token='<unknown>',vocab_path='./vocab.json')
tokenizer=JieBaTokenizer.read_vocab(vocab_path='./vocab.json')


# 构建数据
def build_dataset(dataset:list,save_path:str):
    dataset_id=[tokenizer.encode(item) for item in tqdm(dataset,desc='构建索引')]
    dataset_list=[]
    for item in tqdm(dataset_id,desc='构建数据列表'):
        for i in range(len(item)-5):
            input=item[i:i+5]
            target=item[i+5]
            dataset_list.append({'input':input,'target':target})
    with open(save_path, 'w', encoding='utf-8') as f:
        for line in tqdm(dataset_list, desc='保存文件'):
            json.dump(line, f, ensure_ascii=False)
            f.write('\n')


build_dataset(train_list,'data/train_dataset.jsonl')
build_dataset(test_list,'data/test_dataset.jsonl')


# 构建Dataloader
class SearchDataset(Dataset):
    def __init__(self,path):
        self.data=pd.read_json(path,lines=True,orient='records').to_dict(orient='records')
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        input_rensor=torch.tensor(self.data[index]['input'],dtype=torch.long)
        target_rensor=torch.tensor(self.data[index]['target'],dtype=torch.long)
        return input_rensor,target_rensor


train_dataset=SearchDataset('data/train_dataset.jsonl')
test_dataset=SearchDataset('data/test_dataset.jsonl')
train_dataloader=DataLoader(dataset=train_dataset,batch_size=16,drop_last=True)
test_dataloader=DataLoader(dataset=test_dataset,batch_size=16,drop_last=True)


# 构建双向RNN模型
class Network(nn.Module):
    def __init__(self,vocab_size):
        super(Network,self).__init__()
        self.embeding=nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=128
        )
        self.rnn=nn.RNN(
            input_size=128,
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )
        self.linear = nn.Linear(in_features=256*2,out_features=vocab_size)
        
    def forward(self,x):
        embeding=self.embeding(x)
        output,_hn=self.rnn(embeding)
        return self.linear(output[:,-1,:])


device='cuda' if torch.cuda.is_available() else 'cpu'
network=Network(vocab_size=tokenizer.vocab_size).to(device)
writer=SummaryWriter(log_dir='./logs')


epochs=3
lossfn=nn.CrossEntropyLoss()
lr=1e-3
optimizer=optim.Adam(network.parameters(),lr=lr)


# 训练并保存模型
best_loss=float('inf')
for epoch in range(epochs):
    print(f'==========第{epoch+1}轮===========')
    
    network.train()
    train_total_loss = 0.0
    train_correct = 0
    train_total = 0
    train_pbar = tqdm(train_dataloader, desc='训练')
    for index,(batch_x, batch_y) in enumerate(train_pbar):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        pred_y=network(batch_x)
        loss=lossfn(pred_y,batch_y)
        loss.backward()
        optimizer.step()
        # 指标
        train_total_loss += loss.item()
        batch_avg_loss = train_total_loss / (index + 1)
        pred_idx = torch.argmax(pred_y, dim=1) 
        batch_correct = (pred_idx == batch_y).sum().item()
        train_correct += batch_correct
        train_total += batch_y.size(0) 
        train_acc = train_correct / train_total 
        train_pbar.postfix = ({
            "平均损失": f"{batch_avg_loss:.4f}",
            "准确率": f"{train_acc:.4f}" 
        })
    
    network.eval()
    total_test_loss = 0.0
    test_correct = 0
    test_total = 0
    test_pbar=tqdm(test_dataloader,desc='验证')
    with torch.no_grad():
        for index,(batch_x, batch_y) in enumerate(test_pbar):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            pred_y=network(batch_x)
            loss=lossfn(pred_y,batch_y)
            total_test_loss += loss.item()
            # 指标
            batch_avg_loss = total_test_loss / (index + 1)
            pred_idx = torch.argmax(pred_y, dim=1) 
            batch_correct = (pred_idx == batch_y).sum().item()
            test_correct += batch_correct
            test_total += batch_y.size(0)
            test_acc = test_correct / test_total
            test_pbar.postfix = ({
            "平均损失": f"{batch_avg_loss:.4f}",
            "准确率": f"{test_acc:.4f}" 
        })
        
    train_avg_loss=train_total_loss/len(train_dataloader)
    train_avg_acc = train_correct / train_total
    test_avg_loss=total_test_loss/len(test_dataloader)
    test_avg_acc = test_correct / test_total
    print(f'训练平均损失为 {train_avg_loss:.4f},训练平均准确率为 {train_avg_acc:.4f},验证平均损失为 {test_avg_loss:.4f},验证平均准确率为{test_avg_acc:.4f}')
    writer.add_scalar('loss/train', train_avg_loss, epoch)
    writer.add_scalar('loss/val', test_avg_loss, epoch)
    writer.add_scalar('acc/train', train_avg_acc, epoch)
    writer.add_scalar('acc/val', test_avg_acc, epoch)
    
    if test_avg_loss < best_loss:
        best_loss = test_avg_loss
        torch.save(network,'best_model.pt')
        

# 测试模型
def predict(text:str): # 预测
    input_id=tokenizer.encode(text)
    input_tensor=torch.tensor(input_id,dtype=torch.long).unsqueeze(0).to(device)
    model=torch.load('best_model.pt').to(device)
    model.eval()
    with torch.no_grad():
        pred_y=model(input_tensor)
        top5_index=torch.topk(pred_y,k=5).indices.squeeze()
        top5_world=[tokenizer.index2world.get(id) for id in top5_index.tolist()]
        return top5_world


input_text=input('请输入预测的词:')
while True:
    top5_world=predict(input_text)
    word_dict={index:value for index,value in enumerate(top5_world)}
    choose_world=input(f'请选择预测的词:{word_dict}')
    if choose_world=='q':
        print('已退出!')
        break
    input_text+=top5_world[int(choose_world)]
    print(f'输入历史为:{input_text}')
相关推荐
橘子师兄2 小时前
C++AI大模型接入SDK—deepseek接入封装
c++·人工智能·chatgpt
gdutxiaoxu2 小时前
browser-use - 让AI Agent真正“会“用浏览器
人工智能·ai agent
Fairy要carry2 小时前
面试-OnlyDecoder用于嵌入模型
人工智能
陈天伟教授2 小时前
人工智能应用-机器视觉:AI 鉴伪 03.换脸伪造技术
人工智能·神经网络·生成对抗网络
Yeats_Liao2 小时前
长文本优化:KV Cache机制与显存占用平衡策略
人工智能·深度学习·学习·机器学习·华为
石逸凡2 小时前
基于实体地图的金融大模型落地方法探索与前景展望
人工智能·金融
cooldream20092 小时前
辩核AI具身辩论数字人训练系统:技术架构与功能体系全解析
人工智能·架构·具身数字人
星爷AG I2 小时前
9-13 视知觉组织(AGI基础理论)
人工智能·agi
小草cys2 小时前
qwen3-max识别电气主接线图,不需要yolo辅助
人工智能·yolo·目标跟踪·多模态大模型·电气主接线图