transformers生成式对话机器人

简介

生成式对话机器人是一种先进的人工智能系统,它能够通过学习大量的自然语言数据来模拟人类进行开放、连贯且创造性的对话。与基于规则或检索式的聊天机器人不同,生成式对话机器人并不局限于预定义的回答集,而是可以根据对话上下文动态地生成新的。这类机器人通常依赖于深度学习框架,特别是Transformer架构(如GPT-3、BERT等)或其他循环神经网络(RNN),例如长短期记忆网络(LSTM)。

核心技术组件

神经网络架构

现代生成式对话机器人大多基于深度学习模型,尤其是Transformer架构。这种架构因其卓越的并行化能力和处理长距离依赖的能力而被广泛采用。Transformers中的多头注意力机制使得模型可以更有效地捕捉输入序列中各个部分之间的关系,从而生成更加相关和连贯的。

自回归模型

在生成回复的过程中,自回归模型按照词或子词单元的顺序预测下一个单元,直到构建出完整的句子。这种方式确保了文本序列的连续性和上下文的一致性。自回归模型的一个显著特点是它们会逐步构建输出,每一次迭代都会根据之前生成的内容调整后续的预测。

训练数据

高质量的训练数据对于生成式对话机器人的性能至关重要。这些数据可以来源于各种渠道,比如电影剧本、社交媒体对话、论坛帖子、客服记录等。丰富的多样化数据有助于训练出一个能够理解和回应多种话题及情境的对话系统。

注意力机制

特别是在Transformer架构中,注意力机制允许模型聚焦于输入序列的关键部分,这对于理解复杂的查询以及产生恰当的回答尤为重要。多头注意力机制进一步增强了这一能力,因为它可以在同一层内同时关注多个不同的信息源。

强化学习

为了优化对话机器人的行为,有时会结合强化学习策略。这种方法可以帮助模型适应不断变化的环境,并依据用户的反馈调整对话策略,以达到更好的交互效果。通过奖励机制,模型可以学习哪些类型的回答更能满足用户需求,进而改进自身的性能。

对话管理

除了基本的回复生成外,一个完整的对话机器人还需要具备对话管理功能,用以跟踪对话状态,确保对话流程的连贯性,以及适时切换话题或结束对话。这涉及到对对话历史的理解和对未来可能发展的预测。

后处理与控制

为了保证生成内容的质量和安全性,生成式对话机器人可能会包含一些后处理步骤,比如过滤不当内容或者调整语气风格,以避免生成不准确、误导性或是不合适的信息。

基于预训练模型训练生成式对话机器人

1, 训练实施方案

这次使用的模型是Langboat/bloom-389m-zh 是澜舟科技开源的。

数据集:nlpcc_2017

将数据集如何处理传给模型,训练出想要的模型实现对话机器人了。

因为模型是自回归的,所以训练任务就是要将完整的序列输入,基于上下文token预测当前token结束位置要有特殊token,eos_token。自回归上部简介中有介绍(自回归模型按照词或子词单元的顺序预测下一个单元),这样就好理解了

数据处理大概方向已经清楚了,那具体怎么处理了。

在对话中都是一问一答方式,nlpcc_2017也是这样。是对话那么就免不了是多轮的,那么我们喂给模型要是一轮还是多轮实现这样的结果了。能一轮肯定是一轮就要搞定了。

那么数据就要处理成这样:

input部分提问和答复两部分,label只有答复部分,因为计算原因input和label长度要相同,label缺少部分就要用-100补齐。

图中的黄色部分是提问,蓝色是答复最后要介绍标记eos。

这样数据集处理格式,模型可以识别出来,能计算loss。

单轮问答讲解(作为参考):

多轮问答讲解(参考):

2,代码实现
python 复制代码
# 生成式对话机器人
## Step1 导入相关包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

## Step2 加载数据集

ds = Dataset.load_from_disk("./alpaca_data_zh/")
print(ds)

a=ds[:3]
print(a)

## Step3 数据集预处理

tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-389m-zh")
print(tokenizer)

# 数据集处理
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
print(tokenized_ds)


t = tokenizer.decode(tokenized_ds[1]["input_ids"])
print(t)

p = tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"])))
print(p)

## Step4 创建模型

model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-389m-zh")

## Step5 配置训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=2
)

## Step6 创建训练器

trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
)
## Step7 模型训练
trainer.train()


## Step8 模型推理


from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
s = pipe(ipt, max_length=256, do_sample=True, )
print(s)
相关推荐
学历真的很重要3 小时前
VsCode+Roo Code+Gemini 2.5 Pro+Gemini Balance AI辅助编程环境搭建(理论上通过多个Api Key负载均衡达到无限免费Gemini 2.5 Pro)
前端·人工智能·vscode·后端·语言模型·负载均衡·ai编程
普通网友3 小时前
微服务注册中心与负载均衡实战精要,微软 2025 年 8 月更新:对固态硬盘与电脑功能有哪些潜在的影响。
人工智能·ai智能体·技术问答
苍何3 小时前
一人手搓!AI 漫剧从0到1详细教程
人工智能
苍何3 小时前
Gemini 3 刚刷屏,蚂蚁灵光又整活:一句话生成「闪游戏」
人工智能
苍何3 小时前
越来越对 AI 做的 PPT 敬佩了!(附7大用法)
人工智能
苍何3 小时前
超全Nano Banana Pro 提示词案例库来啦,小白也能轻松上手
人工智能
阿杰学AI4 小时前
AI核心知识39——大语言模型之World Model(简洁且通俗易懂版)
人工智能·ai·语言模型·aigc·世界模型·world model·sara
智慧地球(AI·Earth)4 小时前
Vibe Coding:你被取代了吗?
人工智能
大、男人5 小时前
DeepAgent学习
人工智能·学习
测试人社区—66795 小时前
提升测试覆盖率的有效手段剖析
人工智能·学习·flutter·ui·自动化·测试覆盖率