Text-to-SQL小白入门(十)RLHF在Text2SQL领域的探索实践

本文内容主要基于以下开源项目探索实践,

开源不易,希望大家给个star支持一下,感谢!

摘要

本文主要介绍了Text2SQL的基本概念,以及RLHF的概念和框架,最后结合DB-GPT-Hub项目,将RLHF方法应用于Text2SQL任务进行实践探索。

Text2SQL简介

本章主要对Text2SQL的基本定义、使用的开源数据集和评测指标做了介绍,同时也介绍了一些实践项目,供大家参考。

定义

Text-to-SQL(简写为Text2SQL),顾名思义就是把文本转化为SQL语言,更学术一点的定义 是:把数据库领域下的自然语言(Natural Language,简写为NL )问题,转化为在关系型数据库中可以执行的结构化查询语言(Structured Query Language,简写为SQL ),因此Text2SQL也可以被简写为NL2SQL

举个例子比较直观:

  • 输入:自然语言问题。

    查询表t_user的所有信息,结果按id降序排序,只保留前10个数据

  • 输出:SQL语句。

    SELECT * FROM t_user ORDER BY id DESC LIMIT 10

  • 实验:如图1所示,在DB-GPT项目中,直接使用原生对话,使用Proxy LLM(GPT-3.5)提问上述问题,大模型可以准确给出SQL答案,这也是因为LLM本身语言理解能力强大,同时提问的自然语言问题比较easy。

图1 DB-GPT项目原生对话示意图

数据集

公开的Text2SQL数据集比较多,这里仅介绍目前使用较多的几个数据集:

  • WikiSQL [paper] [code] [dataset]

  • 2017年9月,Salesforce提出的一个大型的Text-to-SQL数据集,数据来源于Wikipedia,属于单领域,包含了80654个自然语言问题,77840个SQL语句,SQL语句形式比较简单,不包含排序、分组、子查询等复杂操作。

  • Spider [paper] [code] [dataset]

  • 2018年9月,耶鲁大学提出的多数据库、多表、单轮查询的Text-to-SQL数据集,也是业界公认难度最大的大规模跨领域评测榜单,包含了10181个自然语言问题,5693个SQL语句,涉及138个不同领域的200多个数据库,难易程度分为:简单、中等、困难、特别困难。

  • CoSQL [paper] [code] [dataset]

  • 2019/09, 耶鲁大学和Salesforce Research提出了一种跨域数据库CoSQL,它由30k+轮次和10k+带注释的SQL查询组成,这些查询是从Wizard-of-Oz (WOZ)集合中获得的,该集合包含3k个对话,查询跨越 138个域的200个复杂数据库。

  • CHASE [paper] [code] [dataset]

  • 2021年8月,西安交通大学和微软等提出了首个跨领域、多轮Text-to-SQL中文数据集,包含了5459个多轮问题组成的列表,17940个<query, SQL>二元组。

  • BIRD-SQL [paper] [code] [dataset]

  • 2023年5月,香港大学和阿里巴巴提出了一个大规模跨域数据集BIRD,其中包含超过12751个独特的问题 SQL、95个大数据库,总大小为33.4GB。它还涵盖区块链、曲棍球、医疗保健和教育等超过37个专业领域。

如何还想了解更多数据集以及Text2SQL的基本知识,可以查看我之前知乎的Text2QL综述文章:Text-to-SQL小白入门(一)综述文章学习

评测指标

以Spider数据集为例:主要有两个指标,分别是执行准确率(Execution Accuracy,简称EX)和逻辑形式准确率(Exact Match,简称EM)

  • EX

  • 计算SQL执行结果正确的数量在数据集中的比例,结果存在高估的可能。

  • EM

  • 计算模型生成的SQL和标注SQL的匹配程度,结果存在低估的可能。

Awesome-Text2SQL项目中,列举了常见的数据以及对应的指标榜单,如图2所示,比如Spider数据集上,目前EX得分第一是MiniSeek组织提交的91.2,EM得分第一也是MiniSeek提交的81.5,因为运用了GPT-4以及一些其他的trick,所以得分最高。

图2 Awesome-Text2SQL项目数据集得分榜单

实验方法

