LLM微调(四)| 微调Llama 2实现Text-to-SQL,并使用LlamaIndex在数据库上进行推理

Llama 2是开源LLM发展的一个巨大里程碑。最大模型及其经过微调的变体位居Hugging Face Open LLM排行榜(https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)前列。多个基准测试表明,就性能而言,它正在接近GPT-3.5(在某些情况下甚至超过它)。所有这些都意味着,对于从RAG系统到Agent的复杂LLM应用程序,开源LLM是一种越来越可行和可靠的选择。

一、Llama-2--7B不擅长从文本到SQL

最小的Llama 2模型(7B参数)有一个缺点是它不太擅长生成SQL,因此它不适用于结构化分析示例。例如,我们尝试在给定以下提示模板的情况下提示Llama 2生成正确的SQL语句:

复制代码
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. ​You must output the SQL query that answers the question.​### Input:{input}​### Context:{context}​### Response:

在这里,我们使用sqlcreatecontext数据集(https://huggingface.co/datasets/b-mc2/sql-create-context)的一个示例来测试一下效果:

复制代码
input: In 1981 which team picked overall 148?context: CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)

同时,这里是生成的输出与正确输出的对比:

复制代码
Generated output: SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;​Correct output: SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"

这显然并不理想。与ChatGPT和GPT-4不同,原始的Llama 2不能生成期望的的格式和正确的SQL

这正是微调的作用所在------如果有一个合适的文本到SQL数据的语料库,我们可以教Llama 2更好地从自然语言生成SQL输出。微调有不同的方法,可以更新模型的所有参数(比如:全量微调),也可以冻结大模型参数仅微调附加参数(比如:LoRA)。

二、微调Ll**** ama-2--7B,使其可以从文本生成SQL

接下来,我们将展示如何在文本到SQL数据集上微调Llama 2,然后使用LlamaIndex的功能对任何SQL数据库进行结构化分析。

准备工作:

