python-pytorch实现lstm模型预测文本输出0.1.00

python-pytorch实现lstm模型预测文本输出0.1.00

有问题还需要完善

数据

一篇新闻:https://news.sina.com.cn/c/2024-04-12/doc-inarqiev0222543.shtml

参考

https://blog.csdn.net/qq_19530977/article/details/120936391

python 复制代码
# https://blog.csdn.net/qq_19530977/article/details/120936391

效果

python 复制代码
"""
布林肯国务卿
布林肯国务卿同王毅
布林肯国务卿同王毅主任
布林肯国务卿同王毅主任以及
布林肯国务卿同王毅主任以及其他
布林肯国务卿同王毅主任以及其他国家
布林肯国务卿同王毅主任以及其他国家敦促
布林肯国务卿同王毅主任以及其他国家敦促伊朗
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅
布林肯国务卿同王毅主任以及其他国家敦促伊朗驻叙利亚使馆的安全不容侵犯,布林肯国务卿同王毅主任
"""

导入包

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torch.autograd import Variable
import jieba

分词到数组

复制文章到txt文档

python 复制代码
allarray=[]
with open("./howtousercbow/data/news.txt",encoding="utf-8") as afterjieba:
    lines=afterjieba.readlines()
    print(lines)
    for line in lines:
        result=list(jieba.cut(line,False))
        for r in result:
            allarray.append(r.replace("\n",""))

allarray,len(allarray)
    

准备数数据

python 复制代码
word2index={one:i for i,one in enumerate(allarray)}
index2word={i:one for i,one in enumerate(allarray)}
word2index[" "]=len(allarray)-1
index2word[len(allarray)-1]=" "
word2index[" "]

查看频次

python 复制代码
from collections import Counter
Counter(allarray)

获取vacab

python 复制代码
vocab_size = len(allarray)
vocab_size

生成输入数据

python 复制代码
# 生成输入数据
batch_x = []
batch_y = []
window=1
seq_length=vocab_size
for i in range(seq_length - window + 1):
    x = word2index[allarray[i]]
    if i + window >= seq_length:
        y = word2index[" "]
    else:
        y = word2index[allarray[i + 1]]
    batch_x.append([x])
    batch_y.append(y)

# print(batch_x)
# print("=======")
# print(batch_y)
# print(45/0)


# 训练数据
batch_x, batch_y = Variable(torch.LongTensor(batch_x)), Variable(torch.LongTensor(batch_y))
 
# 参数
# vocab_size = len(letters)
embedding_size = 100
n_hidden = 32
batch_size = 10
num_classes = vocab_size
 
dataset = Data.TensorDataset(batch_x, batch_y)
loader = Data.DataLoader(dataset, batch_size, shuffle=True)
 
# 建立模型
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.word_vec = nn.Embedding(vocab_size, embedding_size)
        # bidirectional双向LSTM
        self.bilstm = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
        self.lstm = nn.LSTM(2 * n_hidden, 2 * n_hidden, 1, bidirectional=False)
        self.fc = nn.Linear(n_hidden * 2, num_classes)
 
    def forward(self, input):
        embedding_input = self.word_vec(input)
#         print("embedding_input",embedding_input,embedding_input.size())
        # 调换第一维和第二维度
        embedding_input = embedding_input.permute(1, 0, 2)
        bilstm_output, (h_n1, c_n1) = self.bilstm(embedding_input)
        lstm_output, (h_n2, c_n2)= self.lstm(bilstm_output)
        fc_out = self.fc(lstm_output[-1])
        return fc_out
 
model = BiLSTM()

训练

python 复制代码
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
# 训练
for epoch in range(300):
    cost = 0
    for input_batch, target_batch in loader:
        pred = model(input_batch)
#         print("pred",pred)
#         print("target_batch",target_batch)
        loss = criterion(pred, target_batch)
        cost += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch: %d,  loss: %.5f " % (epoch, cost))

测试

python 复制代码
def test(str):
    test_text =str
    test_batch = [word2index[str]]
#     print(test_batch)
    test_batch = torch.LongTensor([test_batch])
#     print("test_batch",test_batch)
#     print(test_batch)
    out = model(test_batch)
    predict = torch.max(out, 1)[1].item()
#     print(test_text,"后一个字母为:", index2word[predict])
    return index2word[predict]

连续预测

python 复制代码
import time
s="布林肯"
while True:
    fenci=jieba.cut(s,False)
    fenciList=list(fenci)
    s=s+test(fenciList[-1:][0])
    
    time.sleep(1)
    print(s)
        
相关推荐
aqi008 小时前
15天学会AI应用开发(十)把文本嵌入模型换成国产模型
人工智能·python·ai编程
金銀銅鐵1 天前
[Python] 扩展欧几里得算法
python·数学·算法
Duckdblab1 天前
DuckDB 性能调优终极指南:打造闪电般的分析体验
python
带派擂总1 天前
Python全栈开发精华版最全合集(包含各种面试题) Day24_异常和错误
python
金銀銅鐵1 天前
n^5 和 n 的个位数是否总相等?
python·数学
aqi001 天前
15天学会AI应用开发(九)利用Chroma持久化向量数据
人工智能·python·大模型·ai编程·ai应用
金銀銅鐵1 天前
借助 Pygame 探索最大公约数的规律
python·数学·游戏
weiwei228442 天前
神经网络模型导出及开放标准格式ONNX
pytorch·onnx