昇思25天学习打卡营第23天|基于mindspore bert对话情绪识别

Interesting thing!

About Bert you just need to know that it is like gpt, but focus on pre-training Encoder instead of decoder. It has a mask method which enhances its precision remarkbably. (judge not only the word before the blank but the later one )

model : BertForSequenceClassfication constructs the model and load the config and set the sentiment classification to 3 kinds

python 复制代码
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels = 3)
model = auto_mixed_precision(model, '01')
optimizer = nn.Adam(model.trainable_params(), learning_rate = 2e-5)
metric = Accuracy()
ckpoint_cb =  CheckpointCallback(save_path = 'checkpoint', ckpt_name = 'bert_emotect', epochs = 1, keep_checkpoint_max = 2)
best_model_cb = BestModelCallback(save_path = 'checkpoint', ckpt_name = 'bert_emotect_best', auto_load = True)
trainer = Trainer(network = model, train_dataset = dataset_train,
                    eval_dataset=dataset_val, metrics = metric,
                    epochs = 5, optimizer = optimizer, callback = [ckpoint_cb, best_model_cb])
trainer.run(tgt_columns = 'labels')

the model validation and prediction are the same mostly like Sentiment by any model:

python 复制代码
evaluator = Evaluator(network = model, eval_dataset = dataset_test, metrics= metric)
evaluator.run(tgt_columns='labels')

dataset_infer = SentimentDataset('data/infer.tsv')
def predict(text, label = None):
    label_map = {0:'消极', 1:'中性', 2:'积极'}
    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs:'{text}',predict:
'{label_map[predict_label]}'"
    if label is not None:
        info += f", label:'{label_map[label]}'"
    print(info)
相关推荐
星越华夏2 小时前
计算机视觉:YOLOv12安装环境
人工智能·yolo·计算机视觉
Yolanda943 小时前
【人工智能】《从零搭建AI问答助手项目(九):Prompt优化》
人工智能·prompt
wj3055853784 小时前
课程 9:模型测试记录与 Prompt 策略
linux·人工智能·python·comfyui
小和尚同志4 小时前
深入使用 skill-creator:结合真实生产级实践
人工智能·aigc
DevSecOps选型指南4 小时前
安全419专访悬镜安全 | 穿越周期在 AI 浪潮中定义数字供应链安全新范式
人工智能
沪漂阿龙4 小时前
面试题详解:GraphRAG 全面解析——知识图谱增强 RAG、Local Search、Global Search、社区摘要、工程落地与评估指标一次讲透
人工智能·知识图谱
WangN24 小时前
Unitree RL Lab 学习笔记【通识】
人工智能·机器学习
吃好睡好便好4 小时前
在Matlab中绘制横直方图
开发语言·学习·算法·matlab
haina20194 小时前
海纳AI亮相《科创中国》,解码招聘“智”变之路
人工智能·ai面试·ai招聘
阿星AI工作室4 小时前
刘润年中大课笔记:一句话说清AI落地之战的本质
大数据·人工智能·创业创新·商业