昇思25天学习打卡营第17天|基于 MindSpore 实现 BERT 对话情绪识别

基于 MindSpore 实现 BERT 对话情绪识别

BERT介绍

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练语言模型,由谷歌在2018年提出。从以下6个方面来介绍BERT:

  1. 预训练和微调:BERT采用预训练和微调的策略。首先,在大量无标签文本数据上进行预训练,学习通用的语言表示。然后,在特定任务的有标签数据上进行微调,以适应特定任务的需求。

  2. 双向编码器:BERT的核心是双向编码器,它能够同时考虑上下文信息。与传统的单向编码器(如GPT)相比,BERT能够更好地理解语境中的多义词和长距离依赖关系。

  3. Masked Language Model(MLM):BERT在预训练阶段采用了MLM任务。这个任务会随机屏蔽输入句子中的一些单词,然后让模型预测这些被屏蔽的单词。这有助于模型学习到更丰富的上下文表示。

  4. Next Sentence Prediction(NSP):除了MLM任务,BERT还引入了NSP任务。这个任务的目的是让模型学会判断两个句子是否连续。这有助于模型理解句子之间的关系,提高对文本结构的把握。

  5. Transformer架构:BERT基于Transformer架构,这是一种自注意力机制的变体。Transformer具有并行计算能力强、捕捉长距离依赖关系好等优点,使得BERT在处理大规模文本数据时具有很高的效率。

  6. 广泛应用:BERT在自然语言处理领域取得了显著的成果,广泛应用于文本分类、命名实体识别、情感分析、问答系统等任务。此外,BERT还催生了许多改进版本和变体,如RoBERTa、ALBERT等,进一步推动了预训练语言模型的发展。

总之,BERT作为一种强大的预训练语言模型,凭借其双向编码器、MLM和NSP任务以及基于Transformer的架构,在自然语言处理领域取得了突破性的成果,为深度学习开发工程师提供了一种高效且有效的工具。

对话情绪识别

