chatglm3-6b尝试

十月底智谱开源了chatglm3,果断来尝试一下。

1.ChatGLM3 亮点:

ChatGLM3 是智谱AI和清华大学 KEG 实验室联合发布的新一代对话预训练模型。ChatGLM3-6B 是 ChatGLM3 系列中的开源模型,在保留了前两代模型对话流畅、部署门槛低等众多优秀特性的基础上,ChatGLM3-6B 引入了 代码执行(Code Interpreter)和 Agent 任务等新特性。

2.前期准备:

python:3.10+,transformers 库版本推荐为 4.30.2torch 推荐使用 2.0 及以上的版本,以获得最佳的推理性能。

chatglm3-6b支持cpu推理,但是在本人32g内存上的电脑实测,emmm,还是准备一台gpu服务器吧。

3.环境安装:

首先需要下载本仓库:

bash 复制代码
git clone https://github.com/THUDM/ChatGLM3

然后使用 pip 安装依赖:

复制代码
pip install -r requirements.txt

模型下载:

Hugging Face 速度慢的建议从modelscope.cn/models/Zhip...

bash 复制代码
git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git

4.本地加载模型:

ini 复制代码
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer


# 载入Tokenizer
model_dir='/root/autodl-tmp/chatglm3-6b'
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_dir, config=config, trust_remote_code=True).cuda()

简单的对话:

bash 复制代码
history=[]
query = "你好呀"
response, history = model.chat(tokenizer, query, history=history)
print(query)
print(response)

运行包里的web_demo.py和web_demo2.py可以生成一个可以在线对话的网页

5.模型微调:

模型微调大概分为多轮对话和单轮对话两种方式,先准备训练数据。

多轮对话微调

先准备一个如下数据格式的json文件

less 复制代码
[  {    "tools": [],
    "conversations": [      {        "role": "system",        "content": "我是搬砖小助手"      },      {        "role": "user",        "content": "你好呀"      },      {        "role": "assistant",        "content": "你好呀,今天又是搬砖的一天"      }    ]
  }
 // ...
]

创建sh文件并运行

ini 复制代码
set -ex

PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1
MAX_SEQ_LEN=2048
DEV_BATCH_SIZE=4
GRAD_ACCUMULARION_STEPS=4
MAX_STEP=1000
SAVE_INTERVAL=200

DATASTR=`date +%Y%m%d-%H%M%S`
RUN_NAME=tool_alpaca_pt


BASE_MODEL_PATH='/root/autodl-tmp/chatglm3-6b'
DATASET_PATH='/root/chatglm3-6b/finetune_demo/formatted_data/train.json'
OUTPUT_DIR='/root/autodl-tmp/check-point'
mkdir -p $OUTPUT_DIR

torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
    --train_format multi-turn \
    --train_file $DATASET_PATH \
    --max_seq_length $MAX_SEQ_LEN \
    --preprocessing_num_workers 1 \
    --model_name_or_path $BASE_MODEL_PATH \
    --output_dir $OUTPUT_DIR \
    --per_device_train_batch_size $DEV_BATCH_SIZE \
    --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
    --max_steps $MAX_STEP \
    --logging_steps 1 \
    --save_steps $SAVE_INTERVAL \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log

试试微调后的结果吧

ini 复制代码
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer

model_dir='/root/autodl-tmp/chatglm3-6b'
CHECKPOINT_PATH='/root/autodl-tmp/check-point/checkpoint-1000'
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_dir, config=config, trust_remote_code=True).cuda()
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model = model.eval()

system_info = {"role": "system", "content": "我是搬砖小助手"}
history = [system_info]
query = "你好呀"
response, history = model.chat(tokenizer, query, history=history)
print(query)
print(response)

单轮对话微调:

先准备一个如下数据格式的json文件:

json 复制代码
[
  {
     "prompt": "你好呀",
      "response": "你好呀,今天又是搬砖的一天"
    }
     // ...

]

创建sh文件并运行:

ini 复制代码
set -ex
PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1
MAX_SOURCE_LEN=1024
MAX_TARGET_LEN=128
DEV_BATCH_SIZE=4
GRAD_ACCUMULARION_STEPS=4
MAX_STEP=1000
SAVE_INTERVAL=100

DATESTR=`date +%Y%m%d-%H%M%S`
RUN_NAME=advertise_gen_pt

BASE_MODEL_PATH='/root/autodl-tmp/chatglm3-6b'
DATASET_PATH='/root/chatglm3-6b/finetune_demo/formatted_data/train.json'
OUTPUT_DIR='/root/autodl-tmp/check-point'
mkdir -p $OUTPUT_DIR

torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
    --train_format input-output \
    --train_file $DATASET_PATH \
    --preprocessing_num_workers 1 \
    --model_name_or_path $BASE_MODEL_PATH \
    --output_dir $OUTPUT_DIR \
    --max_source_length $MAX_SOURCE_LEN \
    --max_target_length $MAX_TARGET_LEN \
    --per_device_train_batch_size $DEV_BATCH_SIZE \
    --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
    --max_steps $MAX_STEP \
    --logging_steps 1 \
    --save_steps $SAVE_INTERVAL \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log

6.工具调用:

ini 复制代码
tools = [
    {
        "name": "track",
        "description": "追踪指定股票的实时价格",
        "parameters": {
            "type": "object",
            "properties": {
                "symbol": {
                    "description": "需要追踪的股票代码"
                }
            },
            "required": ['symbol']
        }
    }
]
system_info = {"role": "system", "content": "Answer the following questions as best as you can. You have access to the following tools:", "tools": tools}

注意:目前 ChatGLM3-6B 的工具调用只支持通过 chat 方法,不支持 stream_chat 方法。根本原因是stream_chat 是一个个吐字的,没法中间做手脚将工具调用结果进行处理。具体可以看这位大佬的文章:zhuanlan.zhihu.com/p/664233831

ini 复制代码
history = [system_info]
query = "帮我查询股票10111的价格"
response, history = model.chat(tokenizer, query, history=history)
print(response)
result = json.dumps({"price": 12412}, ensure_ascii=False)
response, history = model.chat(tokenizer, result, history=history, role="observation")
print(response)
相关推荐
IT_陈寒15 小时前
React 18实战:7个被低估的Hooks技巧让你的开发效率提升50%
前端·人工智能·后端
数据智能老司机16 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
逛逛GitHub16 小时前
飞书多维表“独立”了!功能强大的超出想象。
人工智能·github·产品
机器之心16 小时前
刚刚,DeepSeek-R1论文登上Nature封面,通讯作者梁文锋
人工智能·openai
数据智能老司机17 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机17 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机17 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i17 小时前
drf初步梳理
python·django
每日AI新事件17 小时前
python的异步函数
python
这里有鱼汤18 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python