Llama 2是开源LLM发展的一个巨大里程碑。最大模型及其经过微调的变体位居Hugging Face Open LLM排行榜(https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)前列。多个基准测试表明,就性能而言,它正在接近GPT-3.5(在某些情况下甚至超过它)。所有这些都意味着,对于从RAG系统到Agent的复杂LLM应用程序,开源LLM是一种越来越可行和可靠的选择。
一、Llama-2--7B不擅长从文本到SQL
最小的Llama 2模型(7B参数)有一个缺点是它不太擅长生成SQL,因此它不适用于结构化分析示例。例如,我们尝试在给定以下提示模板的情况下提示Llama 2生成正确的SQL语句:
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.### Input:{input}### Context:{context}### Response:
在这里,我们使用sqlcreatecontext数据集(https://huggingface.co/datasets/b-mc2/sql-create-context)的一个示例来测试一下效果:
input: In 1981 which team picked overall 148?context: CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)
同时,这里是生成的输出与正确输出的对比:
Generated output: SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;Correct output: SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"
这显然并不理想。与ChatGPT和GPT-4不同,原始的Llama 2不能生成期望的的格式和正确的SQL。
这正是微调的作用所在------如果有一个合适的文本到SQL数据的语料库,我们可以教Llama 2更好地从自然语言生成SQL输出。微调有不同的方法,可以更新模型的所有参数(比如:全量微调),也可以冻结大模型参数仅微调附加参数(比如:LoRA)。
二、微调Ll**** ama-2--7B,使其可以从文本生成SQL
接下来,我们将展示如何在文本到SQL数据集上微调Llama 2,然后使用LlamaIndex的功能对任何SQL数据库进行结构化分析。
准备工作:
微调数据集:来自Hugging Face的b-mc2/sql-create-context(https://huggingface.co/datasets/b-mc2/sql-create-context)
base模型:OpenLLaMa 的open_lama_7b_v2(https://github.com/openlm-research/open_llama)
步骤1:加载微调LLaMa的训练数据
PS:1)以下代码来自doppel-bot:https://github.com/modal-labs/doppel-bot;2)许多Python代码都包含在src目录中;3)需要设置一个Modal帐户,并生成token。
!pip install -r requirements.txt
首先,我们使用Modal加载b-mc2/sql-create-context数据集,并将其格式化为.jsonl文件。
modal run src.load_data_sql --data-dir "data_sql"
结果如下所示:
# Modal stubs allow our function to run remotely@stub.function( retries=Retries( max_retries=3, initial_delay=5.0, backoff_coefficient=2.0, ), timeout=60 * 60 * 2, network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol}, cloud="gcp",)def load_data_sql(data_dir: str = "data_sql"): from datasets import load_dataset dataset = load_dataset("b-mc2/sql-create-context") dataset_splits = {"train": dataset["train"]} out_path = get_data_path(data_dir) out_path.parent.mkdir(parents=True, exist_ok=True) for key, ds in dataset_splits.items(): with open(out_path, "w") as f: for item in ds: newitem = { "input": item["question"], "context": item["context"], "output": item["answer"], } f.write(json.dumps(newitem) + "\n")
步骤2:运行微调脚本
在微调数据集微调llama2模型,代码如下:
modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"
微调脚本会执行以下步骤:
将数据集拆分为训练和验证拆分
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
将每个拆分为元组的格式(输入Prompt、标签):输入query和上下文被格式化为输入Prompt,然后对输入Prompt和标签进行 tokenize,模型采用自回归的方法预测下一个token来进行训练。
def generate_and_tokenize_prompt(data_point): full_prompt = generate_prompt_sql( data_point["input"], data_point["context"], data_point["output"], ) tokenized_full_prompt = tokenize(full_prompt) if not train_on_inputs: raise NotImplementedError("not implemented yet") return tokenized_full_prompt
PS:输入Prompt与开始测试llama2的格式完全相同。
运行微调脚本时,模型将保存在model_dir指定的远程云目录中(如果未指定,则设置为默认值)。
步骤3:评估微调后模型
该模型已经进行了微调,可以从云端提供服务。下面我们使用b-mc2/sql-create-context中的示例数据进行一些基本评估,比较微调后模型与原始Llama 2模型的性能。
modal run src.eval_sql::main
结果表明,微调后的模型有了巨大的改进:
Input 1: {'input': 'Which region (year) has Abigail at number 7, Sophia at number 1 and Aaliyah at number 5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" ANDno_5 = "aaliyah"'}Output 1 (finetuned model): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "aaliyah" AND no_5 = "sophia"Output 1 (base model): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';Input 2: {'input': 'Name the result/games for 54741', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}Output 2 (finetuned model): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"Output 2 (base model): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;
步骤4:将微调模型与LlamaIndex集成
我们现在可以在LlamaIndex中使用这个模型,在任何数据库上进行文本到SQL。
我们首先定义一个测试SQL数据库,然后可以使用该数据库来测试模型的推理能力。
我们创建了一个玩具city_stats表,其中包含城市名称、人口和国家信息,并用几个示例城市填充它。
db_file = "cities.db"engine = create_engine(f"sqlite:///{db_file}")metadata_obj = MetaData()# create city SQL tabletable_name = "city_stats"city_stats_table = Table( table_name, metadata_obj, Column("city_name", String(16), primary_key=True), Column("population", Integer), Column("country", String(16), nullable=False),)metadata_obj.create_all(engine)
这存储在cities.db文件中。
然后,我们可以使用Modal将微调后的模型和该数据库文件加载到LlamaIndex中的NLSQLTableQueryEngine中------该查询引擎允许用户轻松地开始在给定的数据库上执行文本到SQL。
modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True
我们得到如下回复:
SQL Query: SELECT MAX(population) FROM city_stats WHERE country = "United States"Response: [(2679000,)]
三、结论
本文提供了一种非常高级的方法来开始微调生成SQL语句的Llama 2模型,并展示了如何使用LlamaIndex将其端到端插入到文本到SQL工作流中。
参考文献:
[1] https://blog.llamaindex.ai/easily-finetune-llama-2-for-your-text-to-sql-applications-ecd53640e10d
[2] https://github.com/run-llama/modal_finetune_sql
[3] https://github.com/run-llama/modal_finetune_sql/blob/main/tutorial.ipynb