基于 ChatGLM 和 LangChain 快速搭建本地知识库

Step1 免费试用服务

阿里云免费的 GPU 服务器

GPU 选择 A10 或者 V100 都行。

ubuntu 系统环境镜像地址:

地区 镜像
杭州地域 dsw-registry-vpc.cn-hangzhou.cr.aliyuncs.com/cloud-dsw/eas-service:aigc-torch113-cu117-ubuntu22.04-v0.2.1_accelerated
北京地域 dsw-registry-vpc.cn-beijing.cr.aliyuncs.com/cloud-dsw/eas-service:aigc-torch113-cu117-ubuntu22.04-v0.2.1_accelerated
上海地域 dsw-registry-vpc.cn-shanghai.cr.aliyuncs.com/cloud-dsw/eas-service:aigc-torch113-cu117-ubuntu22.04-v0.2.1_accelerated
深圳地域 dsw-registry-vpc.cn-shenzhen.cr.aliyuncs.com/cloud-dsw/eas-service:aigc-torch113-cu117-ubuntu22.04-v0.2.1_accelerated

Step2: 部署 ChatGLM2-6B

sh 复制代码
git clone https://github.com/THUDM/ChatGLM2-6B.git

cd ChatGLM2-6B

pip install -r requirements.txt

# 根目录
git clone https://huggingface.co/THUDM/chatglm2-6b

修改 web_demo.py 文件

