python-pytorch实现lstm模型预测文本输出0.1.00

python-pytorch实现lstm模型预测文本输出0.1.00

有问题还需要完善

数据

一篇新闻:https://news.sina.com.cn/c/2024-04-12/doc-inarqiev0222543.shtml

参考

https://blog.csdn.net/qq_19530977/article/details/120936391

python 复制代码
# https://blog.csdn.net/qq_19530977/article/details/120936391

效果

python 复制代码
"""
布林肯国务卿
布林肯国务卿同王毅
布林肯国务卿同王毅主任
布林肯国务卿同王毅主任以及
布林肯国务卿同王毅主任以及其他
布林肯国务卿同王毅主任以及其他国家
布林肯国务卿同王毅主任以及其他国家敦促
布林肯国务卿同王毅主任以及其他国家敦促伊朗
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅主任
"""

导入包

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torch.autograd import Variable
import jieba

分词到数组

复制文章到txt文档

python 复制代码
allarray=[]
with open("./howtousercbow/data/news.txt",encoding="utf-8") as afterjieba:
    lines=afterjieba.readlines()
    print(lines)
    for line in lines:
        result=list(jieba.cut(line,False))
        for r in result:
            allarray.append(r.replace("\n",""))

allarray,len(allarray)
    

准备数数据

python 复制代码
word2index={one:i for i,one in enumerate(allarray)}
index2word={i:one for i,one in enumerate(allarray)}
word2index[" "]=len(allarray)-1
index2word[len(allarray)-1]=" "
word2index[" "]

查看频次

python 复制代码
from collections import Counter
Counter(allarray)

获取vacab

python 复制代码
vocab_size = len(allarray)
vocab_size

生成输入数据

python 复制代码
# 生成输入数据
batch_x = []
batch_y = []
window=1
seq_length=vocab_size
for i in range(seq_length - window + 1):
    x = word2index[allarray[i]]
    if i + window >= seq_length:
        y = word2index[" "]
    else:
        y = word2index[allarray[i + 1]]
    batch_x.append([x])
    batch_y.append(y)

# print(batch_x)
# print("=======")
# print(batch_y)
# print(45/0)


# 训练数据
batch_x, batch_y = Variable(torch.LongTensor(batch_x)), Variable(torch.LongTensor(batch_y))
 
# 参数
# vocab_size = len(letters)
embedding_size = 100
n_hidden = 32
batch_size = 10
num_classes = vocab_size
 
dataset = Data.TensorDataset(batch_x, batch_y)
loader = Data.DataLoader(dataset, batch_size, shuffle=True)
 
# 建立模型
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.word_vec = nn.Embedding(vocab_size, embedding_size)
        # bidirectional双向LSTM
        self.bilstm = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
        self.lstm = nn.LSTM(2 * n_hidden, 2 * n_hidden, 1, bidirectional=False)
        self.fc = nn.Linear(n_hidden * 2, num_classes)
 
    def forward(self, input):
        embedding_input = self.word_vec(input)
#         print("embedding_input",embedding_input,embedding_input.size())
        # 调换第一维和第二维度
        embedding_input = embedding_input.permute(1, 0, 2)
        bilstm_output, (h_n1, c_n1) = self.bilstm(embedding_input)
        lstm_output, (h_n2, c_n2)= self.lstm(bilstm_output)
        fc_out = self.fc(lstm_output[-1])
        return fc_out
 
model = BiLSTM()

训练

python 复制代码
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
# 训练
for epoch in range(300):
    cost = 0
    for input_batch, target_batch in loader:
        pred = model(input_batch)
#         print("pred",pred)
#         print("target_batch",target_batch)
        loss = criterion(pred, target_batch)
        cost += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch: %d,  loss: %.5f " % (epoch, cost))

测试

python 复制代码
def test(str):
    test_text =str
    test_batch = [word2index[str]]
#     print(test_batch)
    test_batch = torch.LongTensor([test_batch])
#     print("test_batch",test_batch)
#     print(test_batch)
    out = model(test_batch)
    predict = torch.max(out, 1)[1].item()
#     print(test_text,"后一个字母为:", index2word[predict])
    return index2word[predict]

连续预测

python 复制代码
import time
s="布林肯"
while True:
    fenci=jieba.cut(s,False)
    fenciList=list(fenci)
    s=s+test(fenciList[-1:][0])
    
    time.sleep(1)
    print(s)
        
相关推荐
YXWik610 小时前
图片 OCR 文字提取 (Python + AI 模型(ModelScope))
人工智能·python·ocr
烛之武10 小时前
Pytorch学习笔记(1)
pytorch·笔记·学习
Thecozzy10 小时前
写文档教 AI 用代码
开发语言·python
Hanniel11 小时前
装饰器 (中): 进阶篇,解锁框架级玩法
开发语言·python
康哥爱编程11 小时前
鸿蒙应用开发之应用如何实现腾讯云对象存储?
python·云计算·腾讯云
${王小剑}11 小时前
在pycharm中配置pyside6
ide·python·pycharm
烛之武11 小时前
Python速通笔记
windows·python
挨踢诗人11 小时前
旺店通ERP集成金蝶云星空解决方案
python·数据集成
码界索隆11 小时前
Python转Java系列:作者有话说
java·开发语言·python
未来智慧谷11 小时前
【无标题】
人工智能·python·大模型·ai幻觉