昇思25天学习打卡营第17天|基于 MindSpore 实现 BERT 对话情绪识别

基于 MindSpore 实现 BERT 对话情绪识别

BERT介绍

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练语言模型,由谷歌在2018年提出。从以下6个方面来介绍BERT:

  1. 预训练和微调:BERT采用预训练和微调的策略。首先,在大量无标签文本数据上进行预训练,学习通用的语言表示。然后,在特定任务的有标签数据上进行微调,以适应特定任务的需求。

  2. 双向编码器:BERT的核心是双向编码器,它能够同时考虑上下文信息。与传统的单向编码器(如GPT)相比,BERT能够更好地理解语境中的多义词和长距离依赖关系。

  3. Masked Language Model(MLM):BERT在预训练阶段采用了MLM任务。这个任务会随机屏蔽输入句子中的一些单词,然后让模型预测这些被屏蔽的单词。这有助于模型学习到更丰富的上下文表示。

  4. Next Sentence Prediction(NSP):除了MLM任务,BERT还引入了NSP任务。这个任务的目的是让模型学会判断两个句子是否连续。这有助于模型理解句子之间的关系,提高对文本结构的把握。

  5. Transformer架构:BERT基于Transformer架构,这是一种自注意力机制的变体。Transformer具有并行计算能力强、捕捉长距离依赖关系好等优点,使得BERT在处理大规模文本数据时具有很高的效率。

  6. 广泛应用:BERT在自然语言处理领域取得了显著的成果,广泛应用于文本分类、命名实体识别、情感分析、问答系统等任务。此外,BERT还催生了许多改进版本和变体,如RoBERTa、ALBERT等,进一步推动了预训练语言模型的发展。

总之,BERT作为一种强大的预训练语言模型,凭借其双向编码器、MLM和NSP任务以及基于Transformer的架构,在自然语言处理领域取得了突破性的成果,为深度学习开发工程师提供了一种高效且有效的工具。

对话情绪识别

