Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)

Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

模型训练与推理

1. 模型加载

python 复制代码
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')
# 模型加载
model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)

其中:vocab, embeddings = load_glove(glove_path)
模型构建和实例化参数:

python 复制代码
  embeddings:输入向量,是数据集经过glove模型统一处理的词向量数值特征,
  hidden_dim:隐藏层特征的维度, 
  output_dim:输出维数, 
  n_layers:RNN 层的数量,
  bidirectional:是否为双向 RNN, 
  pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

2.模型训练

python 复制代码
def train():
    # 音频数据集
    imdb_path = r'./IMDB/aclImdb_v1.tar.gz'

    # 训练集和测试集生成
    imdb_train, imdb_test = load_imdb(imdb_path)  # review评论-标签,数据集

    # 预训练词向量表
    glove_path = r"./IMDB/glove.6B.zip"
    vocab, embeddings = load_glove(glove_path)  # 预定义词向量表

    # 语句标签-数据集。将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。每条评论500字。
    lookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
    pad_op = ds.transforms.PadEnd([500],
                                  pad_value=vocab.tokens_to_ids('<pad>'))  # 使用PadEnd接口,定义最大长度和补齐值(pad_value),取最大长度为500
    type_cast_op = ds.transforms.TypeCast(ms.float32)  # 将label数据转为float32格式
    # 预处理操作流水线
    imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])
    imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])

    # 由于IMDB数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。
    imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
    # 调用数据集的map、split、batch为数据集处理流水线增加对应操作,返回值为新的Dataset类型。现在仅定义流水线操作,在执行时开始执行数据处理流水线,获取最终处理好的数据并送入模型进行训练。
    imdb_train = imdb_train.batch(64, drop_remainder=True)
    imdb_valid = imdb_valid.batch(64, drop_remainder=True)

    # 定义训练参数
    hidden_size = 256
    output_size = 1
    num_layers = 2
    bidirectional = True
    lr = 0.001
    pad_idx = vocab.tokens_to_ids('<pad>')

    model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss

    grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

    def train_step(data, label):
        loss, grads = grad_fn(data, label)
        optimizer(grads)
        return loss

    def train_one_epoch(model, train_dataset, epoch=0):
        model.set_train()
        total = train_dataset.get_dataset_size()
        loss_total = 0
        step_total = 0
        with tqdm(total=total) as t:
            t.set_description('Epoch %i' % epoch)
            for i in train_dataset.create_tuple_iterator():
                loss = train_step(*i)
                loss_total += loss.asnumpy()
                step_total += 1
                t.set_postfix(loss=loss_total / step_total)
                t.update(1)

    num_epochs = 50
    best_valid_loss = float('inf')
    ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')

    for epoch in range(num_epochs):
        train_one_epoch(model, imdb_train, epoch)
        valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            ms.save_checkpoint(model, ckpt_file_name)


if __name__ == "__main__":
    train()

训练完成:

3.小结

本节实现了情感分类模型的训练,精度达到99.9%,损失值降到0.0027。下一节将进行模型部署应用。

相关推荐
ShuQiHere1 分钟前
【ShuQiHere】 探索数据挖掘的世界:从概念到应用
人工智能·数据挖掘
嵌入式杂谈1 分钟前
OpenCV计算机视觉:探索图片处理的多种操作
人工智能·opencv·计算机视觉
时光追逐者3 分钟前
分享6个.NET开源的AI和LLM相关项目框架
人工智能·microsoft·ai·c#·.net·.netcore
东隆科技3 分钟前
PicoQuant公司:探索铜铟镓硒(CIGS)太阳能电池技术,引领绿色能源革新
人工智能·能源
DisonTangor15 分钟前
上海AI气象大模型提前6天预测“贝碧嘉”台风登陆浦东 今年已多次精准预测
人工智能
人工智能培训咨询叶梓32 分钟前
生成式人工智能在无人机群中的应用、挑战和机遇
人工智能·语言模型·自然语言处理·aigc·无人机·多模态·生成式人工智能
羊小猪~~43 分钟前
深度学习基础案例5--VGG16人脸识别(体验学习的痛苦与乐趣)
人工智能·python·深度学习·学习·算法·机器学习·cnn
Zhangci]43 分钟前
OpenCv(一)
人工智能·opencv·计算机视觉
钡铼技术1 小时前
通过iFIX在ARMxy边缘计算网关上实现维护管理
人工智能·物联网·边缘计算·钡铼技术·armxy边缘计算网关
m0_609000422 小时前
向日葵好用吗?4款稳定的远程控制软件推荐。
运维·服务器·网络·人工智能·远程工作