ChatGLM2-6B-PT
本项目实现了对于 ChatGLM2-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
下面以 ADGEN (广告生成) 数据集为例介绍代码的使用方法。
In [11]:
!pip install -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
bash
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Requirement already satisfied: protobuf in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 1)) (5.26.1)
Requirement already satisfied: transformers==4.30.2 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (4.30.2)
Requirement already satisfied: cpm_kernels in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 3)) (1.0.11)
Requirement already satisfied: torch>=2.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (2.2.2)
Requirement already satisfied: gradio in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (3.40.0)
Requirement already satisfied: mdtex2html in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 6)) (1.3.0)
Requirement already satisfied: sentencepiece in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 7)) (0.2.0)
Requirement already satisfied: accelerate in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 8)) (0.28.0)
Requirement already satisfied: sse-starlette in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 9)) (2.0.0)
Requirement already satisfied: filelock in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (3.13.3)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (0.22.2)
Requirement already satisfied: numpy>=1.17 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (2023.12.25)
Requirement already satisfied: requests in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (2.31.0)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (0.4.2)
Requirement already satisfied: tqdm>=4.27 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (4.66.2)
Requirement already satisfied: typing-extensions>=4.8.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (4.10.0)
Requirement already satisfied: sympy in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (1.12)
Requirement already satisfied: networkx in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (3.2.1)
Requirement already satisfied: jinja2 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (3.1.3)
Requirement already satisfied: fsspec in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (2024.2.0)
...
Requirement already satisfied: referencing>=0.28.4 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (0.34.0)
Requirement already satisfied: rpds-py>=0.7.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (0.18.0)
Requirement already satisfied: uc-micro-py in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (1.0.3)
Requirement already satisfied: six>=1.5 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (1.16.0)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
In [13]:
# 运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
!pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/
bash
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Requirement already satisfied: rouge_chinese in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (1.0.3)
Requirement already satisfied: nltk in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (3.8.1)
Requirement already satisfied: jieba in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (0.42.1)
Requirement already satisfied: datasets in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (2.18.0)
Requirement already satisfied: transformers[torch] in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (4.30.2)
Requirement already satisfied: six in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from rouge_chinese) (1.16.0)
Requirement already satisfied: click in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (8.1.7)
Requirement already satisfied: joblib in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (1.3.2)
Requirement already satisfied: regex>=2021.8.3 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (2023.12.25)
Requirement already satisfied: tqdm in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (4.66.2)
Requirement already satisfied: filelock in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (3.13.3)
Requirement already satisfied: numpy>=1.17 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=12.0.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (15.0.2)
Requirement already satisfied: pyarrow-hotfix in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.6)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (2.2.1)
Requirement already satisfied: requests>=2.19.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (2.31.0)
Requirement already satisfied: xxhash in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (3.4.1)
Requirement already satisfied: multiprocess in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)
Requirement already satisfied: aiohttp in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (3.9.3)
Requirement already satisfied: huggingface-hub>=0.19.4 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.22.2)
Requirement already satisfied: packaging in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (6.0.1)
...
Requirement already satisfied: pytz>=2020.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from jinja2->torch!=1.12.0,>=1.9->transformers[torch]) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from sympy->torch!=1.12.0,>=1.9->transformers[torch]) (1.3.0)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
使用方法
下载数据集
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
{
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
从 Google Drive 或者 Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen
目录放到 ptuning 目录下。
本项目中默认已经挂载了 ADGEN 数据集。
In [2]:
微调生成的 Checkpoint 文件较大,为避免占用 project 目录空间,我们将工作目录移到 temp 目录中进行后续工作
!cp -r /mnt/e/AI-lab/ChatGLM2-6B/ptuning /mnt/e/AI-lab/ChatGLM2-6B/temp
In [3]:
python
import os
# 设置你想要切换到的目录路径
new_dir = '/mnt/e/AI-lab/ChatGLM2-6B/temp/ptuning'
# 切换当前工作目录
os.chdir(new_dir)
# 打印当前工作目录以确认切换成功
print(os.getcwd())
/mnt/e/AI-lab/ChatGLM2-6B/temp/ptuning
In [4]:
# 拷贝 ADGEN 数据集到工作目录
!cp -r /home/mw/input/adgen9371 AdvertiseGen
In [5]:
# 检查数据集
!ls -alh AdvertiseGen
bash
total 52M
drwxrwxrwx 1 ai001 ai001 4.0K Apr 3 21:16 .
drwxrwxrwx 1 ai001 ai001 4.0K Apr 4 17:34 ..
-rwxrwxrwx 1 ai001 ai001 487K Apr 4 17:34 dev.json
-rwxrwxrwx 1 ai001 ai001 52M Apr 4 17:34 train.json
训练
P-Tuning v2
PRE_SEQ_LEN
和 LR
分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit
来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
在默认配置 quantization_bit=4
、per_device_train_batch_size=1
、gradient_accumulation_steps=16
下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size
的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
Finetune
如果需要进行全参数的 Finetune,需要安装 Deepspeed,然后运行以下指令:
bash ds_train_finetune.sh
我们以 P-tuning v2 方法为例,采取参数 quantization_bit=4
、per_device_train_batch_size=1
、gradient_accumulation_steps=16
进行微调训练
In [14]:
python
# P-tuning v2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1
!torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \
--preprocessing_num_workers 10 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path /mnt/e/AI-lab/ChatGLM2-6B/ \
--output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 128 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 2e-2 \
--pre_seq_len 128 \
--ddp_find_unused_parameters False
#--quantization_bit 4
In [17]:
python
# 加载模型
model_path = "/mnt/e/ai-lab/ChatGLM2-6B"
from transformers import AutoTokenizer, AutoModel
from utils import load_model_on_gpus
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = load_model_on_gpus("/mnt/e/ai-lab/ChatGLM2-6B", num_gpus=2)
model = model.eval()
python
# 使用 Markdown 格式打印模型输出
from IPython.display import display, Markdown, clear_output
def display_answer(model, query, history=[]):
for response, history in model.stream_chat(
tokenizer, query, history=history):
clear_output(wait=True)
display(Markdown(response))
return history
In [18]:
python
# 微调前
#model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
#model = model.eval()
display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")
上衣材质为牛仔布,颜色为白色,风格为简约,图案为刺绣,衣款式为外套,衣样式为破洞。
Out[18]:
[('类型#上衣\\*材质#牛仔布\\*颜色#白色\\*风格#简约\\*图案#刺绣\\*衣样式#外套\\*衣款式#破洞',
'上衣材质为牛仔布,颜色为白色,风格为简约,图案为刺绣,衣款式为外套,衣样式为破洞。')]
In [22]:
python
# 微调后
import os
import torch
from transformers import AutoConfig
from transformers import AutoModel
from transformers import AutoTokenizer, AutoModel
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
model_path = "/mnt/e/ai-lab/ChatGLM2-6B"
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join("/mnt/e/AI-lab/ChatGLM2-6B/temp/ptuning/output/adgen-chatglm2-6b-pt-128-0.02/checkpoint-1000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
# 使用 Markdown 格式打印模型输出
from IPython.display import display, Markdown, clear_output
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
response, history = model.chat(tokenizer, "类型#上衣*颜色#黑白*风格#简约*风格#休闲*图案#条纹*衣样式#风衣*衣样式#外套", history=[])
print(response)
#display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")
response, history = model.chat(tokenizer, "风衣有什么特征呢", history=[])
print(response)
response, history = model.chat(tokenizer, "日常休闲一般穿什么风格的衣服好呢?", history=[])
print(response)
Loading checkpoint shards: 100%
7/7 [02:39<00:00, 23.82s/it]
Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /mnt/e/ai-lab/ChatGLM2-6B and are newly initialized: ['transformer.prefix_encoder.embedding.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
bash
Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /mnt/e/ai-lab/ChatGLM2-6B and are newly initialized: ['transformer.prefix_encoder.embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
简约的条纹风衣,在黑白两色的搭配下,看起来非常的干练利落。经典的条纹元素,带来一种简约休闲的时尚感,将女性优雅的气质完美展现出来。
这款风衣是经典的风衣款式,采用优质的面料制作,质感舒适。在设计上,风衣采用经典的翻领设计,修饰颈部曲线,让你看起来更加优雅。风衣前襟采用斜线处理,整体看起来更加有设计感。
休闲风格是生活中不可或缺的,无论是在职场还是日常休闲,它都是一种很受欢迎的时尚元素。对于休闲的衣装来说,它一般都具有很亲和的气质,可以搭配出各种不同的风格。像这款休闲的连衣裙,它采用柔软的面料,穿着舒适亲肤,而且可以轻松的搭配出各种不同的风格。