微调数据集:来自Hugging Face的b-mc2/sql-create-context(https://huggingface.co/datasets/b-mc2/sql-create-context)

base模型:OpenLLaMa 的open_lama_7b_v2(https://github.com/openlm-research/open_llama)

步骤1:加载微调LLaMa的训练数据

PS:1)以下代码来自doppel-bot:https://github.com/modal-labs/doppel-bot;2)许多Python代码都包含在src目录中;3)需要设置一个Modal帐户,并生成token。

复制代码
!pip install -r requirements.txt

首先,我们使用Modal加载b-mc2/sql-create-context数据集,并将其格式化为.jsonl文件。

复制代码
modal run src.load_data_sql --data-dir "data_sql"

结果如下所示:

复制代码
# Modal stubs allow our function to run [email protected](    retries=Retries(        max_retries=3,        initial_delay=5.0,        backoff_coefficient=2.0,    ),    timeout=60 * 60 * 2,    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},    cloud="gcp",)def load_data_sql(data_dir: str = "data_sql"):    from datasets import load_dataset​    dataset = load_dataset("b-mc2/sql-create-context")​    dataset_splits = {"train": dataset["train"]}    out_path = get_data_path(data_dir)​    out_path.parent.mkdir(parents=True, exist_ok=True)​    for key, ds in dataset_splits.items():        with open(out_path, "w") as f:            for item in ds:                newitem = {                    "input": item["question"],                    "context": item["context"],                    "output": item["answer"],                }                f.write(json.dumps(newitem) + "\n")

步骤2:运行微调脚本

在微调数据集微调llama2模型,代码如下:

复制代码
modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"

微调脚本会执行以下步骤:

将数据集拆分为训练和验证拆分

复制代码
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

将每个拆分为元组的格式(输入Prompt、标签):输入query和上下文被格式化为输入Prompt,然后对输入Prompt和标签进行 tokenize,模型采用自回归的方法预测下一个token来进行训练。

复制代码
def generate_and_tokenize_prompt(data_point):  full_prompt = generate_prompt_sql(      data_point["input"],      data_point["context"],      data_point["output"],  )  tokenized_full_prompt = tokenize(full_prompt)  if not train_on_inputs:      raise NotImplementedError("not implemented yet")  return tokenized_full_prompt

PS:输入Prompt与开始测试llama2的格式完全相同。

运行微调脚本时,模型将保存在model_dir指定的远程云目录中(如果未指定,则设置为默认值)。

步骤3:评估微调后模型

该模型已经进行了微调,可以从云端提供服务。下面我们使用b-mc2/sql-create-context中的示例数据进行一些基本评估,比较微调后模型与原始Llama 2模型的性能。

复制代码
modal run src.eval_sql::main

结果表明,微调后的模型有了巨大的改进:

复制代码
Input 1: {'input': 'Which region (year) has Abigail at number 7, Sophia at number 1 and Aaliyah at number 5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" ANDno_5 = "aaliyah"'}Output 1 (finetuned model): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "aaliyah" AND no_5 = "sophia"Output 1 (base model): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';​​Input 2: {'input': 'Name the result/games for 54741', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}Output 2 (finetuned model): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"Output 2 (base model): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;

步骤4:将微调模型与LlamaIndex集成

我们现在可以在LlamaIndex中使用这个模型,在任何数据库上进行文本到SQL。

我们首先定义一个测试SQL数据库,然后可以使用该数据库来测试模型的推理能力。

我们创建了一个玩具city_stats表,其中包含城市名称、人口和国家信息,并用几个示例城市填充它。

复制代码
db_file = "cities.db"engine = create_engine(f"sqlite:///{db_file}")metadata_obj = MetaData()# create city SQL tabletable_name = "city_stats"city_stats_table = Table(    table_name,    metadata_obj,    Column("city_name", String(16), primary_key=True),    Column("population", Integer),    Column("country", String(16), nullable=False),)metadata_obj.create_all(engine)

这存储在cities.db文件中。

然后,我们可以使用Modal将微调后的模型和该数据库文件加载到LlamaIndex中的NLSQLTableQueryEngine中------该查询引擎允许用户轻松地开始在给定的数据库上执行文本到SQL。

复制代码
modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True

我们得到如下回复:

复制代码
SQL Query: SELECT MAX(population) FROM city_stats WHERE country = "United States"Response: [(2679000,)]

三、结论

本文提供了一种非常高级的方法来开始微调生成SQL语句的Llama 2模型,并展示了如何使用LlamaIndex将其端到端插入到文本到SQL工作流中。

参考文献:

1\] https://blog.llamaindex.ai/easily-finetune-llama-2-for-your-text-to-sql-applications-ecd53640e10d \[2\] https://github.com/run-llama/modal_finetune_sql \[3\] https://github.com/run-llama/modal_finetune_sql/blob/main/tutorial.ipynb

相关推荐
傻啦嘿哟几秒前
Python 数据分析与可视化实战:从数据清洗到图表呈现
大数据·数据库·人工智能
cookqq27 分钟前
mongodb源码分析session异步接受asyncSourceMessage()客户端流变Message对象
数据库·sql·mongodb·nosql
呼拉拉呼拉39 分钟前
Redis故障转移
数据库·redis·缓存·高可用架构
什么都想学的阿超42 分钟前
【Redis系列 04】Redis高可用架构实战:主从复制与哨兵模式从零到生产
数据库·redis·架构
pp-周子晗(努力赶上课程进度版)1 小时前
【MySQL】视图、用户管理、MySQL使用C\C++连接
数据库·mysql
斯特凡今天也很帅1 小时前
clickhouse常用语句汇总——持续更新中
数据库·sql·clickhouse
超级小忍3 小时前
如何配置 MySQL 允许远程连接
数据库·mysql·adb
吹牛不交税3 小时前
sqlsugar WhereIF条件的大于等于和等于查出来的坑
数据库·mysql
hshpy3 小时前
setting up Activiti BPMN Workflow Engine with Spring Boot
数据库·spring boot·后端
文牧之4 小时前
Oracle 审计参数:AUDIT_TRAIL 和 AUDIT_SYS_OPERATIONS
运维·数据库·oracle