【chatglm3】(3):在AutoDL上,使用4090显卡,部署ChatGLM3API服务,并微调AdvertiseGen数据集,完成微调并测试成功!附视频

在AutoDL上,使用4090显卡,部署ChatGLM3API服务,并微调AdvertiseGen数据集,完成微调并测试成功!

其他chatgpt 和chatglm3 资料: blog.csdn.net/freewebsys/...

视频地址: www.bilibili.com/video/BV1zQ...

[video(video-DBCaDBO8-1699888999884)(type-bilibili)(url-player.bilibili.com/player.html...(image-https%3A%2F%2Fimg-blog.csdnimg.cn%2Fimg_convert%2Ff325fed9ee313f9176646e7284c7a675.jpeg)(title-%25E5%259C%25A8AutoDL%25E4%25B8%258A%25EF%25BC%258C%25E4%25BD%25BF%25E7%2594%25A84090%25E6%2598%25BE%25E5%258D%25A1%25EF%25BC%258C%25E9%2583%25A8%25E7%25BD%25B2ChatGLM3API%25E6%259C%258D%25E5%258A%25A1%25EF%25BC%258C%25E5%25B9%25B6%25E5%25BE%25AE%25E8%25B0%2583AdvertiseGen%25E6%2595%25B0%25E6%258D%25AE%25E9%259B%2586%25EF%25BC%258C%25E5%25AE%258C%25E6%2588%2590%25E5%25BE%25AE%25E8%25B0%2583%25E5%25B9%25B6%25E6%25B5%258B%25E8%25AF%2595%25E6%2588%2590%25E5%258A%259F%25EF%25BC%2581 "https://player.bilibili.com/player.html?aid=705971984)(image-https://img-blog.csdnimg.cn/img_convert/f325fed9ee313f9176646e7284c7a675.jpeg)(title-%E5%9C%A8AutoDL%E4%B8%8A%EF%BC%8C%E4%BD%BF%E7%94%A84090%E6%98%BE%E5%8D%A1%EF%BC%8C%E9%83%A8%E7%BD%B2ChatGLM3API%E6%9C%8D%E5%8A%A1%EF%BC%8C%E5%B9%B6%E5%BE%AE%E8%B0%83AdvertiseGen%E6%95%B0%E6%8D%AE%E9%9B%86%EF%BC%8C%E5%AE%8C%E6%88%90%E5%BE%AE%E8%B0%83%E5%B9%B6%E6%B5%8B%E8%AF%95%E6%88%90%E5%8A%9F%EF%BC%81"))]

1,显卡市场,租个显卡性价比最高!

www.autodl.com/ 创建完成可以使用 juypter 进入:

也可以监控服务器运行状况:

2,下载源代码,下载模型,启动服务

下载模型速度超级快 :

bash 复制代码
apt update && apt install git-lfs -y
git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git chatglm3-6b-models
Cloning into 'chatglm3-6b-models'...
remote: Enumerating objects: 101, done.
remote: Counting objects: 100% (101/101), done.
remote: Compressing objects: 100% (58/58), done.
remote: Total 101 (delta 42), reused 89 (delta 38), pack-reused 0
Receiving objects: 100% (101/101), 40.42 KiB | 1.84 MiB/s, done.
Resolving deltas: 100% (42/42), done.
Filtering content: 100% (8/8), 11.63 GiB | 203.56 MiB/s, done.

再下载github 项目: github.com/THUDM/ChatG... 或者上传代码

然后安装依赖库:

bash 复制代码
# 安装完成才可以启动:
pip3 install uvicorn fastapi loguru sse_starlette transformers sentencepiece
cd /root/ChatGLM3-main/openai_api_demo
python3 openai_api.py

启动成功,端口 8000 可以运行命令进行测试:

bash 复制代码
curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
     "model": "chatglm3-6b",
     "messages": [{"role": "user", "content": "北京景点"}],
     "temperature": 0.7
   }' 

3,使用脚本进行token测试,速度50 tokens/s 速度挺快的

然后使用测试脚本进行 token 测试,修改的 fastcaht的测试脚本:

bash 复制代码
# coding=utf-8
"""

token测试工具:

python3 test_throughput.py
或者:
python3 test_throughput.py --api-address http://localhost:8000 --n-thread 20


"""
import argparse
import json

import requests
import threading
import time


