python-pytorch实现lstm模型预测文本输出0.1.00

python-pytorch实现lstm模型预测文本输出0.1.00

有问题还需要完善

数据

一篇新闻:https://news.sina.com.cn/c/2024-04-12/doc-inarqiev0222543.shtml

参考

https://blog.csdn.net/qq_19530977/article/details/120936391

python 复制代码
# https://blog.csdn.net/qq_19530977/article/details/120936391

效果

python 复制代码
"""
布林肯国务卿
布林肯国务卿同王毅
布林肯国务卿同王毅主任
布林肯国务卿同王毅主任以及
布林肯国务卿同王毅主任以及其他
布林肯国务卿同王毅主任以及其他国家
布林肯国务卿同王毅主任以及其他国家敦促
布林肯国务卿同王毅主任以及其他国家敦促伊朗
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅主任
"""

导入包

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torch.autograd import Variable
import jieba

分词到数组

复制文章到txt文档

python 复制代码
allarray=[]
with open("./howtousercbow/data/news.txt",encoding="utf-8") as afterjieba:
    lines=afterjieba.readlines()
    print(lines)
    for line in lines:
        result=list(jieba.cut(line,False))
        for r in result:
            allarray.append(r.replace("\n",""))

allarray,len(allarray)
    

准备数数据

python 复制代码
word2index={one:i for i,one in enumerate(allarray)}
index2word={i:one for i,one in enumerate(allarray)}
word2index[" "]=len(allarray)-1
index2word[len(allarray)-1]=" "
word2index[" "]

查看频次

python 复制代码
from collections import Counter
Counter(allarray)

获取vacab

python 复制代码
vocab_size = len(allarray)
vocab_size

生成输入数据

python 复制代码
# 生成输入数据
batch_x = []
batch_y = []
window=1
seq_length=vocab_size
for i in range(seq_length - window + 1):
    x = word2index[allarray[i]]
    if i + window >= seq_length:
        y = word2index[" "]
    else:
        y = word2index[allarray[i + 1]]
    batch_x.append([x])
    batch_y.append(y)

# print(batch_x)
# print("=======")
# print(batch_y)
# print(45/0)


# 训练数据
batch_x, batch_y = Variable(torch.LongTensor(batch_x)), Variable(torch.LongTensor(batch_y))
 
# 参数
# vocab_size = len(letters)
embedding_size = 100
n_hidden = 32
batch_size = 10
num_classes = vocab_size
 
dataset = Data.TensorDataset(batch_x, batch_y)
loader = Data.DataLoader(dataset, batch_size, shuffle=True)
 
# 建立模型
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.word_vec = nn.Embedding(vocab_size, embedding_size)
        # bidirectional双向LSTM
        self.bilstm = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
        self.lstm = nn.LSTM(2 * n_hidden, 2 * n_hidden, 1, bidirectional=False)
        self.fc = nn.Linear(n_hidden * 2, num_classes)
 
    def forward(self, input):
        embedding_input = self.word_vec(input)
#         print("embedding_input",embedding_input,embedding_input.size())
        # 调换第一维和第二维度
        embedding_input = embedding_input.permute(1, 0, 2)
        bilstm_output, (h_n1, c_n1) = self.bilstm(embedding_input)
        lstm_output, (h_n2, c_n2)= self.lstm(bilstm_output)
        fc_out = self.fc(lstm_output[-1])
        return fc_out
 
model = BiLSTM()

训练

python 复制代码
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
# 训练
for epoch in range(300):
    cost = 0
    for input_batch, target_batch in loader:
        pred = model(input_batch)
#         print("pred",pred)
#         print("target_batch",target_batch)
        loss = criterion(pred, target_batch)
        cost += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch: %d,  loss: %.5f " % (epoch, cost))

测试

python 复制代码
def test(str):
    test_text =str
    test_batch = [word2index[str]]
#     print(test_batch)
    test_batch = torch.LongTensor([test_batch])
#     print("test_batch",test_batch)
#     print(test_batch)
    out = model(test_batch)
    predict = torch.max(out, 1)[1].item()
#     print(test_text,"后一个字母为:", index2word[predict])
    return index2word[predict]

连续预测

python 复制代码
import time
s="布林肯"
while True:
    fenci=jieba.cut(s,False)
    fenciList=list(fenci)
    s=s+test(fenciList[-1:][0])
    
    time.sleep(1)
    print(s)
        
相关推荐
Sherry Wangs1 天前
显卡算力过高导致PyTorch不兼容的救赎指南
人工智能·pytorch·显卡
小叮当⇔1 天前
PYcharm——获取天气
ide·python·pycharm
霍志杰1 天前
记一次csv和xlsx之间的转换处理
python
测试19981 天前
Jmeter是如何实现接口关联的?
自动化测试·软件测试·python·测试工具·jmeter·职场和发展·接口测试
小蕾Java1 天前
PyCharm 2025:最新使用图文教程!
ide·python·pycharm
java1234_小锋1 天前
TensorFlow2 Python深度学习 - 卷积神经网络(CNN)介绍
python·深度学习·tensorflow·tensorflow2
java1234_小锋1 天前
TensorFlow2 Python深度学习 - 循环神经网络(RNN)- 简介
python·深度学习·tensorflow·tensorflow2
大飞记Python1 天前
Chromedriver放项目里就行!Selenium 3 和 4 指定路径方法对比 + 兼容写法
开发语言·python
小薛引路1 天前
office便捷办公06:根据相似度去掉excel中的重复行
windows·python·excel
Hs_QY_FX1 天前
Python 分类模型评估:从理论到实战(以信用卡欺诈检测为例)
人工智能·python·机器学习·数据挖掘·多分类评估