深度学习系列75:sql大模型工具vanna

1. 概述

vanna是一个可以将自然语言转为sql的工具。简单的demo如下:

复制代码
!pip install vanna
import vanna
from vanna.remote import VannaDefault
vn = VannaDefault(model='chinook', api_key=vanna.get_api_key('my-email@example.com'))
vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
vn.ask("What are the top 10 albums by sales?")

执行下面的代码运行图形界面

复制代码
from vanna.flask import VannaFlaskApp
VannaFlaskApp(vn).run()

2. 配置

数据库可以是任何数据库,比如mysql如下:

复制代码
import pandas as pd
import psycopg2

def run_sql(sql):
    conn = psycopg2.connect(
        host="localhost",
        database="my_database",
        user="my_user",
        password="my_password"
    )
    return pd.read_sql(sql, conn)

vn.run_sql = run_sql
vn.run_sql_is_set = True

向量数据库稍微麻烦一些,目前支持的包括:

参考代码如下:

复制代码
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
class MyVanna(ChromaDB_VectorStore):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)

vn = MyVanna(config={'path': '/path/to/chromadb'})

3. 训练

训练数据可以是:DDL、documentation、sql以及Question-SQL Pairs

复制代码
vn.train(ddl="CREATE TABLE my_table (id INT, name TEXT)")
vn.train(documentation="Our business defines XYZ as ABC")
vn.train(sql="SELECT col1, col2, col3 FROM my_table")

可以设置auto_train = True

4. 询问

复制代码
vn.ask("What are the top 10 customers by sales?")

它包含下列几个函数:

复制代码
vn.generate_sql
vn.run_sql
vn.generate_plotly_code
vn.get_plotly_figure

visualize=False

5. 启用服务

参考https://github.com/vanna-ai/vanna-flask,将LLM、embedding、vectorStore都改造成自己的代码。

首先是LLM,改造框架为:

复制代码
from vanna.base import VannaBase
class MyLLM(VannaBase):
    def __init__(self,config=None):
        VannaBase.__init__(self, config=config)
        ...
   def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}

    def submit_prompt(self, prompt, **kwargs) -> str:
    	...

然后是embedding,需要定义encode_documents和encode_queries两个函数,例如:

复制代码
class BgeM3:
    def __init__(self, url):
        self.url = url
    def encode_documents(self, docs):
        ....
    def encode_queries(self, queries):
        ....

接下来是vectorStore,我们使用milvus,它会自动调用config中的embedding_function,我们把它定义成上面的BegM3即可:

复制代码
class MyVanna(Milvus_VectorStore, QwenLLM):
    def __init__(self, config=None):
        Milvus_VectorStore.__init__(self, config=config)
        QwenLLM.__init__(self, config=config)

vn = MyVanna(config={'milvus_client': MilvusClient(...),'embedding_function':BgeM3(...)})

然后定义连接的数据库,可以换成任意的其他数据库:

复制代码
def run_sql(sql: str) -> pd.DataFrame:
    cnx = mysql.connector.connect(...)
    cursor = cnx.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    columns = cursor.column_names
    df = pd.DataFrame(result, columns=columns)
    return df
    
vn.run_sql = run_sql
vn.run_sql_is_set = True 

接着执行python app.py即可启用服务,访问localhost:5000可以打开页面:

同时也可以调用接口:

复制代码
import requests
response = requests.get(url+'/api/v0/get_training_data',headers={'Content-Type':'application/json'})
response.json()

所有可用的接口清单可以参考app.py

相关推荐
九年义务漏网鲨鱼29 分钟前
BLIP2 工业实战(一):从零实现 LAVIS 跌倒检测 (微调与“踩坑”指南)
人工智能·pytorch·深度学习·语言模型
CoookeCola2 小时前
开源图像与视频过曝检测工具:HSV色彩空间分析与时序平滑处理技术详解
人工智能·深度学习·算法·目标检测·计算机视觉·开源·音视频
CoovallyAIHub2 小时前
万字详解:多目标跟踪(MOT)终极指南
深度学习·算法·计算机视觉
java1234_小锋2 小时前
PyTorch2 Python深度学习 - 初识PyTorch2,实现一个简单的线性神经网络
开发语言·python·深度学习·pytorch2
高洁012 小时前
大模型-模型压缩:量化、剪枝、蒸馏、二值化 (4)
人工智能·python·深度学习·aigc·transformer
CoovallyAIHub3 小时前
Arm重磅加码边缘AI!Flexible Access开放v9平台,实现高端算力普惠
深度学习·算法·计算机视觉
小白狮ww4 小时前
dots.ocr 基于 1.7B 参数实现多语言文档处理,性能达 SOTA
人工智能·深度学习·机器学习·自然语言处理·ocr·小红书·文档处理
无风听海13 小时前
神经网络之窗口大小对词语义向量的影响
人工智能·深度学习·神经网络
Tiandaren14 小时前
自用提示词02 || Prompt Engineering || RAG数据切分 || 作用:通过LLM将文档切分成chunks
数据库·pytorch·深度学习·oracle·prompt·rag
阿水实证通15 小时前
面向社科研究者:用深度学习做因果推断(二)
深度学习·1024程序员节·因果推断·实证分析·科研创新