一、基于PaddleNLP的ChatGLM-6B模型lora微调实现Data-To-Text 硬约束下的受控文本生成
- 比赛地址:www.datafountain.cn/competition...
- 另外:该比赛已进入决赛阶段,仅供大家学习借鉴,打比赛时间上已经失去意义。
1.大赛背景
全国大数据与计算智能挑战赛是由国防科技大学系统工程学院大数据与决策实验室组织的年度赛事活动,旨在深入挖掘大数据应用实践中亟需破解的能力生成难题,选拔汇聚数据领域优势团队,促进大数据领域的技术创新和面向需求的成果生成,推动形成"集智众筹、联合攻关、共享共用"的研建用一体迭代演进创新模式。 本届大赛以"发榜挑战、集智攻关"为主题,面向全国大数据与计算智能领域的相关单位,将围绕自然语言处理、图像检测识别、时空数据分析、知识建模分析等前沿技术难点开设赛道,以"揭榜打擂"的形式组织创研竞赛,通过线上打榜与现场评审相结合的方式决出优势团队,进一步创新大数据管理与应用模式,推动大数据与计算智能领域技术发展与生态构建。 热烈欢迎全国各工业部门、科研院所、各高校及民营企业的业内优势团队踊跃参赛揭榜!
了解更多赛事信息 2023全国大数据与计算智能挑战赛
2.赛题介绍
• 赛题名称 Data-To-Text:硬约束下的受控文本生成
• 赛题背景 受控文本生成,作为人工智能技术在自然语言生成方面的一种应用,越来越受到学界和各行业的关注,并在体育、财经、气象等领域取得了一定的成果。目前受控文本生成的数据集和技术方面主要依赖以英语为主体、具有显著领域特征的公开数据集,在中文的军事相关的文本上难以直接落地应用。另一方面,现有受控文本生成技术仍存在遗漏关键信息、语句表达不通顺、错误推断结论等问题,导致文本可读性、可信性不够。
截至目前,已经通过工作实践,积累了大量军事领域的中文结构化数据和文本数据,并经过整理和标注,形成可供学习训练的数据集。期望通过参赛者的聪明才智和共同努力,探索构建从结构化数据到非结构化文本的模型算法,推动Data to Text的技术进步和实践运用。
• 赛题任务 考虑数据集的开源性和赛题的难度梯度,区分互联网初赛和线下复赛2个环节,设计给定关键词的文本生成技术攻关和给定表格的文本生成技术攻关2个课题,二者在技术上同属一域、内容上一脉相承、难度上梯次递增。 初赛任务需要在一组给定顺序的关键词的情况下,生成一段包含所有关键词的文本,生成文本需要具有领域相关性和表达流畅性。复赛任务是在给定一个表格的情况下,生成一段包含表格关键信息的文本,生成文本必须忠实于表格,且满足语法使用正确、表达简洁清晰、语义自然连贯的要求。
硬约束下的受控文本生成需要解决以下难题:
(1)语言复杂性:参赛者需要考虑如何正确使用中文的语法,高效、准确地表达复杂语义。 (2)中文词汇的多义性和歧义性:在不同的语境下中文词汇表达的意思可能不同,参赛者需要对上下文进行更精确的建模以正确理解和表达文本的含义。 (3)受控机制的设计:生成的文本需要具有领域相关性,同时要保持句子的语义流畅性。 (4)关键词的覆盖率和顺序:生成的文本需要包含所有给定的关键词,并且保证关键词的顺序不变。 (5)可参考的资料和文献较少:中文硬约束文本生成的相关工作较少,参赛者需要充分理解试题并设计针对性的解决方案。
3.数据简介
初赛数据集来源于公开军事新闻网站(如新浪军事、中国军网和环球网等)。复赛数据集来源于特定的结构化数据在线采集平台和相应的综合报告。
4.数据说明
初赛数据集共15000条数据,其中训练集10400条数据,验证集2600条数据,测试集2000条数据。初赛任务为给定关键词,生成领域相关的文本。要求生成的文本中包含所有的关键词,并且关键词按顺序出现。在数据集中,关键词(key_words)为给定信息,文本(text)为期望的输出结果。
共包含3个文件:
训练集(train_set.json)
验证集(valid_set.json)
测试集(test_set.json)
• 训练集和验证集:
训练集(train_set.json)共10400条数据,验证集(valid_set.json)共2600条数据。每条的数据格式为:
json
{
"id": 1,
"text": "最近,因为某知名坦克博主的一篇微博,我国的99A型坦克又被推到了风口浪尖,其中最主要的问题就是,原来传的神乎其神的99A坦克的俯角,似乎并没有很多人想象的那么大,那么今天我们不妨来谈一谈坦克俯角的那些事。",
"key_words": ["最近", "坦克", "一篇", "微博", "99A型坦克", "99A坦克", "今天","坦克", "那些事"]
}
二、环境设置
1.升级pip
避免pip安装失败
python
%%capture
#要更新pip要不容易安装失败
!pip install --upgrade pip
2.安装PaddleNLP开发版
python
# https://gitee.com/livingbody/PaddleNLP为我同步的最新的PaddleNLP,官方gitee的有一些旧
!git clone https://gitee.com/livingbody/PaddleNLP -b develop --depth=1
python
%%capture
!pip install -e PaddleNLP/
三、数据转换
1.JSONL简介
JSONL(JSON Lines)是一种文本格式,每行包含一个独立的的数据结构,通常用于大规模数据存储和传输。因此,将JSON转换为JSONL相对简单,只需要将每个JSON对象放在单独的行中即可。
python
import json
# JSON数据
json_data = [
{"name": "Alice", "age": 25},
{"name": "Bob", "age": 30},
{"name": "Charlie", "age": 35}
]
# 打开输出文件
with open("output.jsonl", "w") as f:
# 逐行写入JSON对象
for data in json_data:
f.write(json.dumps(data) + "\n")
在上面的示例中,我们使用Python的内置json模块将JSON对象转换为JSON字符串,然后在输出文件中逐行写入每个JSON对象。注意,每个JSON对象都需要以独立的行为基础,并以换行符结尾。这样,就可以将JSON转换为JSONL格式了。
python
!cat output.jsonl
json
{"name": "Alice", "age": 25}
{"name": "Bob", "age": 30}
{"name": "Charlie", "age": 35}
2.数据格式转换
- train \ dev 转jsonl格式
python
!head data/data222851/test_set.json
json
[
{
"id": 1,
"key_words": [
"拉脱维亚",
"前苏联",
"部署",
"核武器",
"一个",
"核反应堆",
python
# 训练集
import json
from pprint import pprint
def json2jsonl(source_file, target_file):
# 读取 JSON 文件
with open(source_file, 'r', encoding='utf-8') as f:
data = json.load(f)
pprint(len(data))
# 写入文件
f = open(target_file, 'w', encoding="utf-8")
result_list = []
for item in data:
temp = dict()
temp['content'] = ' '.join(item['key_words'])
temp['summary'] = item['text']
f.write(json.dumps(temp, ensure_ascii=False) + "\n")
result_list.append(temp)
f.close()
python
!mkdir converted
json2jsonl('data/data222851/valid_set.json', 'converted/dev.json' )
json2jsonl('data/data222851/train_set.json', 'converted/train.json' )
python
!head -n3 converted/dev.json
css
{"content": "贵飞 2017年 山鹰 一架 后机身 山鹰 后期 批次 山鹰 歼7L 天线", "summary": "据贵飞消息,2017年批产山鹰最后一架份后机身已经交付总装线,至此山鹰系列已生产到了第七批次的第22架,从贵飞发布的照片可以看到较后期批次的山鹰垂尾都装了类似歼7L的数据链天线。"}
{"content": "分毫 火箭军 导弹旅 邱黄成 一个 22年 邱黄成 多次 导弹发射 两次 三等功 火箭军 去年6月 40岁 一名 新时代", "summary": "但凡是跟打仗有关的事,必须一丝不苟,分毫不能差。这是火箭军某导弹旅任务规划队原队长邱黄成对待工作的一贯态度。在战友们眼中,他就是这样一个始终专注于练兵备战的人。入伍22年来,邱黄成先后多次参与导弹发射任务,两次荣立三等功,被评为火箭军基层好主官标兵。去年6月,他因突发颅内淋巴瘤倒在了战位上,用年仅40岁的生命,诠释了一名新时代基层干部的奉献和担当。"}
{"content": "读者网 郭广杰 徐伟 陈聪 尹志伟 2月3日 浙江 安徽 演练 春节 武警部队 节前 演练 突发事件 演练 总指挥组", "summary": "军事读者网讯郭广杰徐伟陈聪尹志伟摄影报道:2月3日,浙江、安徽武警反恐特战队员实装进行了反恐任务演练。春节来临之际,武警部队有针对性地进行了节前处突反恐演练,为贴近实战需求,在演练中设置突发事件,在演练总指挥组的导调下不断变换新情况"}
3.chatglm-6b模型预置
- 使用数据集复制到临时存储模型目录.paddlenlp/models/THUDM/chatglm-6b/,避免等待长耗时的下载
python
!mkdir .paddlenlp/models/THUDM/chatglm-6b/ -p
!cp data/data217141/* .paddlenlp/models/THUDM/chatglm-6b/
四、微调
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。ChatGLM-6B 使用了和 ChatGLM 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
1.参数说明
model_name_or_path
: 预训练模型内置名称或者模型所在目录,默认为THUDM/chatglm-6b
。task_name_or_path
: 数据集存储目录。max_steps
: 模型训练步数。learning_rate
: 参数更新的学习率。warmup_steps
: 学习率热启的步数。eval_steps
: 模型评估的间隔步数。logging_steps
: 训练日志打印的间隔步数。save_steps
: 模型参数保存的间隔步数。save_total_limit
: 模型 checkpoint 保存的份数。output_dir
: 模型参数保存目录。src_length
: 上下文的最大输入长度,默认为128.tgt_length
: 生成文本的最大长度,默认为160.gradient_accumulation_steps
: 模型参数梯度累积的步数,可用于扩大 batch size。实际的 batch_size = per_device_train_batch_size * gradient_accumulation_steps。fp16
: 使用 float16 精度进行模型训练和推理。fp16_opt_level
: float16 精度训练模式,O2
表示纯 float16 训练。recompute
: 使用重计算策略,开启后可节省训练显存。do_train
: 是否训练模型。do_eval
: 是否评估模型。tensor_parallel_degree
: 模型并行数量。eval_with_do_generation
: 在评估的时候是否调用model.generate,默认为False。
2.lora微调
python
!python ~/PaddleNLP/llm/chatglm/finetune_generation.py \
--output_dir ./checkpoints/chatglm-6b \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 8 \
--model_name_or_path THUDM/chatglm-6b \
--task_name_or_path converted/ \
--num_train_epochs 2 \
--learning_rate 3e-4 \
--warmup_steps 30 \
--logging_steps 1 \
--evaluation_strategy epoch \
--save_strategy epoch \
--src_length 1024 \
--tgt_length 1024 \
--fp16 \
--fp16_opt_level O2 \
--do_train \
--do_eval \
--disable_tqdm True \
--load_best_model_at_end True \
--metric_for_best_model accuracy \
--eval_with_do_generation False \
--recompute \
--save_total_limit 1 \
--overwrite_output_dir \
--lora True \
--lora_rank 8
yaml
r_second: 1.09, ppl: 3.9061367028048903, epoch: 1.1923
[2023-08-02 20:31:35,256] [ INFO] - loss: 1.37255859, learning_rate: 0.0001233, global_step: 1551, interval_runtime: 0.8586, interval_samples_per_second: 9.318, interval_steps_per_second: 1.165, ppl: 3.945432536555322, epoch: 1.1931
[2023-08-02 20:31:36,232] [ INFO] - loss: 1.80371094, learning_rate: 0.0001232, global_step: 1552, interval_runtime: 0.9752, interval_samples_per_second: 8.204, interval_steps_per_second: 1.025, ppl: 6.072139049825416, epoch: 1.1938
[2023-08-02 20:31:37,414] [ INFO] - loss: 1.47412109, learning_rate: 0.000123, global_step: 1553, interval_runtime: 1.1831, interval_samples_per_second: 6.762, interval_steps_per_second: 0.845, ppl: 4.367195713659017, epoch: 1.1946
[2023-08-02 20:31:38,292] [ INFO] - loss: 1.87402344, learning_rate: 0.0001229, global_step: 1554, interval_runtime: 0.8781, interval_samples_per_second: 9.11, interval_steps_per_second: 1.139, ppl: 6.514454257550992, epoch: 1.1954
[2023-08-02 20:31:39,253] [ INFO] - loss: 1.85986328, learning_rate: 0.0001228, global_step: 1555, interval_runtime: 0.9606, interval_samples_per_second: 8.328, interval_steps_per_second: 1.041, ppl: 6.422858578172402, epoch: 1.1962
五、预测
1.修改预测脚本
修改PaddleNLP/llm/chatglm/predict_generation.py第167行,加入预测数据看看效果。
ini
if __name__ == "__main__":
args = parse_arguments()
paddle.set_device(args.device)
predictor = Predictor(args)
if args.data_file is None:
all_texts = [
"拉脱维亚 前苏联 部署 核武器 一个 核反应堆 苏美 1987年 中程导弹 前苏联 核武器 核反应堆 苏联解体 拉脱维亚 一个 无核国家",
"通用动力公司 一份 1710万美元 M1A2坦克 今年4月 一份 7.41亿美元 跨年度 307辆 SEP坦克 此前 通用动力公司 580辆 SEP坦克 206辆 坦克",
"法国 国防部 网站 法国 今年10月8日至11月2日 派遣 1200名 埃及 联合军事演习 每隔两年 一次 20多个 埃及 美国"
]
2.预测数据集处置
以空格隔开各个词汇,存储为json
python
# 训练集
import json
from pprint import pprint
def save_test(source_file, target_file):
# 读取 JSON 文件
with open(source_file, 'r', encoding='utf-8') as f:
data = json.load(f)
pprint(len(data))
# 写入文件
f = open(target_file, 'w', encoding="utf-8")
result_list = []
for item in data:
temp = dict()
temp['content'] = ' '.join(item['key_words'])
temp['summary'] = ''
f.write(json.dumps(temp, ensure_ascii=False) + "\n")
result_list.append(temp)
f.close()
python
test_file='data/data222851/test_set.json'
test_convert='test_set.json'
save_test(test_file, test_convert)
yaml
2000
python
!head -n3 test_set.json
css
{"content": "拉脱维亚 前苏联 部署 核武器 一个 核反应堆 苏美 1987年 中程导弹 前苏联 核武器 核反应堆 苏联解体 拉脱维亚 一个 无核国家", "summary": ""}
{"content": "通用动力公司 一份 1710万美元 M1A2坦克 今年4月 一份 7.41亿美元 跨年度 307辆 SEP坦克 此前 通用动力公司 580辆 SEP坦克 206辆 坦克", "summary": ""}
{"content": "法国 国防部 网站 法国 今年10月8日至11月2日 派遣 1200名 埃及 联合军事演习 每隔两年 一次 20多个 埃及 美国", "summary": ""}
3.开始预测
对前三条数据进行预测
css
{"content": "拉脱维亚 前苏联 部署 核武器 一个 核反应堆 苏美 1987年 中程导弹 前苏联 核武器 核反应堆 苏联解体 拉脱维亚 一个 无核国家", "summary": ""}
{"content": "通用动力公司 一份 1710万美元 M1A2坦克 今年4月 一份 7.41亿美元 跨年度 307辆 SEP坦克 此前 通用动力公司 580辆 SEP坦克 206辆 坦克", "summary": ""}
{"content": "法国 国防部 网站 法国 今年10月8日至11月2日 派遣 1200名 埃及 联合军事演习 每隔两年 一次 20多个 埃及 美国", "summary": ""}
python
# 覆盖预测文件
!cp predict_generation.py PaddleNLP/llm/chatglm/predict_generation.py -rf
python
!python PaddleNLP/llm/chatglm/predict_generation.py \
--model_name_or_path THUDM/chatglm-6b \
--lora_path ./model_checkpoints/chatglm-6b/checkpoint-1300
less
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.
warnings.warn("Setuptools is replacing distutils.")
[32m[2023-08-02 21:34:02,338] [ INFO][0m - Found /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/tokenizer_config.json[0m[32m[2023-08-02 21:34:02,339] [ INFO][0m - We are using <class 'paddlenlp.transformers.chatglm.tokenizer.ChatGLMTokenizer'> to load 'THUDM/chatglm-6b'.[0m[32m[2023-08-02 21:34:02,339] [ INFO][0m - Already cached /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/ice_text.model[0m[32m[2023-08-02 21:34:02,339] [ INFO][0m - Downloading https://bj.bcebos.com/paddlenlp/models/community/THUDM/chatglm-6b/added_tokens.json and saved to /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b[0m[33m[2023-08-02 21:34:02,374] [ WARNING][0m - file<https://bj.bcebos.com/paddlenlp/models/community/THUDM/chatglm-6b/added_tokens.json> not exist[0m[32m[2023-08-02 21:34:02,374] [ INFO][0m - Already cached /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/special_tokens_map.json[0m[32m[2023-08-02 21:34:02,374] [ INFO][0m - Already cached /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/tokenizer_config.json[0m[32m[2023-08-02 21:34:02,686] [ INFO][0m - Found /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/config.json[0m[32m[2023-08-02 21:34:02,687] [ INFO][0m - We are using <class 'paddlenlp.transformers.chatglm.modeling.ChatGLMForCausalLM'> to load 'THUDM/chatglm-6b'.[0m[33m[2023-08-02 21:34:02,687] [ WARNING][0m - `load_state_as_np` is deprecated, please delete it![0m[32m[2023-08-02 21:34:02,711] [ INFO][0m - Found /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/config.json[0m[32m[2023-08-02 21:34:02,711] [ INFO][0m - Loading configuration file /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/config.json[0m[32m[2023-08-02 21:34:02,712] [ INFO][0m - Already cached /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/model_state.pdparams[0m[32m[2023-08-02 21:34:02,712] [ INFO][0m - Loading weights file model_state.pdparams from cache at /home/aistudio/.paddlenlp/models/THUDM/chatglm-6b/model_state.pdparams[0m[32m[2023-08-02 21:35:38,457] [ INFO][0m - Loaded weights file from disk, setting weights to model.[0mW0802 21:35:38.460880 6418 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.6
W0802 21:35:38.468986 6418 gpu_resources.cc:149] device: 0, cuDNN Version: 8.4.
[33m[2023-08-02 21:35:50,552] [ WARNING][0m - Some weights of the model checkpoint at THUDM/chatglm-6b were not used when initializing ChatGLMForCausalLM: ['lm_head.weight']
- This IS expected if you are initializing ChatGLMForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ChatGLMForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).[0m[33m[2023-08-02 21:35:50,552] [ WARNING][0m - Some weights of ChatGLMForCausalLM were not initialized from the model checkpoint at THUDM/chatglm-6b and are newly initialized: ['transformer.rotary_embeddings.inv_freq', 'lm_head.decoder_weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.[0m[33m[2023-08-02 21:35:51,265] [ WARNING][0m - Reset tensor_parallel_degree of lora_config to 1.[0m[32m[2023-08-02 21:35:51,299] [ INFO][0m - Loading the LoRA weights from ./model_checkpoints/chatglm-6b/checkpoint-1300/lora_model_state.pdparams[0m[32m[2023-08-02 21:35:51,330] [ INFO][0m - Load lora weight successfully[0m拉脱维亚 前苏联 部署 核武器 一个 核反应堆 苏美 1987年 中程导弹 前苏联 核武器 核反应堆 苏联解体 拉脱维亚 一个 无核国家拉脱维亚是前苏联部署核武器的一个核反应堆,苏美于1987年进行了中程导弹试射,表明前苏联已经具备核武器,但核反应堆没有建成,苏联解体后,拉脱维亚成为一个独立的国家,但仍然是一个独立的核国家。通用动力公司 一份 1710万美元 M1A2坦克 今年4月 一份 7.41亿美元 跨年度 307辆 SEP坦克 此前 通用动力公司 580辆 SEP坦克 206辆 坦克通用动力公司一份1710万美元的订单,将用于购买M1A2坦克。今年4月,通用动力公司一份价值7.41亿美元的合同,将用于购买307辆SEP坦克。此前通用动力公司已经向军方交付了580辆SEP坦克,还有206辆坦克等待交付。法国 国防部 网站 法国 今年10月8日至11月2日 派遣 1200名 埃及 联合军事演习 每隔两年 一次 20多个 埃及 美国据法国国防部网站消息,法国今年10月8日至11月2日派遣了大约1200名士兵参加埃及的联合军事演习。该演习每隔两年举行一次,20多个国家参加,埃及是其中之一。图为美国士兵参加演习。
4.效果分析 结果如下
yaml
拉脱维亚 前苏联 部署 核武器 一个 核反应堆 苏美 1987年 中程导弹 前苏联 核武器 核反应堆 苏联解体 拉脱维亚 一个 无核国家
拉脱维亚是前苏联部署核武器的一个核反应堆,苏美于1987年进行了中程导弹试射,表明前苏联已经具备核武器,但核反应堆没有建成,苏联解体后,拉脱维亚成为一个独立的国家,但仍然是一个独立的核国家。
通用动力公司 一份 1710万美元 M1A2坦克 今年4月 一份 7.41亿美元 跨年度 307辆 SEP坦克 此前 通用动力公司 580辆 SEP坦克 206辆 坦克
通用动力公司一份1710万美元的订单,将用于购买M1A2坦克。今年4月,通用动力公司一份价值7.41亿美元的合同,将用于购买307辆SEP坦克。此前通用动力公司已经向军方交付了580辆SEP坦克,还有206辆坦克等待交付。
法国 国防部 网站 法国 今年10月8日至11月2日 派遣 1200名 埃及 联合军事演习 每隔两年 一次 20多个 埃及 美国
据法国国防部网站消息,法国今年10月8日至11月2日派遣了大约1200名士兵参加埃及的联合军事演习。该演习每隔两年举行一次,20多个国家参加,埃及是其中之一。图为美国士兵参加演习。
感觉较为流畅,效果不错!
六、总结
-
通过使用lora对chatglm-6b进行微调,实现了根据给出多个关键词组,生成流畅的新闻。在认知域作战领域有一定用武之地。
-
该题目应该是5月左右发布的,当时使用飞桨没有思路,目前大模型开源层出不穷,终于可以轻松解决该应用了。