transformers生成式对话机器人

简介

生成式对话机器人是一种先进的人工智能系统,它能够通过学习大量的自然语言数据来模拟人类进行开放、连贯且创造性的对话。与基于规则或检索式的聊天机器人不同,生成式对话机器人并不局限于预定义的回答集,而是可以根据对话上下文动态地生成新的。这类机器人通常依赖于深度学习框架,特别是Transformer架构(如GPT-3、BERT等)或其他循环神经网络(RNN),例如长短期记忆网络(LSTM)。

核心技术组件

神经网络架构

现代生成式对话机器人大多基于深度学习模型,尤其是Transformer架构。这种架构因其卓越的并行化能力和处理长距离依赖的能力而被广泛采用。Transformers中的多头注意力机制使得模型可以更有效地捕捉输入序列中各个部分之间的关系,从而生成更加相关和连贯的。

自回归模型

在生成回复的过程中,自回归模型按照词或子词单元的顺序预测下一个单元,直到构建出完整的句子。这种方式确保了文本序列的连续性和上下文的一致性。自回归模型的一个显著特点是它们会逐步构建输出,每一次迭代都会根据之前生成的内容调整后续的预测。

训练数据

高质量的训练数据对于生成式对话机器人的性能至关重要。这些数据可以来源于各种渠道,比如电影剧本、社交媒体对话、论坛帖子、客服记录等。丰富的多样化数据有助于训练出一个能够理解和回应多种话题及情境的对话系统。

注意力机制

特别是在Transformer架构中,注意力机制允许模型聚焦于输入序列的关键部分,这对于理解复杂的查询以及产生恰当的回答尤为重要。多头注意力机制进一步增强了这一能力,因为它可以在同一层内同时关注多个不同的信息源。

强化学习

为了优化对话机器人的行为,有时会结合强化学习策略。这种方法可以帮助模型适应不断变化的环境,并依据用户的反馈调整对话策略,以达到更好的交互效果。通过奖励机制,模型可以学习哪些类型的回答更能满足用户需求,进而改进自身的性能。

对话管理

除了基本的回复生成外,一个完整的对话机器人还需要具备对话管理功能,用以跟踪对话状态,确保对话流程的连贯性,以及适时切换话题或结束对话。这涉及到对对话历史的理解和对未来可能发展的预测。

后处理与控制

为了保证生成内容的质量和安全性,生成式对话机器人可能会包含一些后处理步骤,比如过滤不当内容或者调整语气风格,以避免生成不准确、误导性或是不合适的信息。

基于预训练模型训练生成式对话机器人

1, 训练实施方案

这次使用的模型是Langboat/bloom-389m-zh 是澜舟科技开源的。

数据集:nlpcc_2017

将数据集如何处理传给模型,训练出想要的模型实现对话机器人了。

因为模型是自回归的,所以训练任务就是要将完整的序列输入,基于上下文token预测当前token结束位置要有特殊token,eos_token。自回归上部简介中有介绍(自回归模型按照词或子词单元的顺序预测下一个单元),这样就好理解了

数据处理大概方向已经清楚了,那具体怎么处理了。

在对话中都是一问一答方式,nlpcc_2017也是这样。是对话那么就免不了是多轮的,那么我们喂给模型要是一轮还是多轮实现这样的结果了。能一轮肯定是一轮就要搞定了。

那么数据就要处理成这样:

input部分提问和答复两部分,label只有答复部分,因为计算原因input和label长度要相同,label缺少部分就要用-100补齐。

图中的黄色部分是提问,蓝色是答复最后要介绍标记eos。

这样数据集处理格式,模型可以识别出来,能计算loss。

单轮问答讲解(作为参考):

多轮问答讲解(参考):

2,代码实现
python 复制代码
# 生成式对话机器人
## Step1 导入相关包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

## Step2 加载数据集

ds = Dataset.load_from_disk("./alpaca_data_zh/")
print(ds)

a=ds[:3]
print(a)

## Step3 数据集预处理

tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")
print(tokenizer)

# 数据集处理
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
print(tokenized_ds)


t = tokenizer.decode(tokenized_ds[1]["input_ids"])
print(t)

p = tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"])))
print(p)

## Step4 创建模型

model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")

## Step5 配置训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=2
)

## Step6 创建训练器

trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
)
## Step7 模型训练
trainer.train()


## Step8 模型推理


from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
s = pipe(ipt, max_length=256, do_sample=True, )
print(s)
相关推荐
金士镧新材料有限公司36 分钟前
稀土化合物:引领科技创新,推动绿色发展
人工智能·科技·安全·全文检索·生活
feifeikon1 小时前
PyTorch DAY1: 基础语法
人工智能·pytorch·python
yuanbenshidiaos1 小时前
【大数据】机器学习 -----关于data.csv数据集分析案例
大数据·人工智能·机器学习
deephub1 小时前
TorchOptimizer:基于贝叶斯优化的PyTorch Lightning超参数调优框架
人工智能·pytorch·python·机器学习·超参数调优
墨绿色的摆渡人1 小时前
pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系
人工智能·pytorch·python
pzx_0012 小时前
【深度学习】神经网络灾难性遗忘(Catastrophic Forgetting,CF)问题
人工智能·深度学习·神经网络·集成学习
玄明Hanko2 小时前
开源LLM:开启DIY你的专属AI之路
人工智能·llm·开源ai模型
说私域2 小时前
仪式感在会员体系建设中的重要性及AI智能名片2+1链动模式S2B2C商城小程序的应用研究
人工智能·小程序
rommel rain3 小时前
投机解码论文阅读:Falcon
论文阅读·语言模型·自然语言处理
我不是千面3 小时前
视频超分(VSR)论文阅读记录/idea积累(一)
人工智能