Text2SQL研究主要有基于模版和匹配的方法、基于Seq2Seq框架的方法和基于模型预训练的方法,随着LLM的崛起,如今利用LLM微调完成Text2SQL任务也越来越常见,比如在DB-GPT-Hub项目中,就实现了利用各种开源模型在Spider数据集上进行lora和qlora方法微调,亲测好用!(方法详情可以参考代码仓库)

RLHF简介

本章主要介绍了RLHF的基本定义,以及介绍了强化学习的基础概念和RLHF框架。

定义

RLHF:R einforcement L earning from H uman Feedback,通过强化学习方式方式根据人类反馈优化语言模型,使得在一般文本数据语料库的语言模型能够和复杂人类价值观对齐。

强化学习基础概念

RL:指的是Reinforcement learning。

  • 强化学习是一种机器学习方法,旨在通过智能体(agent )与环境(environment )的交互学习如何在动态环境中做出决策(action )以最大化累积回报(reward )。在强化学习中,智能体通过观察环境的状态、采取行动和接收奖励来学习与环境的交互。智能体的目标是通过学习最优的策略(policy),在不断尝试和调整中,使得长期累积的奖励最大化。
  • 强化学习最早在游戏中应用比较多。

为了更好理解强化学习,我们可以先了解一下比较常见的有监督学习(Supervised Learning, SL)。对于有监督学习而言,模型完整的训练pipline通常可以分成如图3所示:

图3 有监督学习示意图

  • 输入标注好的数据labeled data(有标签ground truth+原始数据)

  • 1.从标签数据中获取原始数据

  • 2.把原始数据拿给模型训练(比如卷积神经网络CNN)

  • 3.模型根据当前数据输出预测值predict

  • 4.通过损失函数loss function计算预测值和真实值之间的loss

  • 5.loss更新给模型

  • 然后重复上述1-5步骤,训练模型。【优化目标:把loss变小】

  • 输出训练好的模型

对于强化学习而言,模型训练的pipline也是类似的,如图4所示。

  • 输入初始化的环境environment

  • 1.从环境获取当前状态state

  • 2.把当前state拿给智能体agent

  • 3.agent根据环境的状态输出采取的动作action

  • 4.action和环境进行交互,通过奖励函数reward function计算当前奖励

  • 5.奖励和状态更新给智能体agent

  • 然后重复上述1-5步骤,训练agent。【优化目标:把reward变大】

图4 有监督学习和强化学习对比示意图