对话情绪识别是一种先进的人工智能技术,它通过自然语言处理和深度学习技术来分析对话中的情绪状态。下面将详细介绍一下对话情绪识别:

  1. 定义与技术
  • 基本概念:对话情绪识别旨在识别对话中的用户情绪状态,如积极、消极、中性等,以提供更精准的用户体验和服务。

  • 技术基础:该技术基于自然语言处理(NLP)和深度学习,通过处理和分析自然语言来提取情感信息。

  1. 情绪分类与识别
  • 情绪分类:对话情绪识别能识别多种情绪,包括正向情绪如喜爱、愉快、感谢,以及负向情绪如抱怨、愤怒、厌恶、恐惧、悲伤等。

  • 识别技术:现代对话情绪识别技术能够自动检测用户日常对话中的情绪特征,并提供有针对性的参考回复,有助于企业或应用方快速应对客户情绪。

  1. 应用场景与作用
  • 智能客服:对话情绪识别可以提升智能客服系统的效率,通过理解用户情感和需求,提供更合适的服务解决方案[。

  • 智能推荐:在智能推荐系统中,该技术可以根据用户的情绪状态推荐相应的产品或服务,提高推荐的准确性和用户满意度。

  • 市场调研:通过对话情绪识别技术,可以更准确地了解用户对产品的情感态度和反馈,为企业制定市场策略提供依据。

  1. 技术优势与挑战
  • 实时性:这一技术处理速度快,能够实时分析对话数据,及时提供情绪识别结果,适用于需要快速响应的场景。

  • 准确性:通过深度学习技术,对话情绪识别可以达到很高的准确率和精度,但在某些复杂情境下仍面临挑战,如讽刺、幽默等难以识别的言语。

  • 易用性:现代对话情绪识别产品通常提供简单易用的API接口,方便企业集成和应用。

  1. 未来发展趋势与改进
  • 模型优化:随着技术进步,对话情绪识别的模型将持续优化,提升对复杂情绪和语境的识别能力。

  • 应用深化:预计对话情绪识别将更广泛地应用于更多领域,如心理健康、在线教育等,为不同行业提供定制化解决方案。

总的来说,对话情绪识别作为一种基于NLP和深度学习的先进技术,能够有效识别和处理对话中的情绪信息,为智能客服、智能推荐等领域提供了强大的技术支持。随着技术的不断发展,其应用范围将进一步扩大,为企业和个人提供更加智能化、个性化的服务。

实践

环境准备

python: 3.9.19

安装环境依赖

bash 复制代码
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14


pip install mindnlp

完整的Python环境依赖

bash 复制代码
pip list
Package                        Version
------------------------------ --------------
absl-py                        2.1.0
addict                         2.4.0
aiofiles                       22.1.0
aiohttp                        3.9.5
aiosignal                      1.3.1
aiosqlite                      0.20.0
altair                         5.3.0
annotated-types                0.7.0
anyio                          4.4.0
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
astroid                        3.2.2
asttokens                      2.0.5
astunparse                     1.6.3
async-timeout                  4.0.3
attrs                          23.2.0
auto-tune                      0.1.0
autopep8                       1.5.5
Babel                          2.15.0
backcall                       0.2.0
beautifulsoup4                 4.12.3
black                          24.4.2
bleach                         6.1.0
certifi                        2024.6.2
cffi                           1.16.0
charset-normalizer             3.3.2
click                          8.1.7
cloudpickle                    3.0.0
colorama                       0.4.6
comm                           0.2.1
contextlib2                    21.6.0
contourpy                      1.2.1
cycler                         0.12.1
dataflow                       0.0.1
datasets                       2.20.0
debugpy                        1.6.7
decorator                      5.1.1
defusedxml                     0.7.1
dill                           0.3.8
dnspython                      2.6.1
download                       0.3.5
easydict                       1.13
email_validator                2.2.0
entrypoints                    0.4
evaluate                       0.4.2
exceptiongroup                 1.2.0
executing                      0.8.3
fastapi                        0.111.0
fastapi-cli                    0.0.4
fastjsonschema                 2.20.0
ffmpy                          0.3.2
filelock                       3.15.3
flake8                         3.8.4
fonttools                      4.53.0
fqdn                           1.5.1
frozenlist                     1.4.1
fsspec                         2024.5.0
gitdb                          4.0.11
GitPython                      3.1.43
gradio                         4.26.0
gradio_client                  0.15.1
h11                            0.14.0
hccl                           0.1.0
hccl-parser                    0.1
httpcore                       1.0.5
httptools                      0.6.1
httpx                          0.27.0
huggingface-hub                0.23.4
hypothesis                     6.105.1
idna                           3.7
importlib-metadata             7.0.1
importlib_resources            6.4.0
iniconfig                      2.0.0
ipykernel                      6.28.0
ipympl                         0.9.4
ipython                        8.15.0
ipython-genutils               0.2.0
ipywidgets                     8.1.3
isoduration                    20.11.0
isort                          5.13.2
jedi                           0.17.2
jieba                          0.42.1
Jinja2                         3.1.4
joblib                         1.4.2
json5                          0.9.25
jsonpointer                    3.0.0
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 7.4.9
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter-resource-usage         0.7.2
jupyter_server                 2.14.1
jupyter_server_fileid          0.9.2
jupyter-server-mathjax         0.2.6
jupyter_server_terminals       0.5.3
jupyter_server_ydoc            0.8.0
jupyter-ydoc                   0.2.5
jupyterlab                     3.6.7
jupyterlab_code_formatter      2.2.1
jupyterlab_git                 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp                 4.3.0
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab-system-monitor      0.8.0
jupyterlab-topbar              0.6.1
jupyterlab_widgets             3.0.11
kiwisolver                     1.4.5
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.6
mccabe                         0.6.1
mdurl                          0.1.2
mindnlp                        0.3.1
mindspore                      2.2.14
mindvision                     0.1.0
mistune                        3.0.2
ml_collections                 0.1.1
ml-dtypes                      0.4.0
mpmath                         1.3.0
msadvisor                      1.0.0
multidict                      6.0.5
multiprocess                   0.70.16
mypy-extensions                1.0.0
nbclassic                      1.1.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbdime                         4.0.1
nbformat                       5.10.4
nest-asyncio                   1.6.0
notebook                       6.5.7
notebook_shim                  0.2.4
numpy                          1.26.4
op-compile-tool                0.1.0
op-gen                         0.1
op-test-frame                  0.1
opc-tool                       0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python                  4.10.0.84
opencv-python-headless         4.10.0.84
orjson                         3.10.5
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.2
pandocfilters                  1.5.1
parso                          0.7.1
pathlib2                       2.3.7.post1
pathspec                       0.12.1
pexpect                        4.8.0
pickleshare                    0.7.5
pillow                         10.3.0
pip                            24.1
platformdirs                   4.2.2
pluggy                         1.5.0
prometheus_client              0.20.0
prompt-toolkit                 3.0.43
protobuf                       5.27.1
psutil                         5.9.0
ptyprocess                     0.7.0
pure-eval                      0.2.2
pyarrow                        16.1.0
pyarrow-hotfix                 0.6
pycodestyle                    2.6.0
pycparser                      2.22
pyctcdecode                    0.5.0
pydantic                       2.7.4
pydantic_core                  2.18.4
pydocstyle                     6.3.0
pydub                          0.25.1
pyflakes                       2.2.0
Pygments                       2.15.1
pygtrie                        2.5.0
pylint                         3.2.3
pyparsing                      3.1.2
pytest                         7.2.0
python-dateutil                2.9.0.post0
python-dotenv                  1.0.1
python-json-logger             2.0.7
python-jsonrpc-server          0.4.0
python-language-server         0.36.2
python-multipart               0.0.9
pytoolconfig                   1.3.1
pytz                           2024.1
PyYAML                         6.0.1
pyzmq                          25.1.2
referencing                    0.35.1
regex                          2024.5.15
requests                       2.32.3
rfc3339-validator              0.1.4
rfc3986-validator              0.1.1
rich                           13.7.1
rope                           1.13.0
rpds-py                        0.18.1
ruff                           0.4.10
safetensors                    0.4.3
schedule-search                0.0.1
scikit-learn                   1.5.0
scipy                          1.13.1
semantic-version               2.10.0
Send2Trash                     1.8.3
sentencepiece                  0.2.0
setuptools                     69.5.1
shellingham                    1.5.4
six                            1.16.0
smmap                          5.0.1
sniffio                        1.3.1
snowballstemmer                2.2.0
sortedcontainers               2.4.0
soupsieve                      2.5
stack-data                     0.2.0
starlette                      0.37.2
sympy                          1.12.1
synr                           0.5.0
te                             0.4.0
terminado                      0.18.1
threadpoolctl                  3.5.0
tinycss2                       1.3.0
tokenizers                     0.19.1
toml                           0.10.2
tomli                          2.0.1
tomlkit                        0.12.0
toolz                          0.12.1
tornado                        6.4.1
tqdm                           4.66.4
traitlets                      5.14.3
typer                          0.12.3
types-python-dateutil          2.9.0.20240316
typing_extensions              4.11.0
tzdata                         2024.1
ujson                          5.10.0
uri-template                   1.3.0
urllib3                        2.2.2
uvicorn                        0.30.1
uvloop                         0.19.0
watchfiles                     0.22.0
wcwidth                        0.2.5
webcolors                      24.6.0
webencodings                   0.5.1
websocket-client               1.8.0
websockets                     11.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
xxhash                         3.4.1
y-py                           0.6.2
yapf                           0.40.2
yarl                           1.9.4
ypy-websocket                  0.8.4
zipp                           3.17.0

数据集介绍

数据集是已标注的、经过分词预处理的机器人聊天数据集,来自于百度飞桨团队。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label--text_a

0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1--我有事等会儿就回来和你聊

2--我见到你很高兴谢谢你帮我

这部分主要包括数据集读取,数据格式转换,数据 Tokenize 处理和 pad 操作。

python 复制代码
# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

实践代码

python 复制代码
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy

# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)


# 数据加载和数据预处理
# 新建 process_dataset 函数用于数据加载和数据预处理,具体内容可见下面代码注释。

import numpy as np

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    type_cast_op = transforms.TypeCast(mindspore.int32)
    def tokenize_and_pad(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']
    # map dataset
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})

    return dataset


