本文内容主要基于以下开源项目探索实践,
Awesome-Text2SQL:github.com/eosphoros-a...
DB-GPT-Hub:github.com/eosphoros-a...
DB-GPT:github.com/eosphoros-a...
DeepSpeedExamples:github.com/microsoft/D...
开源不易,希望大家给个star支持一下,感谢!
摘要
本文主要介绍了Text2SQL的基本概念,以及RLHF的概念和框架,最后结合DB-GPT-Hub项目,将RLHF方法应用于Text2SQL任务进行实践探索。
Text2SQL简介
本章主要对Text2SQL的基本定义、使用的开源数据集和评测指标做了介绍,同时也介绍了一些实践项目,供大家参考。
定义
Text-to-SQL(简写为Text2SQL),顾名思义就是把文本转化为SQL语言,更学术一点的定义 是:把数据库领域下的自然语言(Natural Language,简写为NL )问题,转化为在关系型数据库中可以执行的结构化查询语言(Structured Query Language,简写为SQL ),因此Text2SQL也可以被简写为NL2SQL。
举个例子比较直观:
-
输入:自然语言问题。
查询表t_user的所有信息,结果按id降序排序,只保留前10个数据
-
输出:SQL语句。
SELECT * FROM t_user ORDER BY id DESC LIMIT 10
-
实验:如图1所示,在DB-GPT项目中,直接使用原生对话,使用Proxy LLM(GPT-3.5)提问上述问题,大模型可以准确给出SQL答案,这也是因为LLM本身语言理解能力强大,同时提问的自然语言问题比较easy。
图1 DB-GPT项目原生对话示意图
数据集
公开的Text2SQL数据集比较多,这里仅介绍目前使用较多的几个数据集:
-
2017年9月,Salesforce提出的一个大型的Text-to-SQL数据集,数据来源于Wikipedia,属于单领域,包含了80654个自然语言问题,77840个SQL语句,SQL语句形式比较简单,不包含排序、分组、子查询等复杂操作。
-
2018年9月,耶鲁大学提出的多数据库、多表、单轮查询的Text-to-SQL数据集,也是业界公认难度最大的大规模跨领域评测榜单,包含了10181个自然语言问题,5693个SQL语句,涉及138个不同领域的200多个数据库,难易程度分为:简单、中等、困难、特别困难。
-
2019/09, 耶鲁大学和Salesforce Research提出了一种跨域数据库CoSQL,它由30k+轮次和10k+带注释的SQL查询组成,这些查询是从Wizard-of-Oz (WOZ)集合中获得的,该集合包含3k个对话,查询跨越 138个域的200个复杂数据库。
-
2021年8月,西安交通大学和微软等提出了首个跨领域、多轮Text-to-SQL中文数据集,包含了5459个多轮问题组成的列表,17940个<query, SQL>二元组。
-
2023年5月,香港大学和阿里巴巴提出了一个大规模跨域数据集BIRD,其中包含超过12751个独特的问题 SQL、95个大数据库,总大小为33.4GB。它还涵盖区块链、曲棍球、医疗保健和教育等超过37个专业领域。
如何还想了解更多数据集以及Text2SQL的基本知识,可以查看我之前知乎的Text2QL综述文章:Text-to-SQL小白入门(一)综述文章学习
评测指标
以Spider数据集为例:主要有两个指标,分别是执行准确率(Execution Accuracy,简称EX)和逻辑形式准确率(Exact Match,简称EM)
-
EX
-
计算SQL执行结果正确的数量在数据集中的比例,结果存在高估的可能。
-
EM
-
计算模型生成的SQL和标注SQL的匹配程度,结果存在低估的可能。
在Awesome-Text2SQL项目中,列举了常见的数据以及对应的指标榜单,如图2所示,比如Spider数据集上,目前EX得分第一是MiniSeek组织提交的91.2,EM得分第一也是MiniSeek提交的81.5,因为运用了GPT-4以及一些其他的trick,所以得分最高。
图2 Awesome-Text2SQL项目数据集得分榜单
实验方法
Text2SQL研究主要有基于模版和匹配的方法、基于Seq2Seq框架的方法和基于模型预训练的方法,随着LLM的崛起,如今利用LLM微调完成Text2SQL任务也越来越常见,比如在DB-GPT-Hub项目中,就实现了利用各种开源模型在Spider数据集上进行lora和qlora方法微调,亲测好用!(方法详情可以参考代码仓库)
RLHF简介
本章主要介绍了RLHF的基本定义,以及介绍了强化学习的基础概念和RLHF框架。
定义
RLHF:R einforcement L earning from H uman Feedback,通过强化学习方式方式根据人类反馈优化语言模型,使得在一般文本数据语料库的语言模型能够和复杂人类价值观对齐。
强化学习基础概念
RL:指的是Reinforcement learning。
- 强化学习是一种机器学习方法,旨在通过智能体(agent )与环境(environment )的交互学习如何在动态环境中做出决策(action )以最大化累积回报(reward )。在强化学习中,智能体通过观察环境的状态、采取行动和接收奖励来学习与环境的交互。智能体的目标是通过学习最优的策略(policy),在不断尝试和调整中,使得长期累积的奖励最大化。
- 强化学习最早在游戏中应用比较多。
为了更好理解强化学习,我们可以先了解一下比较常见的有监督学习(Supervised Learning, SL)。对于有监督学习而言,模型完整的训练pipline通常可以分成如图3所示:
图3 有监督学习示意图
-
输入标注好的数据labeled data(有标签ground truth+原始数据)
-
1.从标签数据中获取原始数据
-
2.把原始数据拿给模型训练(比如卷积神经网络CNN)
-
3.模型根据当前数据输出预测值predict
-
4.通过损失函数loss function计算预测值和真实值之间的loss
-
5.loss更新给模型
-
然后重复上述1-5步骤,训练模型。【优化目标:把loss变小】
-
输出训练好的模型
对于强化学习而言,模型训练的pipline也是类似的,如图4所示。
-
输入初始化的环境environment
-
1.从环境获取当前状态state
-
2.把当前state拿给智能体agent
-
3.agent根据环境的状态输出采取的动作action
-
4.action和环境进行交互,通过奖励函数reward function计算当前奖励
-
5.奖励和状态更新给智能体agent
-
然后重复上述1-5步骤,训练agent。【优化目标:把reward变大】
图4 有监督学习和强化学习对比示意图
由上面讲述可知,强化学习的基本组成主要由以下部分:
-
environment
-
agent
-
state
-
reward
-
action
-
policy: 策略。定义了agent如何根据当前的state来做出action。策略主要可以分为on-policy和off-policy。
-
On-policy: 学习到的agent以及和环境进行互动的agent是同一个agent ,比如PPO算法 (eg:你在打游戏,你在实战中变强。)
-
Off-policy: 学习到的agent以及和环境进行互动的agent是不同的agent,比如DQN算法(eg: 你在看直播,你在观摩中变强。)
RLHF框架
RLHF方法最早是在是2017年论文(Deep reinforcement learning from human preferences)提出。
- 在2020年的论文(Learning to summarize from human feedback)中RM训练使用了交叉熵损失。
- 在2023年3月OpenAI发表的论文(Training language models to follow instructions with human feedback)中进一步提供了RLHF实现的标准范式(论文中训练的模型为InstructGPT,ChatGPT是改进后的InstructGPT,比如InstructGPT是基于GPT-3训练,而ChatGPT是基于GPT-3.5训练),如图5所示。
如果想了解InstructGPT论文的详细内容,可以参考我之前的知乎文章:Text-to-SQL小白入门(九)InstructGPT论文:教你如何训练ChatGPT
图5 InstructGPT论文中的RLHF实现范式
RLHF主要流程有3步:
第一阶段:SFT
- Supervised Fine-tuning有监督微调,简称为SFT。这是InstructGPT(ChatGPT等)训练过程中的一个重要步骤,主要采用有监督的方式对与预训练的LLM进行微调,这个方法比较依赖于标注的数据,SFT数据集标注质量越高(质量不等同于数据),模型的效果越好。
之前听一个大学教授的讲座,有个观点很有意思:Open AI做大模型为什么比谷歌强,因为包括transformer在内的一些创新模型大多是谷歌研究的,那为什么Open AI在大模型领域为什么比谷歌强?答:因为Open AI在数据清洗,数据质量把控这方面做的很好。------所以数据是相当重要的!
第二阶段:RM
- Reward Model奖励模型训练,是InstructGPT训练过程的第二阶段,它的目标是训练一个模型来适应人类的偏好(这里主要是标注人员的偏好)。在RM训练阶段,输入prompt,会使LLM生成多个响应response,然后标注人员对这些响应进行排名,根据这些排名训练一个奖励模型。
第三阶段:RL
-
Reinforcement Learning,是InstructGPT训练中的最后步骤,主要是通过PPO策略(proximal policy optimization 近端策略优化)迭代,它通过引入奖励信号来调整模型的行为,使模型生成的内容更符合人类的偏好。
-
输入一个标注数据,模型经过PPO输出一个response
-
RM模型对response打分
-
根据打分score更新PPO策略。
RLHF+Text2SQL的实践探索
本章节主要结合DB-GPT-Hub项目代码以及一些RLHF代码对Text2SQL进行了实践探索。
SFT
SFT模块的实现主要参考DB-GPT-Hub,比如在Spider数据集上进行实现。
数据预处理
bash
sh dbgpt_hub/scripts/gen_train_eval_data.sh
经过数据预处理后,可以得到example_text2sql_train.json和example_text2sql_dev.json
数据格式
数据格式如下所示:
-
db_id-instruction-input-output-history
{ "db_id": "department_management", "instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n\n", "input": "###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:", "output": "SELECT count(*) FROM head WHERE age > 56", "history": [] }
-
最终经过代码后会形成为这样的格式:prompt-output
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}
训练
bash
sh dbgpt_hub/scripts/train_sft.sh
训练的基础大模型为CodeLlama-13b-instruct,如果想了解该开源模型,可以参考论文讲解:Text-to-SQL小白入门(五)开源代码大模型Code Llama
训练的参数如下所示:
css
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
--model_name_or_path /home/model/CodeLlama-13B-Instruct \
--do_train \
--dataset example_text2sql_train \
--max_source_length 2048 \
--max_target_length 512 \
--template llama2 \
--finetuning_type lora \
--lora_rank 64 \
--lora_alpha 32 \
--lora_target q_proj,v_proj \
--output_dir dbgpt_hub/output/adapter/CodeLlama-13B-Instruct-lora \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--lr_scheduler_type cosine_with_restarts \
--logging_steps 500 \
--save_steps 2000 \
--learning_rate 2e-4 \
--num_train_epochs 8 \
--plot_loss \
--bf16
预测
bash
sh dbgpt_hub/scripts/predict_sft.sh
预测完成后,会生成一个predict.sql文件,文件中存放了dev集合中1034个sql.
评估
测试的库为ts库
css
python dbgpt_hub/eval/evaluation.py --plug_value --input Your_model_pred_file
评估过程如下所示:会对每一个sql进行对比,对错误的sql进行打印输出展示。
最终对1034条sql验证完成后,可以得到EX、EM精度得分。
- EX-0.746
其他模型的一些baseline分数也可以通过DB-GPT-Hub获取。
RM
RM模型训练的模型以SFT阶段的模型为基础,参考微软代码进行训练(Hub项目近期也会增加RLHF功能,敬请期待),自行构建了少量Text2SQL的RM训练数据集用于测试训练。
数据格式
数据格式如下所示:
-
prompy-chosen-rejected
-
chosen就是在SFT阶段的ground truth
-
rejected就是模型的错误输出结果
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","chosen": "SELECT count(*) FROM head WHERE age > 56","rejected":"SELECT COUNT(head_name) FROM head WHERE age > 56;"}
训练
-
比如训练10个epoch的训练结果如下:
deepspeed --num_gpus= <math xmlns="http://www.w3.org/1998/Math/MathML"> n g p u m a i n . p y − − d a t a p a t h n_gpu \ main.py \ --data_path </math>ngpu main.py −−datapathdata_path
--data_split 2,4,4
--model_name_or_path <math xmlns="http://www.w3.org/1998/Math/MathML"> m o d e l n a m e o r p a t h − − p e r d e v i c e t r a i n b a t c h s i z e 8 − − p e r d e v i c e e v a l b a t c h s i z e 8 − − m a x s e q l e n 1024 − − l e a r n i n g r a t e 9.65 e − 6 − − w e i g h t d e c a y 0.1 − − n u m p a d d i n g a t b e g i n n i n g 0 − − n u m t r a i n e p o c h s 10 − − g r a d i e n t a c c u m u l a t i o n s t e p s 1 − − l r s c h e d u l e r t y p e c o s i n e − − n u m w a r m u p s t e p s 0 − − s e e d 1234 − − g r a d i e n t c h e c k p o i n t i n g − − z e r o s t a g e model_name_or_path \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 8 \ --max_seq_len 1024 \ --learning_rate 9.65e-6 \ --weight_decay 0.1 \ --num_padding_at_beginning 0 \ --num_train_epochs 10 \ --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --num_warmup_steps 0 \ --seed 1234 \ --gradient_checkpointing \ --zero_stage </math>modelnameorpath −−perdevicetrainbatchsize8 −−perdeviceevalbatchsize8 −−maxseqlen1024 −−learningrate9.65e−6 −−weightdecay0.1 −−numpaddingatbeginning0 −−numtrainepochs10 −−gradientaccumulationsteps1 −−lrschedulertypecosine −−numwarmupsteps0 −−seed1234 −−gradientcheckpointing −−zerostageZERO_STAGE
--deepspeed
--offload
--lora_dim 128
--lora_module_name "layers."
--output_dir OUTPUT \ 2>&1 | tee OUTPUT/log.txt
结果
训练完成后,会在制定目前生成训练好的模型,比如有以下文件:
- config.json
- log.txt
- pytorch_model.bin
- tokenizer.model
RL
数据格式
RL阶段和SFT阶段的数据格式保持一致,以Text2SQL任务举例子,RL数据可以构造为(prompt,output}的二元组,如下所示:
-
prompt-otput
{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}
训练
-
训练参数
-
SFT模型即为上面训练的SFT模型
-
RM模型即为上面训练的RM模型
-
训练10epoch
deepspeed --master_port 12346 main.py
--data_path <math xmlns="http://www.w3.org/1998/Math/MathML"> d a t a p a t h − − d a t a s p l i t 2 , 4 , 4 − − a c t o r m o d e l n a m e o r p a t h data_path \ --data_split 2,4,4 \ --actor_model_name_or_path </math>datapath −−datasplit2,4,4 −−actormodelnameorpathACTOR_MODEL_PATH
--critic_model_name_or_path <math xmlns="http://www.w3.org/1998/Math/MathML"> C R I T I C M O D E L P A T H − − n u m p a d d i n g a t b e g i n n i n g 1 − − p e r d e v i c e g e n e r a t i o n b a t c h s i z e 8 − − p e r d e v i c e t r a i n i n g b a t c h s i z e 8 − − g e n e r a t i o n b a t c h e s 1 − − p p o e p o c h s 1 − − m a x a n s w e r s e q l e n 256 − − m a x p r o m p t s e q l e n 1024 − − a c t o r l e a r n i n g r a t e CRITIC_MODEL_PATH \ --num_padding_at_beginning 1 \ --per_device_generation_batch_size 8 \ --per_device_training_batch_size 8 \ --generation_batches 1 \ --ppo_epochs 1 \ --max_answer_seq_len 256 \ --max_prompt_seq_len 1024 \ --actor_learning_rate </math>CRITICMODELPATH −−numpaddingatbeginning1 −−perdevicegenerationbatchsize8 −−perdevicetrainingbatchsize8 −−generationbatches1 −−ppoepochs1 −−maxanswerseqlen256 −−maxpromptseqlen1024 −−actorlearningrate{Actor_Lr}
--critic_learning_rate <math xmlns="http://www.w3.org/1998/Math/MathML"> C r i t i c L r − − a c t o r w e i g h t d e c a y 0.1 − − c r i t i c w e i g h t d e c a y 0.1 − − n u m t r a i n e p o c h s 10 − − l r s c h e d u l e r t y p e c o s i n e − − g r a d i e n t a c c u m u l a t i o n s t e p s 1 − − a c t o r g r a d i e n t c h e c k p o i n t i n g − − c r i t i c g r a d i e n t c h e c k p o i n t i n g − − o f f l o a d r e f e r e n c e m o d e l − − d i s a b l e a c t o r d r o p o u t − − n u m w a r m u p s t e p s 100 − − d e e p s p e e d − − s e e d 1234 − − a c t o r z e r o s t a g e {Critic_Lr} \ --actor_weight_decay 0.1 \ --critic_weight_decay 0.1 \ --num_train_epochs 10 \ --lr_scheduler_type cosine \ --gradient_accumulation_steps 1 \ --actor_gradient_checkpointing \ --critic_gradient_checkpointing \ --offload_reference_model \ --disable_actor_dropout \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --actor_zero_stage </math>CriticLr −−actorweightdecay0.1 −−criticweightdecay0.1 −−numtrainepochs10 −−lrschedulertypecosine −−gradientaccumulationsteps1 −−actorgradientcheckpointing −−criticgradientcheckpointing −−offloadreferencemodel −−disableactordropout −−numwarmupsteps100 −−deepspeed−−seed1234 −−actorzerostageACTOR_ZERO_STAGE
--critic_zero_stage <math xmlns="http://www.w3.org/1998/Math/MathML"> C R I T I C Z E R O S T A G E − − e n a b l e h y b r i d e n g i n e − − a c t o r l o r a d i m 64 − − c r i t i c l o r a d i m 64 − − c r i t i c l o r a m o d u l e n a m e " l a y e r s . " − − a c t o r l o r a m o d u l e n a m e " l a y e r s . " − − o u t p u t d i r CRITIC_ZERO_STAGE \ --enable_hybrid_engine \ --actor_lora_dim 64 \ --critic_lora_dim 64 \ --critic_lora_module_name "layers." \ --actor_lora_module_name "layers." \ --output_dir </math>CRITICZEROSTAGE −−enablehybridengine −−actorloradim64 −−criticloradim64 −−criticloramodulename"layers." −−actorloramodulename"layers." −−outputdirOUTPUT
2>&1 | tee $OUTPUT/log.txt
-
训练结束
结果
训练结束会得到两个模型,actor模型即为需要的最终评测模型。
验证
- 验证得到的模型
- EX-0.752
- EM-0.717
可以发现的是,RLHF相比SFT方法,精度有轻微提升,主要是数据质量的问题,后续还可以进一步探索。
其他文章学习
Text-to-SQL小白入门(二)Transformer学习
Text-to-SQL小白入门(三)IRNet:引入中间表示SemQL
Text-to-SQL小白入门(四)指令进化大模型WizardLM
Text-to-SQL小白入门(五)开源代码大模型Code Llama
Text-to-SQL小白入门(六)Awesome-Text2SQL项目介绍
Text-to-SQL小白入门(七)PanGu-Coder2论文------RRTF