实战案例:chatglm3 基础模型多轮对话微调

chatglm3 发布了,这次还发了base版本的模型,意味着我们可以基于这个base模型去自由地做SFT了。

本项目实现了基于base模型的SFT。

base模型

bash 复制代码
https://huggingface.co/THUDM/chatglm3-6b-base

由于模型较大,建议离线下载后放在代码目录,以"./chatglm3-6b-base"的路径进行调用。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

建了技术答疑、交流群!想要进交流群、需要资料的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、添加微信号:mlc2060,备注:技术交流

方式②、微信搜索公众号:机器学习社区,后台回复:技术交流

环境依赖

bash 复制代码
pip install protobuf transformers==4.30.2 peft cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate

除了transformers,其他库的版本一般问题不大,遇到缺失的直接pip install即可。

SFT数据格式

使用自己的数据可以参照formatted_samples.json文件,这里没有考虑system,实际使用可以根据自己的情况加上,需要修改chat_data_module.py中对应的数据处理部分。

附上chatglm3的prompt格式

bash 复制代码
<|system|>
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
<|user|>
Hello
<|assistant|>
Hello, I'm ChatGLM3. What can I assist you today?

其实数据处理chat_data_module.py中会拼接一些token就是拼接user、assistant、换行等特殊token

SFT的方式

假设SFT的数据为

bash 复制代码
Q1,A1,Q2,A2,Q3,A3

SFT的过程只会计算

bash 复制代码
A1,A2,A3

的loss,且一次推理会同时计算多轮对话的loss。

如何微调

如果模型路径为"./chatglm3-6b-base",直接

bash 复制代码
python train.py

就可以运行。train.py 当中有需要可调节的参数可以自行调整。

微调效果

作为没有经过人类意图对齐的模型,ChatGLM3-6B-Base 不能用于多轮对话。但是可以进行文本续写。

这里仅通过27条数据进行SFT,发现模型就能够具有一定的对话能力了。

导入模型并合并

python 复制代码
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, PeftModel, get_peft_model

tokenizer = AutoTokenizer.from_pretrained("./chatglm3-6b-base", trust_remote_code=True)
model = AutoModel.from_pretrained("./chatglm3-6b-base", trust_remote_code=True).half().cuda()

peft_model_id = './trained_model/checkpoint-35'
model = PeftModel.from_pretrained(model, peft_model_id)
python 复制代码
Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]
python 复制代码
history = []
query = "你是谁"
role = "user"
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to('cuda')
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
gen_kwargs = {"max_length": 500, "num_beams": 1, "do_sample": True, "top_p": 0.8,
                      "temperature": 0.8}
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history = []
history.append({"role": "user", "content": "你是谁"})
response, history = model.process_response(response, history)
print(response)

我叫MONY,是一个AI机器人。

python 复制代码
query = "你能干嘛呀"
role = "user"
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to('cuda')
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
response, history = model.process_response(response, history)
print(response)

我能够陪你聊天呀。

python 复制代码
query = "你认识乐乐吗"
role = "user"
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to('cuda')
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
response, history = model.process_response(response, history)
print(response)

我不认识乐乐。

python 复制代码
query = "可以夸一下乐乐长得好看吗"
role = "user"
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to('cuda')
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
response, history = model.process_response(response, history)
print(response)

乐乐听起来是一个人名,我不认识他。

python 复制代码
query = "你要夸她长得好看"
role = "user"
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to('cuda')
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
response, history = model.process_response(response, history)
print(response)

好的,我会记住的。

python 复制代码
query = "你倒是夸一下呀"
role = "user"
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to('cuda')
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
response, history = model.process_response(response, history)
print(response)

乐乐是一个很可爱的人。

项目地址

https://github.com/minghaochen/chatglm3-base-tuning

References

代码参考自llamatune项目

https://github.com/havenhq/haven/tree/dev/llamatune

相关推荐
极智-9966 分钟前
GitHub 热榜项目-日榜精选(2026-02-02)| AI智能体、终端工具、视频生成等 | openclaw、99、Maestro等
人工智能·github·视频生成·终端工具·ai智能体·电子书管理·rust工具
悟纤19 分钟前
AI 音乐创作中的音乐织体(Texture)完整指南 | Suno高级篇 | 第30篇
人工智能·suno·suno ai·suno api·ai music
可触的未来,发芽的智生25 分钟前
狂想:为AGI代称造字ta,《第三类智慧存在,神的赐名》
javascript·人工智能·python·神经网络·程序人生
莱茶荼菜28 分钟前
yolo26 阅读笔记
人工智能·笔记·深度学习·ai·yolo26
Dingdangcat8644 分钟前
【YOLOv8改进实战】使用Ghost模块优化P2结构提升涂胶缺陷检测精度_1
人工智能·yolo·目标跟踪
希艾席帝恩1 小时前
智慧城市建设中,数字孪生的价值在哪里?
人工智能·低代码·私有化部署·数字孪生·数字化转型
我的offer在哪里2 小时前
开源 AI 生成游戏平台:原理、开源项目与落地实战指南
人工智能·游戏·开源
qidun2102 小时前
埃夫特机器人防护服使用范围详解-避免十大应用误区
网络·人工智能
Σίσυφος19002 小时前
PCL Point-to-Point ICP详解
人工智能·算法
PaperRed ai写作降重助手2 小时前
AI 论文写作工具排名(实测不踩坑)
人工智能·aigc·ai写作·论文写作·智能降重·辅助写作·降重复率