def main():

    headers = {"User-Agent": "openai client", "Content-Type": "application/json"}
    ploads = {
        "model": args.model_name,
        "messages": [{"role": "user", "content": "生成一个50字的故事,内容随即生成。"}],
        "temperature": 1,
    }
    thread_api_addr = args.api_address

    def send_request(results, i):
        print(f"thread {i} goes to {thread_api_addr}")
        response = requests.post(
            thread_api_addr + "/v1/chat/completions",
            headers=headers,
            json=ploads,
            stream=False,
        )
        print(response.text)
        response_new_words = json.loads(response.text)["usage"]["completion_tokens"]
        print(f"=== Thread {i} ===, words: {response_new_words} ")
        results[i] = response_new_words

    # use N threads to prompt the backend
    tik = time.time()
    threads = []
    results = [None] * args.n_thread
    for i in range(args.n_thread):
        t = threading.Thread(target=send_request, args=(results, i))
        t.start()
        # time.sleep(0.5)
        threads.append(t)

    for t in threads:
        t.join()

    print(f"Time (POST): {time.time() - tik} s")
    n_words = sum(results)
    time_seconds = time.time() - tik
    print(
        f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, "
        f"throughput: {n_words / time_seconds} words/s."
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--api-address", type=str, default="http://localhost:8000")
    parser.add_argument("--model-name", type=str, default="chatglm3-6b")
    parser.add_argument("--n-thread", type=int, default=10)
    args = parser.parse_args()

    main()

测下下服务:

bash 复制代码
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:42:00.0 Off |                  Off |
| 30%   39C    P2              56W / 450W |  12429MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

速度特别快:

bash 复制代码
Time (POST): 22.13719415664673 s
Time (Completion): 22.137234687805176, n threads: 10, throughput: 51.22591037193507 words/s.

完全可以满足内部使用了。

3,下载微调数据,并进行模型训练

cloud.tsinghua.edu.cn/f/b3f119a00...

AdvertiseGen以商品网页的标签与文案的信息对应关系为基础构造

载处理好的 AdvertiseGen 数据集,将解压后的 AdvertiseGen 目录放到本目录下。

bash 复制代码
./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"

来下载和将数据集处理成上述格式。

微调模型

bash 复制代码
# 安装依赖库
pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed

./scripts/finetune_pt.sh  # P-Tuning v2 微调

为了验证演示,调整参数,快速训练:

bash 复制代码
#! /usr/bin/env bash

set -ex

PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1
MAX_SOURCE_LEN=1024
MAX_TARGET_LEN=128
DEV_BATCH_SIZE=1
GRAD_ACCUMULARION_STEPS=8
MAX_STEP=10
SAVE_INTERVAL=10

DATESTR=`date +%Y%m%d-%H%M%S`
RUN_NAME=advertise_gen_pt

BASE_MODEL_PATH=/root/chatglm3-6b-models
DATASET_PATH=formatted_data/advertise_gen.jsonl
OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${PRE_SEQ_LEN}-${LR}

mkdir -p $OUTPUT_DIR

torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
    --train_format input-output \
    --train_file $DATASET_PATH \
    --preprocessing_num_workers 1 \
    --model_name_or_path $BASE_MODEL_PATH \
    --output_dir $OUTPUT_DIR \
    --max_source_length $MAX_SOURCE_LEN \
    --max_target_length $MAX_TARGET_LEN \
    --per_device_train_batch_size $DEV_BATCH_SIZE \
    --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
    --max_steps $MAX_STEP \
    --logging_steps 1 \
    --save_steps $SAVE_INTERVAL \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log

4,推理验证,使用命令行的方式

对于输入输出格式的微调,可使用 inference.py 进行基本的推理验证。

bash 复制代码
python inference.py \
    --model /root/chatglm3-6b-models \
    --pt-checkpoint "output/advertise_gen_pt-20231113-222811-128-2e-2" 
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.32it/s]
Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /root/chatglm3-6b-models and are newly initialized: ['transformer.prefix_encoder.embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Prompt:类型#裙*材质#网纱*颜色#粉红色*裙袖长#短袖*裙领型#圆领
Response: *裙下摆流苏设计,轻轻松松演绎甜美可爱风。这条裙子真的太仙了,粉红色网纱,在阳光的照耀下,真的太仙了,仿佛置身于童话故事中。短袖的设计,既不会过于露肤,也能展示出修长的身材线条。裙摆处流苏的设计,让整个裙子的层次感更加明显,给人一种飘逸的感觉。

5,总结

在 4090 上面运行 chatgm3 速度还是挺快的。 然后找到官方的 AdvertiseGen 数据集,就是对商品的标签和文案的匹配数据。 然后根据内容进行训练,然后再输入相关类似的标签,就可以自动生成广告文案了。 这个是AIGC的挺好的落地场景。

可以在 4090 上完成训练,并验证成功了!

相关推荐
手握风云-20 分钟前
零基础Java第十六期:抽象类接口(二)
数据结构·算法
笨小古1 小时前
路径规划——RRT-Connect算法
算法·路径规划·导航
<但凡.1 小时前
编程之路,从0开始:知识补充篇
c语言·数据结构·算法
f狐0狸x2 小时前
【数据结构副本篇】顺序表 链表OJ
c语言·数据结构·算法·链表
paopaokaka_luck2 小时前
基于Spring Boot+Vue的多媒体素材管理系统的设计与实现
java·数据库·vue.js·spring boot·后端·算法
视觉小萌新2 小时前
VScode+opencv——关于opencv多张图片拼接成一张图片的算法
vscode·opencv·算法
2的n次方_2 小时前
二维费用背包问题
java·算法·动态规划
simple_ssn3 小时前
【C语言刷力扣】1502.判断能否形成等差数列
c语言·算法·leetcode
寂静山林3 小时前
UVa 11855 Buzzwords
算法
Curry_Math3 小时前
LeetCode 热题100之技巧关卡
算法·leetcode