对话情绪识别是一种先进的人工智能技术,它通过自然语言处理和深度学习技术来分析对话中的情绪状态。下面将详细介绍一下对话情绪识别:

  1. 定义与技术
  • 基本概念:对话情绪识别旨在识别对话中的用户情绪状态,如积极、消极、中性等,以提供更精准的用户体验和服务。

  • 技术基础:该技术基于自然语言处理(NLP)和深度学习,通过处理和分析自然语言来提取情感信息。

  1. 情绪分类与识别
  • 情绪分类:对话情绪识别能识别多种情绪,包括正向情绪如喜爱、愉快、感谢,以及负向情绪如抱怨、愤怒、厌恶、恐惧、悲伤等。

  • 识别技术:现代对话情绪识别技术能够自动检测用户日常对话中的情绪特征,并提供有针对性的参考回复,有助于企业或应用方快速应对客户情绪。

  1. 应用场景与作用
  • 智能客服:对话情绪识别可以提升智能客服系统的效率,通过理解用户情感和需求,提供更合适的服务解决方案[。

  • 智能推荐:在智能推荐系统中,该技术可以根据用户的情绪状态推荐相应的产品或服务,提高推荐的准确性和用户满意度。

  • 市场调研:通过对话情绪识别技术,可以更准确地了解用户对产品的情感态度和反馈,为企业制定市场策略提供依据。

  1. 技术优势与挑战
  • 实时性:这一技术处理速度快,能够实时分析对话数据,及时提供情绪识别结果,适用于需要快速响应的场景。

  • 准确性:通过深度学习技术,对话情绪识别可以达到很高的准确率和精度,但在某些复杂情境下仍面临挑战,如讽刺、幽默等难以识别的言语。

  • 易用性:现代对话情绪识别产品通常提供简单易用的API接口,方便企业集成和应用。

  1. 未来发展趋势与改进
  • 模型优化:随着技术进步,对话情绪识别的模型将持续优化,提升对复杂情绪和语境的识别能力。

  • 应用深化:预计对话情绪识别将更广泛地应用于更多领域,如心理健康、在线教育等,为不同行业提供定制化解决方案。

总的来说,对话情绪识别作为一种基于NLP和深度学习的先进技术,能够有效识别和处理对话中的情绪信息,为智能客服、智能推荐等领域提供了强大的技术支持。随着技术的不断发展,其应用范围将进一步扩大,为企业和个人提供更加智能化、个性化的服务。

实践

环境准备

python: 3.9.19

安装环境依赖

bash 复制代码
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14


pip install mindnlp

完整的Python环境依赖

bash 复制代码
pip list
Package                        Version
------------------------------ --------------
absl-py                        2.1.0
addict                         2.4.0
aiofiles                       22.1.0
aiohttp                        3.9.5
aiosignal                      1.3.1
aiosqlite                      0.20.0
altair                         5.3.0
annotated-types                0.7.0
anyio                          4.4.0
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
astroid                        3.2.2
asttokens                      2.0.5
astunparse                     1.6.3
async-timeout                  4.0.3
attrs                          23.2.0
auto-tune                      0.1.0
autopep8                       1.5.5
Babel                          2.15.0
backcall                       0.2.0
beautifulsoup4                 4.12.3
black                          24.4.2
bleach                         6.1.0
certifi                        2024.6.2
cffi                           1.16.0
charset-normalizer             3.3.2
click                          8.1.7
cloudpickle                    3.0.0
colorama                       0.4.6
comm                           0.2.1
contextlib2                    21.6.0
contourpy                      1.2.1
cycler                         0.12.1
dataflow                       0.0.1
datasets                       2.20.0
debugpy                        1.6.7
decorator                      5.1.1
defusedxml                     0.7.1
dill                           0.3.8
dnspython                      2.6.1
download                       0.3.5
easydict                       1.13
email_validator                2.2.0
entrypoints                    0.4
evaluate                       0.4.2
exceptiongroup                 1.2.0
executing                      0.8.3
fastapi                        0.111.0
fastapi-cli                    0.0.4
fastjsonschema                 2.20.0
ffmpy                          0.3.2
filelock                       3.15.3
flake8                         3.8.4
fonttools                      4.53.0
fqdn                           1.5.1
frozenlist                     1.4.1
fsspec                         2024.5.0
gitdb                          4.0.11
GitPython                      3.1.43
gradio                         4.26.0
gradio_client                  0.15.1
h11                            0.14.0
hccl                           0.1.0
hccl-parser                    0.1
httpcore                       1.0.5
httptools                      0.6.1
httpx                          0.27.0
huggingface-hub                0.23.4
hypothesis                     6.105.1
idna                           3.7
importlib-metadata             7.0.1
importlib_resources            6.4.0
iniconfig                      2.0.0
ipykernel                      6.28.0
ipympl                         0.9.4
ipython                        8.15.0
ipython-genutils               0.2.0
ipywidgets                     8.1.3
isoduration                    20.11.0
isort                          5.13.2
jedi                           0.17.2
jieba                          0.42.1
Jinja2                         3.1.4
joblib                         1.4.2
json5                          0.9.25
jsonpointer                    3.0.0
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 7.4.9
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter-resource-usage         0.7.2
jupyter_server                 2.14.1
jupyter_server_fileid          0.9.2
jupyter-server-mathjax         0.2.6
jupyter_server_terminals       0.5.3
jupyter_server_ydoc            0.8.0
jupyter-ydoc                   0.2.5
jupyterlab                     3.6.7
jupyterlab_code_formatter      2.2.1
jupyterlab_git                 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp                 4.3.0
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab-system-monitor      0.8.0
jupyterlab-topbar              0.6.1
jupyterlab_widgets             3.0.11
kiwisolver                     1.4.5
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.6
mccabe                         0.6.1
mdurl                          0.1.2
mindnlp                        0.3.1
mindspore                      2.2.14
mindvision                     0.1.0
mistune                        3.0.2
ml_collections                 0.1.1
ml-dtypes                      0.4.0
mpmath                         1.3.0
msadvisor                      1.0.0
multidict                      6.0.5
multiprocess                   0.70.16
mypy-extensions                1.0.0
nbclassic                      1.1.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbdime                         4.0.1
nbformat                       5.10.4
nest-asyncio                   1.6.0
notebook                       6.5.7
notebook_shim                  0.2.4
numpy                          1.26.4
op-compile-tool                0.1.0
op-gen                         0.1
op-test-frame                  0.1
opc-tool                       0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python                  4.10.0.84
opencv-python-headless         4.10.0.84
orjson                         3.10.5
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.2
pandocfilters                  1.5.1
parso                          0.7.1
pathlib2                       2.3.7.post1
pathspec                       0.12.1
pexpect                        4.8.0
pickleshare                    0.7.5
pillow                         10.3.0
pip                            24.1
platformdirs                   4.2.2
pluggy                         1.5.0
prometheus_client              0.20.0
prompt-toolkit                 3.0.43
protobuf                       5.27.1
psutil                         5.9.0
ptyprocess                     0.7.0
pure-eval                      0.2.2
pyarrow                        16.1.0
pyarrow-hotfix                 0.6
pycodestyle                    2.6.0
pycparser                      2.22
pyctcdecode                    0.5.0
pydantic                       2.7.4
pydantic_core                  2.18.4
pydocstyle                     6.3.0
pydub                          0.25.1
pyflakes                       2.2.0
Pygments                       2.15.1
pygtrie                        2.5.0
pylint                         3.2.3
pyparsing                      3.1.2
pytest                         7.2.0
python-dateutil                2.9.0.post0
python-dotenv                  1.0.1
python-json-logger             2.0.7
python-jsonrpc-server          0.4.0
python-language-server         0.36.2
python-multipart               0.0.9
pytoolconfig                   1.3.1
pytz                           2024.1
PyYAML                         6.0.1
pyzmq                          25.1.2
referencing                    0.35.1
regex                          2024.5.15
requests                       2.32.3
rfc3339-validator              0.1.4
rfc3986-validator              0.1.1
rich                           13.7.1
rope                           1.13.0
rpds-py                        0.18.1
ruff                           0.4.10
safetensors                    0.4.3
schedule-search                0.0.1
scikit-learn                   1.5.0
scipy                          1.13.1
semantic-version               2.10.0
Send2Trash                     1.8.3
sentencepiece                  0.2.0
setuptools                     69.5.1
shellingham                    1.5.4
six                            1.16.0
smmap                          5.0.1
sniffio                        1.3.1
snowballstemmer                2.2.0
sortedcontainers               2.4.0
soupsieve                      2.5
stack-data                     0.2.0
starlette                      0.37.2
sympy                          1.12.1
synr                           0.5.0
te                             0.4.0
terminado                      0.18.1
threadpoolctl                  3.5.0
tinycss2                       1.3.0
tokenizers                     0.19.1
toml                           0.10.2
tomli                          2.0.1
tomlkit                        0.12.0
toolz                          0.12.1
tornado                        6.4.1
tqdm                           4.66.4
traitlets                      5.14.3
typer                          0.12.3
types-python-dateutil          2.9.0.20240316
typing_extensions              4.11.0
tzdata                         2024.1
ujson                          5.10.0
uri-template                   1.3.0
urllib3                        2.2.2
uvicorn                        0.30.1
uvloop                         0.19.0
watchfiles                     0.22.0
wcwidth                        0.2.5
webcolors                      24.6.0
webencodings                   0.5.1
websocket-client               1.8.0
websockets                     11.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
xxhash                         3.4.1
y-py                           0.6.2
yapf                           0.40.2
yarl                           1.9.4
ypy-websocket                  0.8.4
zipp                           3.17.0

数据集介绍

数据集是已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label--text_a

0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1--我有事等会儿就回来和你聊

2--我见到你很高兴谢谢你帮我

这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

python 复制代码
# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

实践代码

python 复制代码
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy

# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)


# 数据加载和数据预处理
# 新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。

import numpy as np

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    type_cast_op = transforms.TypeCast(mindspore.int32)
    def tokenize_and_pad(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']
    # map dataset
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})

    return dataset