由上面讲述可知,强化学习的基本组成主要由以下部分:

  • environment

  • agent

  • state

  • reward

  • action

  • policy: 策略。定义了agent如何根据当前的state来做出action。策略主要可以分为on-policy和off-policy。

  • On-policy: 学习到的agent以及和环境进行互动的agent是同一个agent ,比如PPO算法 (eg:你在打游戏,你在实战中变强。

  • Off-policy: 学习到的agent以及和环境进行互动的agent是不同的agent,比如DQN算法(eg: 你在看直播,你在观摩中变强。)

RLHF框架

RLHF方法最早是在是2017年论文(Deep reinforcement learning from human preferences)提出。

  • 在2020年的论文(Learning to summarize from human feedback)中RM训练使用了交叉熵损失。
  • 在2023年3月OpenAI发表的论文(Training language models to follow instructions with human feedback)中进一步提供了RLHF实现的标准范式(论文中训练的模型为InstructGPT,ChatGPT是改进后的InstructGPT,比如InstructGPT是基于GPT-3训练,而ChatGPT是基于GPT-3.5训练),如图5所示。

如果想了解InstructGPT论文的详细内容,可以参考我之前的知乎文章:Text-to-SQL小白入门(九)InstructGPT论文:教你如何训练ChatGPT

图5 InstructGPT论文中的RLHF实现范式

RLHF主要流程有3步:

第一阶段:SFT

  • Supervised Fine-tuning有监督微调,简称为SFT。这是InstructGPT(ChatGPT等)训练过程中的一个重要步骤,主要采用有监督的方式对与预训练的LLM进行微调,这个方法比较依赖于标注的数据,SFT数据集标注质量越高(质量不等同于数据),模型的效果越好。

之前听一个大学教授的讲座,有个观点很有意思:Open AI做大模型为什么比谷歌强,因为包括transformer在内的一些创新模型大多是谷歌研究的,那为什么Open AI在大模型领域为什么比谷歌强?答:因为Open AI在数据清洗,数据质量把控这方面做的很好。------所以数据是相当重要的!

第二阶段:RM

  • Reward Model奖励模型训练,是InstructGPT训练过程的第二阶段,它的目标是训练一个模型来适应人类的偏好(这里主要是标注人员的偏好)。在RM训练阶段,输入prompt,会使LLM生成多个响应response,然后标注人员对这些响应进行排名,根据这些排名训练一个奖励模型。

第三阶段:RL

  • Reinforcement Learning,是InstructGPT训练中的最后步骤,主要是通过PPO策略(proximal policy optimization 近端策略优化)迭代,它通过引入奖励信号来调整模型的行为,使模型生成的内容更符合人类的偏好。

  • 输入一个标注数据,模型经过PPO输出一个response

  • RM模型对response打分

  • 根据打分score更新PPO策略。

RLHF+Text2SQL的实践探索

本章节主要结合DB-GPT-Hub项目代码以及一些RLHF代码对Text2SQL进行了实践探索。

SFT

SFT模块的实现主要参考DB-GPT-Hub,比如在Spider数据集上进行实现。

数据预处理

bash 复制代码
sh dbgpt_hub/scripts/gen_train_eval_data.sh

经过数据预处理后,可以得到example_text2sql_train.json和example_text2sql_dev.json

数据格式

数据格式如下所示:

  • db_id-instruction-input-output-history

    { "db_id": "department_management", "instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n\n", "input": "###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:", "output": "SELECT count(*) FROM head WHERE age > 56", "history": [] }

  • 最终经过代码后会形成为这样的格式:prompt-output

    {"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}

训练

bash 复制代码
sh dbgpt_hub/scripts/train_sft.sh

训练的基础大模型为CodeLlama-13b-instruct,如果想了解该开源模型,可以参考论文讲解:Text-to-SQL小白入门(五)开源代码大模型Code Llama

训练的参数如下所示:

css 复制代码
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
    --model_name_or_path /home/model/CodeLlama-13B-Instruct \
    --do_train \
    --dataset example_text2sql_train \
    --max_source_length 2048 \
    --max_target_length 512 \
    --template llama2 \
    --finetuning_type lora \
    --lora_rank 64 \
    --lora_alpha 32 \
    --lora_target q_proj,v_proj \
    --output_dir dbgpt_hub/output/adapter/CodeLlama-13B-Instruct-lora \
    --overwrite_cache \
    --overwrite_output_dir \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --lr_scheduler_type cosine_with_restarts \
    --logging_steps 500 \
    --save_steps 2000 \
    --learning_rate 2e-4 \
    --num_train_epochs 8 \
    --plot_loss \
    --bf16

预测

bash 复制代码
sh dbgpt_hub/scripts/predict_sft.sh

预测完成后,会生成一个predict.sql文件,文件中存放了dev集合中1034个sql.

评估

测试的库为ts库

css 复制代码
 python dbgpt_hub/eval/evaluation.py --plug_value --input Your_model_pred_file

评估过程如下所示:会对每一个sql进行对比,对错误的sql进行打印输出展示。

最终对1034条sql验证完成后,可以得到EX、EM精度得分。

  • EX-0.746

其他模型的一些baseline分数也可以通过DB-GPT-Hub获取。

RM

RM模型训练的模型以SFT阶段的模型为基础,参考微软代码进行训练(Hub项目近期也会增加RLHF功能,敬请期待),自行构建了少量Text2SQL的RM训练数据集用于测试训练。

数据格式

数据格式如下所示:

  • prompy-chosen-rejected

  • chosen就是在SFT阶段的ground truth

  • rejected就是模型的错误输出结果

    {"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","chosen": "SELECT count(*) FROM head WHERE age > 56","rejected":"SELECT COUNT(head_name) FROM head WHERE age > 56;"}

训练

  • 比如训练10个epoch的训练结果如下:

    deepspeed --num_gpus= <math xmlns="http://www.w3.org/1998/Math/MathML"> n g p u m a i n . p y − − d a t a p a t h n_gpu \ main.py \ --data_path </math>ngpu main.py −−datapathdata_path

    --data_split 2,4,4

    --model_name_or_path <math xmlns="http://www.w3.org/1998/Math/MathML"> m o d e l n a m e o r p a t h − − p e r d e v i c e t r a i n b a t c h s i z e 8 − − p e r d e v i c e e v a l b a t c h s i z e 8 − − m a x s e q l e n 1024 − − l e a r n i n g r a t e 9.65 e − 6 − − w e i g h t d e c a y 0.1 − − n u m p a d d i n g a t b e g i n n i n g 0 − − n u m t r a i n e p o c h s 10 − − g r a d i e n t a c c u m u l a t i o n s t e p s 1 − − l r s c h e d u l e r t y p e c o s i n e − − n u m w a r m u p s t e p s 0 − − s e e d 1234 − − g r a d i e n t c h e c k p o i n t i n g − − z e r o s t a g e model_name_or_path \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 8 \ --max_seq_len 1024 \ --learning_rate 9.65e-6 \ --weight_decay 0.1 \ --num_padding_at_beginning 0 \ --num_train_epochs 10 \ --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --num_warmup_steps 0 \ --seed 1234 \ --gradient_checkpointing \ --zero_stage </math>modelnameorpath −−perdevicetrainbatchsize8 −−perdeviceevalbatchsize8 −−maxseqlen1024 −−learningrate9.65e−6 −−weightdecay0.1 −−numpaddingatbeginning0 −−numtrainepochs10 −−gradientaccumulationsteps1 −−lrschedulertypecosine −−numwarmupsteps0 −−seed1234 −−gradientcheckpointing −−zerostageZERO_STAGE

    --deepspeed

    --offload

    --lora_dim 128

    --lora_module_name "layers."

    --output_dir OUTPUT \ 2>&1 | tee OUTPUT/log.txt

结果

训练完成后,会在制定目前生成训练好的模型,比如有以下文件:

  • config.json
  • log.txt
  • pytorch_model.bin
  • tokenizer.model

RL

数据格式

RL阶段和SFT阶段的数据格式保持一致,以Text2SQL任务举例子,RL数据可以构造为(prompt,output}的二元组,如下所示:

  • prompt-otput

    {"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age > 56"}

训练

  • 训练参数

  • SFT模型即为上面训练的SFT模型

  • RM模型即为上面训练的RM模型

  • 训练10epoch

    deepspeed --master_port 12346 main.py

    --data_path <math xmlns="http://www.w3.org/1998/Math/MathML"> d a t a p a t h − − d a t a s p l i t 2 , 4 , 4 − − a c t o r m o d e l n a m e o r p a t h data_path \ --data_split 2,4,4 \ --actor_model_name_or_path </math>datapath −−datasplit2,4,4 −−actormodelnameorpathACTOR_MODEL_PATH

    --critic_model_name_or_path <math xmlns="http://www.w3.org/1998/Math/MathML"> C R I T I C M O D E L P A T H − − n u m p a d d i n g a t b e g i n n i n g 1 − − p e r d e v i c e g e n e r a t i o n b a t c h s i z e 8 − − p e r d e v i c e t r a i n i n g b a t c h s i z e 8 − − g e n e r a t i o n b a t c h e s 1 − − p p o e p o c h s 1 − − m a x a n s w e r s e q l e n 256 − − m a x p r o m p t s e q l e n 1024 − − a c t o r l e a r n i n g r a t e CRITIC_MODEL_PATH \ --num_padding_at_beginning 1 \ --per_device_generation_batch_size 8 \ --per_device_training_batch_size 8 \ --generation_batches 1 \ --ppo_epochs 1 \ --max_answer_seq_len 256 \ --max_prompt_seq_len 1024 \ --actor_learning_rate </math>CRITICMODELPATH −−numpaddingatbeginning1 −−perdevicegenerationbatchsize8 −−perdevicetrainingbatchsize8 −−generationbatches1 −−ppoepochs1 −−maxanswerseqlen256 −−maxpromptseqlen1024 −−actorlearningrate{Actor_Lr}

    --critic_learning_rate <math xmlns="http://www.w3.org/1998/Math/MathML"> C r i t i c L r − − a c t o r w e i g h t d e c a y 0.1 − − c r i t i c w e i g h t d e c a y 0.1 − − n u m t r a i n e p o c h s 10 − − l r s c h e d u l e r t y p e c o s i n e − − g r a d i e n t a c c u m u l a t i o n s t e p s 1 − − a c t o r g r a d i e n t c h e c k p o i n t i n g − − c r i t i c g r a d i e n t c h e c k p o i n t i n g − − o f f l o a d r e f e r e n c e m o d e l − − d i s a b l e a c t o r d r o p o u t − − n u m w a r m u p s t e p s 100 − − d e e p s p e e d − − s e e d 1234 − − a c t o r z e r o s t a g e {Critic_Lr} \ --actor_weight_decay 0.1 \ --critic_weight_decay 0.1 \ --num_train_epochs 10 \ --lr_scheduler_type cosine \ --gradient_accumulation_steps 1 \ --actor_gradient_checkpointing \ --critic_gradient_checkpointing \ --offload_reference_model \ --disable_actor_dropout \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --actor_zero_stage </math>CriticLr −−actorweightdecay0.1 −−criticweightdecay0.1 −−numtrainepochs10 −−lrschedulertypecosine −−gradientaccumulationsteps1 −−actorgradientcheckpointing −−criticgradientcheckpointing −−offloadreferencemodel −−disableactordropout −−numwarmupsteps100 −−deepspeed−−seed1234 −−actorzerostageACTOR_ZERO_STAGE

    --critic_zero_stage <math xmlns="http://www.w3.org/1998/Math/MathML"> C R I T I C Z E R O S T A G E − − e n a b l e h y b r i d e n g i n e − − a c t o r l o r a d i m 64 − − c r i t i c l o r a d i m 64 − − c r i t i c l o r a m o d u l e n a m e " l a y e r s . " − − a c t o r l o r a m o d u l e n a m e " l a y e r s . " − − o u t p u t d i r CRITIC_ZERO_STAGE \ --enable_hybrid_engine \ --actor_lora_dim 64 \ --critic_lora_dim 64 \ --critic_lora_module_name "layers." \ --actor_lora_module_name "layers." \ --output_dir </math>CRITICZEROSTAGE −−enablehybridengine −−actorloradim64 −−criticloradim64 −−criticloramodulename"layers." −−actorloramodulename"layers." −−outputdirOUTPUT

    2>&1 | tee $OUTPUT/log.txt

  • 训练结束

结果

训练结束会得到两个模型,actor模型即为需要的最终评测模型。

验证

  • 验证得到的模型
  • EX-0.752
  • EM-0.717

可以发现的是,RLHF相比SFT方法,精度有轻微提升,主要是数据质量的问题,后续还可以进一步探索。

其他文章学习

xt-to-SQL小白入门(一)综述文章学习

Text-to-SQL小白入门(二)Transformer学习

Text-to-SQL小白入门(三)IRNet:引入中间表示SemQL

Text-to-SQL小白入门(四)指令进化大模型WizardLM

Text-to-SQL小白入门(五)开源代码大模型Code Llama

Text-to-SQL小白入门(六)Awesome-Text2SQL项目介绍

Text-to-SQL小白入门(七)PanGu-Coder2论文------RRTF

Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

Text-to-SQL小白入门(九)InstructGPT论文:教你如何训练ChatGPT

相关推荐
不知几秋5 小时前
sqlilab-Less-18
sql
Amctwd6 小时前
【SQL】如何在 SQL 中统计结构化字符串的特征频率
数据库·sql
有梦想的攻城狮10 小时前
大语言模型与多模态模型比较
人工智能·语言模型·自然语言处理·llm·大语言模型
lqlj223312 小时前
Spark SQL 读取 CSV 文件,并将数据写入 MySQL 数据库
数据库·sql·spark
遗憾皆是温柔13 小时前
MyBatis—动态 SQL
java·数据库·ide·sql·mybatis
未来之窗软件服务13 小时前
Cacti 未经身份验证SQL注入漏洞
android·数据库·sql·服务器安全
_星辰大海乀15 小时前
表的设计、聚合函数
java·数据结构·数据库·sql·mysql·数据库开发
幸福回头19 小时前
ms-swift 代码推理数据集
llm·swift
亚里随笔19 小时前
AlphaEvolve:LLM驱动的算法进化革命与科学发现新范式
人工智能·算法·llm·大语言模型
tebukaopu14821 小时前
官方 Elasticsearch SQL NLPChina Elasticsearch SQL
大数据·sql·elasticsearch