Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)

Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

一、模型资源下载

  1. RNN升级版LSTM模型:本项目训练好的情感分类模型-下载训练好的IMDB分类模型

二、模型加载与推理

python 复制代码
class RNN(nn.Cell):
    def __init__(self, embeddings, hidden_dim, output_dim, n_layers,
                 bidirectional, pad_idx):
        super().__init__()
        vocab_size, embedding_dim = embeddings.shape
        self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings),
                                      padding_idx=pad_idx)
        self.rnn = nn.LSTM(embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           batch_first=True)
        weight_init = HeUniform(math.sqrt(5))
        bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
        self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)

    def construct(self, inputs):
        embedded = self.embedding(inputs)
        _, (hidden, _) = self.rnn(embedded)
        hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
        output = self.fc(hidden)
        return output

编写预测接口:test_interface

python 复制代码
def predict_sentiment(model, vocab, sentence):
    score_map = {
        1: "Positive",
        0: "Negative"
    }
    model.set_train(False)
    tokenized = sentence.lower().split()
    indexed = vocab.tokens_to_ids(tokenized)
    tensor = ms.Tensor(indexed, ms.int32)
    tensor = tensor.expand_dims(0)
    prediction = model(tensor)
    return score_map[int(np.round(ops.sigmoid(prediction).asnumpy()))]

def test_interface():
    # train()
    score_map = {
        1: "Positive",
        0: "Negative"
    }

    ckpt_file_name = './IMDB/IMDB/sentiment-analysis.ckpt'
    # 预训练词向量表
    glove_path = r"./IMDB/IMDB/glove.6B.zip"
    vocab, embeddings = load_glove(glove_path)  # 预定义词向量表
    hidden_size = 256
    output_size = 1
    num_layers = 2
    bidirectional = True
    pad_idx = vocab.tokens_to_ids('<pad>')
    model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
    param_dict = ms.load_checkpoint(ckpt_file_name)
    ms.load_param_into_net(model, param_dict)

    # 预测
    while True:
        try:
            print("go on!")
            sentence = input("请输入:")
            res = predict_sentiment(model, vocab, sentence)
            print("用户输入的内容为:", sentence, "评价结果是:", res)
        except:
            break

def load_glove(glove_path):
    glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')  # 保存数据词典
    if not os.path.exists(glove_100d_path):
        glove_zip = zipfile.ZipFile(glove_path)
        glove_zip.extractall(cache_dir)

    embeddings = []
    tokens = []
    with open(glove_100d_path, encoding='utf-8') as gf:
        for glove in gf:
            word, embedding = glove.split(maxsplit=1)
            tokens.append(word)
            embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))
    # 添加 <unk>, <pad> 两个特殊占位符对应的embedding
    embeddings.append(np.random.rand(100))
    embeddings.append(np.zeros((100,), np.float32))

    vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)
    embeddings = np.array(embeddings).astype(np.float32)
    return vocab, embeddings

预测推理

python 复制代码
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import os
import zipfile
import numpy as np

test_interface()

预测结果。

相关推荐
weixin_468466853 分钟前
图像连通域分析新手实战指南
图像处理·人工智能·深度学习·ai·机器视觉·连通域
硅谷秋水1 小时前
世界动作模型:具身智能的下一前沿
大数据·人工智能·深度学习·计算机视觉·语言模型·机器人
大江东去浪淘尽千古风流人物3 小时前
【HaMeR】全Transformer架构的单目3D手部网格重建:ViT-H骨干+跨注意力MANO解码器源码深度解析
深度学习·3d·transformer·vit·手部重建·mano
钓了猫的鱼儿3 小时前
基于深度学习+AI的红外电力设备故障目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·目标检测
LaughingZhu3 小时前
Product Hunt 每日热榜 | 2026-05-30
人工智能·经验分享·深度学习·神经网络·产品运营
蒟蒻的贤4 小时前
深度学习底层核心原理:损失函数、梯度与参数更新
人工智能·深度学习
谷哥的小弟4 小时前
大模型核心基础知识(14)—神经网络的结构
人工智能·深度学习·神经网络·大模型·大语言模型
大模型最新论文速读4 小时前
SkillOpt:把 skill 文档当成模型权重来训练
论文阅读·人工智能·深度学习·机器学习·自然语言处理
z小猫不吃鱼5 小时前
15 InstructGPT 论文精读:SFT + RLHF 如何让模型听懂指令?
人工智能·深度学习·算法·机器学习·语言模型·自然语言处理·gpt-3
zcg19425 小时前
如何在CV中使用transformer
人工智能·深度学习·transformer