# 昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:
from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

tokenizer.pad_token_id
python 复制代码
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)


print(dataset_train.get_col_names())
print(next(dataset_train.create_tuple_iterator()))

模型构建

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

python 复制代码
from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision

# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
python 复制代码
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])

%%time
# start training
trainer.run(tgt_columns="labels")

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

python 复制代码
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

模型推理

遍历推理数据集,将结果与标签进行统一展示。

python 复制代码
dataset_infer = SentimentDataset("data/infer.tsv")

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}

    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)


from mindspore import Tensor

for label, text in dataset_infer:
    predict(text, label)

自定义推理数据集

自己输入推理数据,展示模型的泛化能力。

相关推荐
一只特立独行的程序猿9 分钟前
关于GCC内联汇编(也可以叫内嵌汇编)的简单学习
汇编·学习·gcc
虾球xz15 分钟前
游戏引擎学习第10天
学习·游戏引擎
Chef_Chen18 分钟前
从0开始学习机器学习--Day25--SVM作业
学习·机器学习·支持向量机
矢量赛奇22 分钟前
比ChatGPT更酷的AI工具
人工智能·ai·ai写作·视频
L_cl26 分钟前
Python学习从0到1 day28 Python 高阶技巧 ⑧ 递归
学习
KuaFuAI30 分钟前
微软推出的AI无代码编程微应用平台GitHub Spark和国产AI原生无代码工具CodeFlying比到底咋样?
人工智能·github·aigc·ai编程·codeflying·github spark·自然语言开发软件
Make_magic39 分钟前
Git学习教程(更新中)
大数据·人工智能·git·elasticsearch·计算机视觉
shelly聊AI43 分钟前
语音识别原理:AI 是如何听懂人类声音的
人工智能·语音识别
vortex544 分钟前
Vim 编辑器学习笔记
学习·编辑器·vim
源于花海1 小时前
论文学习(四) | 基于数据驱动的锂离子电池健康状态估计和剩余使用寿命预测
论文阅读·人工智能·学习·论文笔记