py 复制代码
# 改为本地模型路径
tokenizer = AutoTokenizer.from_pretrained("/mnt/workspace/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("/mnt/workspace/chatglm2-6b", trust_remote_code=True).cuda()

# 有需要的话在这里修改本地访问
demo.queue().launch(share=True, inbrowser=True, server_name='0.0.0.0', server_port=9527)

python web_demo.py 启动。

Step3: 基于 P-Tuning 微调 ChatGLM2-6B

文档 ChatGLM2-6B-PT

sh 复制代码
# 安装依赖
pip install rouge_chinese nltk jieba datasets
# 有报错的话
pip install transformers==4.27.1

准备数据集

官方文档里面推荐从 Tsinghua Cloud 下载处理好的数据集,将解压后的 AdvertiseGen 文件夹中的 train.jsondev.json文件放到 root/ptuning 目录下。

不过我们就是测试,所以只需要简单的一些数据进行测试就好。ptuning 文件夹新建 train.jsondev.json

实际使用时候这里需要替换成大量训练数据。

ps: 没看错,别修改,就是这样的 json 格式。

json 复制代码
{"content": "你好,你是谁", "summary": "你好,我是kanelogger的小助理二狗。"}
{"content": "你是谁", "summary": "你好,我是kanelogger的小助理二狗。"}
{"content": "kanelogger是谁", "summary": "kanelogger是麻瓜。"}
{"content": "介绍下kanelogger", "summary": "kanelogger是麻瓜。"}
{"content": "kanelogger", "summary": "kanelogger是麻瓜。"}

调整参数

修改 ptuning/train.sh 文件

  1. train_file/validation_file 改为刚刚新建的数据集路径
  2. max_source_length/max_target_length 是匹配数据集中的最大输入和输出的长度。
  3. --model_name_or_path 将模型路径改为本地的模型路径。

PS:

  • prompt_column/response_column 是 JSON 文件中输入文本和输出文本对应的 KEY
  • PRE_SEQ_LEN 是 soft prompt 长度
  • LR 是 训练的学习率,可以进行调节以取得最佳的效果。
  • quantization_bit 可通过这个配置调整改变原始模型的量化等级,不加此选项则为 FP16 精度加载。
sh 复制代码
# PRE_SEQ_LEN=128
PRE_SEQ_LEN=32
LR=2e-2
NUM_GPUS=1

torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
    --do_train \
    # --train_file AdvertiseGen/train.json \
    --train_file train.json \
    # --validation_file AdvertiseGen/dev.json \
    --validation_file dev.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    # --model_name_or_path THUDM/chatglm2-6b \
    --model_name_or_path /mnt/workspace/chatglm2-6b \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    # --max_source_length 64 \
    --max_source_length 128 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    # --quantization_bit 4

修改 evaluate.sh 文件

sh 复制代码
# PRE_SEQ_LEN=128
PRE_SEQ_LEN=32
# CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2
CHECKPOINT=adgen-chatglm2-6b-pt-32-2e-2
STEP=3000
NUM_GPUS=1

torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
    --do_predict \
    # --validation_file AdvertiseGen/dev.json \
    --validation_file dev.json \
    # --test_file AdvertiseGen/dev.json \
    --test_file dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    # --model_name_or_path THUDM/chatglm2-6b \
    --model_name_or_path /mnt/workspace/chatglm2-6b \
    --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    # --max_source_length 64 \
    # --max_target_length 64 \
    --max_source_length 128 \
    --max_target_length 128 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    # --quantization_bit 4
  1. 训练。当前文件夹下打开 terminal bash train.sh。这个时间很长,大概 40 -60 min。
  2. 继续执行 bash evaluate.sh 进行推理测试

生成的结果保存在 ./output/adgen-chatglm2-6b-pt-32-2e-2/generated_predictions.txt

如果不满意调整之前PS里的参数,再次训练。

PS: /mnt/workspace/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt-32-2e-2/checkpoint-3000 以及 pytorch_model.bin 会在后续开发的时候会用到,可以特别记一下。

部署微调后的模型

修改 web_demo.sh。也是修改模型地址和 checkpoint 地址。

sh 复制代码
# PRE_SEQ_LEN=128
PRE_SEQ_LEN=32

CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
    # --model_name_or_path THUDM/chatglm2-6b \
    --model_name_or_path /mnt/workspace/chatglm2-6b \
    # --ptuning_checkpoint output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000 \
    --ptuning_checkpoint output/adgen-chatglm2-6b-pt-32-2e-2/checkpoint-3000 \
    --pre_seq_len $PRE_SEQ_LEN

执行 bash web_demo.sh

Step4: 直接用 langchain-ChatGLM 构建知识库

根目录

sh 复制代码
git clone https://github.com/imClumsyPanda/langchain-ChatGLM.git
cd langchain-ChatGLM
pip install -r requirements.txt
pip install --upgrade protobuf==3.19.6

# 安装 git lfs
git lfs install

# 根目录下载 Embedding 模型
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese text2vec

修改文件 configs/model_config.py

py 复制代码
embedding_model_dict = {
    "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
    "ernie-base": "nghuyong/ernie-3.0-base-zh",
    "text2vec-base": "shibing624/text2vec-base-chinese",
    # 修改刚刚下载的路径
    "text2vec": "/mnt/workspace/text2vec",
    "m3e-small": "moka-ai/m3e-small",
    "m3e-base": "moka-ai/m3e-base",
}

llm_model_dict = {
    ...
    "chatglm2-6b": {
        "name": "chatglm2-6b",
        # 本地模型路径
        "pretrained_model_name": "/mnt/workspace/chatglm2-6b",
        "local_model_path": None,
        "provides": "ChatGLM"
    },
    ...
}

# LLM 名称改成 chatglm2-6b
LLM_MODEL = "chatglm2-6b"

python webui.py 启动网页 或者 python api.py 启动 api。项目内置了 vue 搭建的页面,在 views 里面。

参考资料

相关推荐
极客代码2 分钟前
【Python TensorFlow】入门到精通
开发语言·人工智能·python·深度学习·tensorflow
义小深4 分钟前
TensorFlow|咖啡豆识别
人工智能·python·tensorflow
Tianyanxiao44 分钟前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
撞南墙者1 小时前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone4211 小时前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
王哈哈^_^1 小时前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心1 小时前
【AI技术】PaddleSpeech
人工智能
是瑶瑶子啦1 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
EasyCVR2 小时前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频
柳鲲鹏2 小时前
OpenCV视频防抖源码及编译脚本
人工智能·opencv·计算机视觉