# 昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理:
from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

tokenizer.pad_token_id
python 复制代码
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)


print(dataset_train.get_col_names())
print(next(dataset_train.create_tuple_iterator()))

模型构建

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

python 复制代码
from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision

# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
python 复制代码
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])

%%time
# start training
trainer.run(tgt_columns="labels")

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

python 复制代码
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

模型推理

遍历推理数据集,将结果与标签进行统一展示。

python 复制代码
dataset_infer = SentimentDataset("data/infer.tsv")

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}

    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)


from mindspore import Tensor

for label, text in dataset_infer:
    predict(text, label)

自定义推理数据集

自己输入推理数据,展示模型的泛化能力。

相关推荐
Python极客之家3 分钟前
基于深度学习的眼部疾病检测识别系统
人工智能·python·深度学习·毕业设计·卷积神经网络
繁依Fanyi20 分钟前
828 华为云征文|华为 Flexus 云服务器部署 RustDesk Server,打造自己的远程桌面服务器
运维·服务器·开发语言·人工智能·pytorch·华为·华为云
shuxianshrng22 分钟前
鹰眼降尘系统怎么样
大数据·服务器·人工智能·数码相机·物联网
优思学院26 分钟前
优思学院|如何从零开始自己学习六西格玛?
大数据·运维·服务器·学习·六西格玛黑带·cssbb
说私域26 分钟前
开源 AI 智能名片小程序:开启内容营销新境界
人工智能·小程序
红米煮粥32 分钟前
OpenCV-直方图
人工智能·opencv·计算机视觉
LN花开富贵42 分钟前
stm32g431rbt6芯片中VREF+是什么?在电路中怎么设计?
笔记·stm32·单片机·嵌入式硬件·学习
怀九日42 分钟前
C++(学习)2024.9.18
开发语言·c++·学习·面向对象·引用·
一道秘制的小菜43 分钟前
C++第七节课 运算符重载
服务器·开发语言·c++·学习·算法
DisonTangor1 小时前
上海人工智能实验室开源视频生成模型Vchitect 2.0 可生成20秒高清视频
人工智能·音视频