python-pytorch实现lstm模型预测文本输出0.1.00

python-pytorch实现lstm模型预测文本输出0.1.00

有问题还需要完善

数据

一篇新闻:https://news.sina.com.cn/c/2024-04-12/doc-inarqiev0222543.shtml

参考

https://blog.csdn.net/qq_19530977/article/details/120936391

python 复制代码
# https://blog.csdn.net/qq_19530977/article/details/120936391

效果

python 复制代码
"""
布林肯国务卿
布林肯国务卿同王毅
布林肯国务卿同王毅主任
布林肯国务卿同王毅主任以及
布林肯国务卿同王毅主任以及其他
布林肯国务卿同王毅主任以及其他国家
布林肯国务卿同王毅主任以及其他国家敦促
布林肯国务卿同王毅主任以及其他国家敦促伊朗
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅主任
"""

导入包

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torch.autograd import Variable
import jieba

分词到数组

复制文章到txt文档

python 复制代码
allarray=[]
with open("./howtousercbow/data/news.txt",encoding="utf-8") as afterjieba:
    lines=afterjieba.readlines()
    print(lines)
    for line in lines:
        result=list(jieba.cut(line,False))
        for r in result:
            allarray.append(r.replace("\n",""))

allarray,len(allarray)
    

准备数数据

python 复制代码
word2index={one:i for i,one in enumerate(allarray)}
index2word={i:one for i,one in enumerate(allarray)}
word2index[" "]=len(allarray)-1
index2word[len(allarray)-1]=" "
word2index[" "]

查看频次

python 复制代码
from collections import Counter
Counter(allarray)

获取vacab

python 复制代码
vocab_size = len(allarray)
vocab_size

生成输入数据

python 复制代码
# 生成输入数据
batch_x = []
batch_y = []
window=1
seq_length=vocab_size
for i in range(seq_length - window + 1):
    x = word2index[allarray[i]]
    if i + window >= seq_length:
        y = word2index[" "]
    else:
        y = word2index[allarray[i + 1]]
    batch_x.append([x])
    batch_y.append(y)

# print(batch_x)
# print("=======")
# print(batch_y)
# print(45/0)


# 训练数据
batch_x, batch_y = Variable(torch.LongTensor(batch_x)), Variable(torch.LongTensor(batch_y))
 
# 参数
# vocab_size = len(letters)
embedding_size = 100
n_hidden = 32
batch_size = 10
num_classes = vocab_size
 
dataset = Data.TensorDataset(batch_x, batch_y)
loader = Data.DataLoader(dataset, batch_size, shuffle=True)
 
# 建立模型
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.word_vec = nn.Embedding(vocab_size, embedding_size)
        # bidirectional双向LSTM
        self.bilstm = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
        self.lstm = nn.LSTM(2 * n_hidden, 2 * n_hidden, 1, bidirectional=False)
        self.fc = nn.Linear(n_hidden * 2, num_classes)
 
    def forward(self, input):
        embedding_input = self.word_vec(input)
#         print("embedding_input",embedding_input,embedding_input.size())
        # 调换第一维和第二维度
        embedding_input = embedding_input.permute(1, 0, 2)
        bilstm_output, (h_n1, c_n1) = self.bilstm(embedding_input)
        lstm_output, (h_n2, c_n2)= self.lstm(bilstm_output)
        fc_out = self.fc(lstm_output[-1])
        return fc_out
 
model = BiLSTM()

训练

python 复制代码
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
# 训练
for epoch in range(300):
    cost = 0
    for input_batch, target_batch in loader:
        pred = model(input_batch)
#         print("pred",pred)
#         print("target_batch",target_batch)
        loss = criterion(pred, target_batch)
        cost += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch: %d,  loss: %.5f " % (epoch, cost))

测试

python 复制代码
def test(str):
    test_text =str
    test_batch = [word2index[str]]
#     print(test_batch)
    test_batch = torch.LongTensor([test_batch])
#     print("test_batch",test_batch)
#     print(test_batch)
    out = model(test_batch)
    predict = torch.max(out, 1)[1].item()
#     print(test_text,"后一个字母为:", index2word[predict])
    return index2word[predict]

连续预测

python 复制代码
import time
s="布林肯"
while True:
    fenci=jieba.cut(s,False)
    fenciList=list(fenci)
    s=s+test(fenciList[-1:][0])
    
    time.sleep(1)
    print(s)
        
相关推荐
阿正的梦工坊2 小时前
深入理解 PyTorch 中的 unsqueeze 操作
人工智能·pytorch·python
FreakStudio3 小时前
硬件版【Cursor】?aily blockly IDE尝鲜封神,实战硬伤尽显
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机
测试员周周5 小时前
【Appium 系列】第06节-页面对象实现 — LoginPage 实战
开发语言·前端·人工智能·python·功能测试·appium·测试用例
2301_783848655 小时前
优化文本分类中堆叠模型的网格搜索性能:避免训练卡顿的实战指南
jvm·数据库·python
CLX05056 小时前
如何安装Oracle 12c Cloud Control_OMS服务端组件与Agent部署
jvm·数据库·python
老纪7 小时前
SQL中如何查找特定的空值行:WHERE IS NULL深度解析
jvm·数据库·python
噜噜噜阿鲁~7 小时前
python学习笔记 | 10.0、面向对象编程
笔记·python·学习
weixin199701080167 小时前
[特殊字符] RESTful API 接口规范详解:构建高效、可扩展的 Web 服务(附 Python 源码)
前端·python·restful
2301_781571427 小时前
mysql数据库响应缓慢如何排查_使用EXPLAIN分析执行计划
jvm·数据库·python
彳亍1017 小时前
实现倒计时数字在到达1后自动隐藏(2为最后可见数字),同时继续运行至-1再终止
jvm·数据库·python