Langgraph项目三 agent搭建

前面完成了元数据库的初始化同步,现在开始根据流程搭建agent。

这是对应的流程,将其转为node节点。

构建graph图

python 复制代码
# 构建图
graph_builder = StateGraph(state_schema=DataAgentState, context_schema=DataAgentContext)

# 增加点
graph_builder.add_node("extract_keywords", extract_keywords)
graph_builder.add_node("recall_column", recall_column)
graph_builder.add_node("recall_value", recall_value)
graph_builder.add_node("recall_metric", recall_metric)
graph_builder.add_node("merge_retrieved_info", merge_retrieved_info)
graph_builder.add_node("filter_metric", filter_metric)
graph_builder.add_node("filter_table", filter_table)
graph_builder.add_node("add_extra_context", add_extra_context)
graph_builder.add_node("generate_sql", generate_sql)
graph_builder.add_node("validate_sql", validate_sql)
graph_builder.add_node("correct_sql", correct_sql)
graph_builder.add_node("run_sql", run_sql)

# 添加关系
graph_builder.add_edge(START, "extract_keywords")
graph_builder.add_edge("extract_keywords", "recall_column")
graph_builder.add_edge("extract_keywords", "recall_value")
graph_builder.add_edge("extract_keywords", "recall_metric")
graph_builder.add_edge("recall_column", "merge_retrieved_info")
graph_builder.add_edge("recall_value", "merge_retrieved_info")
graph_builder.add_edge("recall_metric", "merge_retrieved_info")

graph_builder.add_edge("merge_retrieved_info", "filter_table")
graph_builder.add_edge("merge_retrieved_info", "filter_metric")

graph_builder.add_edge("filter_table", "add_extra_context")
graph_builder.add_edge("filter_metric", "add_extra_context")

graph_builder.add_edge("add_extra_context", "generate_sql")
graph_builder.add_edge("generate_sql", "validate_sql")

# validate_sql到run_sql 和 validate_sql到correct_sql 已经由Command指定了
# 如果没有Command,就需要指定条件边
graph_builder.add_conditional_edges("validate_sql",
                                    lambda state: "run_sql" if state["error"] is None else "correct_sql", {
                                        "run_sql": "run_sql",
                                        "correct_sql": "correct_sql",
                                    })

graph_builder.add_edge("correct_sql", "run_sql")
graph_builder.add_edge("run_sql", END)

app = graph_builder.compile()

if __name__ == "__main__":
    # 打印图结构(Mermaid 格式)
    async def test():
        state = DataAgentState(query="统计去年各地区的销售总额")
        embedding_client_manager.init()
        print(app.get_graph().draw_mermaid())

        embedding_client = embedding_client_manager.client
        context = DataAgentContext(embedding_client=embedding_client)
        async for chunk in app.astream(input=state, context=context):
            print(chunk)


    asyncio.run(test())

校验sql节点的流向有两种做法,一种是add_conditional_edges,条件边处理,一种是在节点里面使用Command对象,手动指定下一个节点流向何处。

可以看到,完全满足需求。

测试
python 复制代码
if __name__ == "__main__":
    # 打印图结构(Mermaid 格式)
    async def test():
        state = DataAgentState(query="统计去年各地区的销售总额")
        dw_mysql_client_manager.init()
        embedding_client_manager.init()
        meta_mysql_client_manager.init()
        qdrant_client_manager.init()
        es_client_manager.init()

        async with meta_mysql_client_manager.session_factory() as meta_session, dw_mysql_client_manager.session_factory() as dw_session:
            # 这里操作数据库依然是repository层
            meta_mysql_repository = MetaMySQLRepository(session=meta_session)
            embedding_client = embedding_client_manager.client
            column_qdrant_repository = ColumnQdrantRepository(client=qdrant_client_manager.client)
            metric_qdrant_repository = MetricQdrantRepository(client=qdrant_client_manager.client)
            value_es_repository = ValueEsRepository(client=es_client_manager.client)
            dw_mysql_repository = DwMySQLRepository(session=dw_session)
            context = DataAgentContext(embedding_client=embedding_client, meta_mysql_repository=meta_mysql_repository,
                                       column_qdrant_repository=column_qdrant_repository,
                                       metric_qdrant_repository=metric_qdrant_repository,
                                       value_es_repository=value_es_repository,
                                       dw_mysql_repository=dw_mysql_repository)
            # stream_mode 打印自定义输出
            async for chunk in app.astream(input=state, context=context, stream_mode=["custom"]):
                print(chunk)

        await meta_mysql_client_manager.close()
        await qdrant_client_manager.close()
        await es_client_manager.close()
        await dw_mysql_client_manager.close()


    asyncio.run(test())

这里我们要初始化repository这些,然后将其作为上下文传入。流式返回我们使用custom格式化的数据,自己定义数据返回。

节点编写

抽取关键词

第一个节点:抽取关键词,这里使用了jieba中文分词器。

我们使用基于TF-IDF算法的关键词抽取

python 复制代码
async def extract_keywords(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "抽取关键词"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))
        query = state["query"]
        # 对查询进行分词,只提取指定词性的词
        allow_pos = (
            "n",  # 名词: 数据、服务器、表格
            "nr",  # 人名: 张三、李四
            "ns",  # 地名: 北京、上海
            "nt",  # 机构团体名: 政府、学校、某公司
            "nz",  # 其他专有名词: Unicode、哈希算法、诺贝尔奖
            "v",  # 动词: 运行、开发
            "vn",  # 名动词: 工作、研究
            "a",  # 形容词: 美丽、快速
            "an",  # 名形词: 难度、合法性、复杂度
            "eng",  # 英文
            "i",  # 成语
            "l",  # 常用固定短语
        )
        # 基于TF-IDF算法的关键词抽取
        keywords = jieba.analyse.extract_tags(query, withWeight=False, topK=20, allowPOS=allow_pos)
        # 去重
        keywords = list(set(keywords + [query]))
        writer(StreamInfo(type="progress", step=step, status="success"))
        logger.info(f"抽取的关键词是 {keywords}")
        return {"keywords": keywords}
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"抽取关键词报错: {str(e)}")
        # 需要抛错误去
        raise e

对query进行分词得到关键词列表,去重。

可以看到正常处理。

召回字段信息

召回字段信息,主要就是根据对关键词进行向量化,然后去column_qdrant向量数据集中检索,得到相似度较高的数据,取其payload,也就是ColumnInfo,然后返回,最终得到一组list[ColumnInfo]

python 复制代码
# 根据关键词向量检索召回,获取得到的ColumnInfo数组
async def recall_column(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "召回字段信息"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        keywords = state["keywords"]
        query = state["query"]

        embedding_client = runtime.context["embedding_client"]
        column_qdrant_repository = runtime.context["column_qdrant_repository"]

        # 使用llm扩展关键词
        prompt = PromptTemplate(template=load_prompt_template("extend_keywords_for_column_recall"),
                                input_variables=["query"])
        output_str = JsonOutputParser()

        chain = prompt | llm | output_str

        async with llm_semaphore:
            result = await chain.ainvoke({"query": query})

        keywords = list(set(result + keywords))

        # 使用扩展后的关键词召回字段信息
        logger.info(f"召回信息扩展关键词: {keywords}")

        # 去重
        retrieved_columns_map: dict[str, ColumnInfo] = {}
        for keyword in keywords:
            # 转为向量,检索
            embeddings = await embedding_client.aembed_query(keyword)
            payloads: list[ColumnInfo] = await column_qdrant_repository.search(embeddings=embeddings)
            for payload in payloads:
                column_id = payload.id
                if column_id not in retrieved_columns_map:
                    retrieved_columns_map[column_id] = payload
        # 汇总 values返回的是远足
        retrieved_columns = list(retrieved_columns_map.values())

        logger.info(f"召回字段信息: {list(retrieved_columns_map.keys())}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "retrieved_columns": retrieved_columns,
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"召回字段信息报错: {str(e)}")
        raise e

这里我们得到第一步抽取的关键词,然后再通过llm对问题再度进行关键词抽取,

这里的prompt大概是:

让其返回数组,然后通过链组成一个chain。得到关键词后进行去重。

接着,遍历关键词列表,对每一个关键词进行向量化,得到向量数组,然后调用column_qdrant_client.search

python 复制代码
 async def search(self, embeddings: list[float], score_threshold: float = 0.6, limit=5):
        result = await self.client.query_points(
            collection_name=self.collection_name,
            query=embeddings,
            # 相似度
            score_threshold=score_threshold,
            limit=limit,
        )
        # column向量检索数据存储 id,embeddings,payload,payload转为ColumnInfo格式
        return [ColumnInfo(**point.payload) for point in result.points]

通过query_points检索,score_threshold是相似度分数,这里大于0.6我们就返回,默认限制返回5条,将返回的数据取其payload然后拼成数组返回。

最后对column去重。得到

如上,关键词通过llm进行了扩展,然后通过向量检索得到了四个字段。

召回指标信息

跟召回字段信息类似

python 复制代码
async def recall_metric(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "召回指标信息"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        query = state["query"]
        keywords = state["keywords"]

        embedding_client = runtime.context["embedding_client"]
        metric_qdrant_repository = runtime.context["metric_qdrant_repository"]

        prompt = PromptTemplate(template=load_prompt_template("extend_keywords_for_metric_recall"),
                                input_variables=["query"])

        output_str = JsonOutputParser()

        chain = prompt | llm | output_str

        # 扩展关键词
        async with llm_semaphore:
            result = await chain.ainvoke({"query": query})

        keywords = list(set(result + keywords))

        retrieved_metric_map: dict[str, MetricInfo] = {}
        for keyword in keywords:
            # 对关键词进行向量检索
            embeddings = await embedding_client.aembed_query(keyword)
            payloads: list[MetricInfo] = await metric_qdrant_repository.search(embeddings=embeddings)
            for payload in payloads:
                metric_id = payload.id
                if metric_id not in retrieved_metric_map:
                    retrieved_metric_map[metric_id] = payload

        retrieved_metric_infos = list(retrieved_metric_map.values())

        logger.info(f"召回指标信息: {list(retrieved_metric_map.keys())}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "retrieved_metrics": retrieved_metric_infos
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"召回指标信息报错: {str(e)}")
        raise e

先通过llm扩展关键词,然后向量召回,最后整合返回。

召回指标值数据

跟上面两个节点类似,只不过不是向量检索,是es检索

python 复制代码
async def recall_value(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "召回字段取值"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        query = state["query"]
        keywords = state["keywords"]

        value_es_repository = runtime.context["value_es_repository"]

        prompt = PromptTemplate(template=load_prompt_template("extend_keywords_for_value_recall"),
                                input_variables=["query"])

        output_str = JsonOutputParser()

        chain = prompt | llm | output_str

        # 扩展关键词
        async with llm_semaphore:
            result = await chain.ainvoke({"query": query})

        keywords = list(set(result + keywords))

        logger.info(f"keywords==: {keywords}")
        values_map: dict[str, ValueInfo] = {}
        for keyword in keywords:
            values = await value_es_repository.search(keyword=keyword)
            for value in values:
                if value.id not in values_map:
                    values_map[value.id] = value
        retrieved_values: list[ValueInfo] = list(values_map.values())

        logger.info(f"召回字段取值:{list(values_map.keys())}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "retrieved_values": retrieved_values
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"召回字段取值报错: {str(e)}")
        raise e


 async def search(self, keyword: str, limit=5, score_threshold: float = 0.6):
        # es 存储的数据是 id value column_id
        resp = await self.client.search(
            index=self.index_name,
            query={
                "match": {
                    "value": keyword,
                }
            },
            size=limit,
            min_score=score_threshold,
        )
        return [ValueInfo(**hit["_source"]) for hit in resp["hits"]["hits"]]

通过llm扩展后,根据es搜索召回

合并召回信息

合并召回信息,主要是将得到的column_infos,指标信息metric_infos,以及es取到的value_infos

python 复制代码
# 合并找回信息,首先,将得到的ColumnInfo,ValueInfo数组,合并为table_infos表信息,将MetricInfo合并为metric_infos表信息
# columnInfo 有多少表格,然后整合,valueInfo,判断是哪个表格哪个字段的,将value放到examples中,如果没有的话,就得新增table或者column,metricInfo,GMV/AOV都有对应的column.id,也需要同步
# metricInfo整合得到对应元数据库中metric_info的数据
async def merge_retrieved_info(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "合并召回信息"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        retrieved_columns = state["retrieved_columns"]
        retrieved_metrics = state["retrieved_metrics"]
        retrieved_values = state["retrieved_values"]

        meta_mysql_repository = runtime.context["meta_mysql_repository"]

        retrieved_columns_map = {retrieved_column.id: retrieved_column for retrieved_column in retrieved_columns}

        # 合并指标信息到column中
        for retrieved_metric in retrieved_metrics:
            for column_id in retrieved_metric.relevant_columns:
                if column_id not in retrieved_columns_map:
                    # 不存在已有的column列表,需要添加,meta元数据库的column_info的id就是表名.列名
                    column_info = await meta_mysql_repository.get_column_info_by_id(column_id)
                    retrieved_columns_map[retrieved_metric.id] = column_info

        # 将字段取值同步到字段信息列表中examples字段
        for retrieved_value in retrieved_values:
            column_id = retrieved_value.column_id
            if column_id not in retrieved_columns_map:
                column_info = await meta_mysql_repository.get_column_info_by_id(column_id)
                retrieved_columns_map[column_id] = column_info
            if retrieved_value.value not in retrieved_columns_map[column_id].examples:
                retrieved_columns_map[column_id].examples.append(retrieved_value.value)

        # 整合column为table_infos
        table_to_column_map: dict[str, list[ColumnInfo]] = {}
        for column in retrieved_columns_map.values():
            table_id = column.table_id
            if table_id not in table_to_column_map:
                table_to_column_map[table_id] = []
            table_to_column_map[table_id].append(column)

        # 每个表的主外键 描述可能不够清晰,不可以很好的召回,给用到的每个表都显式添加主外建信息,后面过滤信息的时候再过滤
        for table_id in table_to_column_map.keys():
            # 查询主键字段
            column_infos = await meta_mysql_repository.get_key_columns_by_table_id(table_id=table_id)

            # 当前表已有的所有列的ID
            column_ids = [column.id for column in table_to_column_map[table_id]]

            for column_info in column_infos:
                if column_info.id not in column_ids:  # 主外建不存在
                    table_to_column_map[table_id].append(column_info)

        table_infos: list[TableInfoGraphState] = []
        # 将table_id -> columnInfo 映射,专程我们的 list[TableInfoGraphState]表数组数据
        for table_id, column_infos in table_to_column_map.items():
            table: TableInfo = await meta_mysql_repository.get_table_info_by_id(table_id=table_id)
            # 收集table的columns
            columns = [
                # 不需要id和table_id
                ColumnInfoGraphState(
                    name=column.name,
                    type=column.type,
                    role=column.role,
                    examples=column.examples,
                    description=column.description,
                    alias=column.alias,
                )
                for column in column_infos
            ]
            table_infos.append(
                # 不需要id和新增columns列表
                TableInfoGraphState(
                    columns=columns,
                    name=table.name,
                    role=table.role,
                    description=table.description,
                )
            )
            # 至此,table_infos的数据整合完毕

        # 整合指标信息
        metric_infos: list[MetricInfoGraphState] = [
            MetricInfoGraphState(
                name=metric_info.name,
                description=metric_info.description,
                relevant_columns=metric_info.relevant_columns,
                alias=metric_info.alias
            ) for metric_info in retrieved_metrics
        ]

        logger.info(f"合并召回信息,收集到表数据: {[table['name'] for table in table_infos]}")
        logger.info(f"合并召回信息,收集到指标表数据: {[metric_info['name'] for metric_info in metric_infos]}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "table_infos": table_infos,
            "metric_infos": metric_infos,
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"合并召回信息报错: {str(e)}")
        raise e

其中,主外键的描述可能不够清晰,会导致召回不到,所以这里先显式加上所有的主外键,后续过滤的时候过滤掉没用的。

过滤表信息节点

借助LLM将合并到的数据进行过滤。

python 复制代码
async def filter_table(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "过滤表格信息"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        table_infos = state["table_infos"]
        query = state["query"]

        prompt = PromptTemplate(template=load_prompt_template("filter_table_info"),
                                input_variables=["query", "table_infos"])

        output_parser = JsonOutputParser()

        chain = prompt | llm | output_parser

        async with llm_semaphore:
            result = await chain.ainvoke({
                "query": query,
                "table_infos": yaml.dump(table_infos, allow_unicode=True, sort_keys=False)  # 将list数组转为yaml格式,方便llm查看
            })

        # 利用模型输出过滤table_infos
        # {
        #   'fact_order':['order_amount', 'region_id'],
        #   'dim_region':['region_id', 'region_name']
        # }

        # 将大模型觉得不合理的去掉
        for table_info in table_infos[:]:
            if table_info["name"] not in result:
                table_infos.remove(table_info)
            else:
                for column in table_info["columns"][:]:
                    if column["name"] not in result[table_info["name"]]:
                        table_info["columns"].remove(column)

        logger.info(f"过滤后的表格信息: {[table["name"] for table in table_infos]}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "table_infos": table_infos
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"过滤表格信息报错: {str(e)}")
        raise e

让大模型过滤一些没有用的字段,然后再去掉。

过滤指标信息

过滤指标信息跟过滤表格信息一样

python 复制代码
async def filter_metric(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "过滤指标信息"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        metric_infos = state["metric_infos"]
        query = state["query"]

        prompt = PromptTemplate(template=load_prompt_template("filter_metric_info"),
                                input_variables=["query", "metric_infos"])

        output_parser = JsonOutputParser()

        chain = prompt | llm | output_parser

        async with llm_semaphore:
            result = await chain.ainvoke({
                "query": query,
                "metric_infos": yaml.dump(metric_infos, allow_unicode=True, sort_keys=False)  # 将list数组转为yaml格式,方便llm查看
            })

        # 利用模型输出过滤metric_infos ["x"]

        # 将大模型觉得不合理的去掉
        for metric_info in metric_infos[:]:
            if metric_info["name"] not in result:
                metric_infos.remove(metric_info)

        logger.info(f"过滤后的指标信息: {[metric_info["name"] for metric_info in metric_infos]}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "metric_infos": metric_infos
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"过滤指标信息报错: {str(e)}")
        raise e

只不过大模型返回的指标数据是数组。

增加额外上下文 节点

只需要增加数据仓库的信息+当前询问的日期信息

python 复制代码
async def add_extra_context(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "增加额外上下文"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        dw_mysql_repository = runtime.context["dw_mysql_repository"]
        # 增加当前时间/季度,数据库相关信息
        today = datetime.today()

        date = today.strftime("%Y-%m-%d")
        weekday = today.strftime("%A")
        quarter = f"Q{(today.month - 1) // 3 + 1}"

        date_info = DateInfoState(date=date, weekday=weekday, quarter=quarter)

        # 数仓环境信息
        db_info = await dw_mysql_repository.get_db_info()

        logger.info(f"额外上下文信息: 数据库信息-{db_info}, 日期信息-{date_info}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "db_info": db_info,
            "date_info": date_info
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"增加额外上下文报错: {str(e)}")
        raise e
生成sql 节点

我们已经具备所有信息了,可以将其交给大模型生成sql了

python 复制代码
async def generate_sql(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "生成SQL"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        query = state["query"]
        table_infos = state["table_infos"]
        metric_infos = state["metric_infos"]
        db_info = state["db_info"]
        date_info = state["date_info"]

        prompt = PromptTemplate(template=load_prompt_template("generate_sql"),
                                input_variables=["query", "table_infos", "metric_infos", "db_info", "date_info"])

        output_parser = StrOutputParser()

        chain = prompt | llm | output_parser

        async with llm_semaphore:
            sql = await chain.ainvoke({
                "query": query,
                "table_infos": yaml.dump(table_infos),
                "metric_infos": yaml.dump(metric_infos),
                "db_info": yaml.dump(db_info),
                "date_info": yaml.dump(date_info),
            })

        logger.info(f"生成的SQL: {sql}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "sql": sql
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"生成SQL报错: {str(e)}")
        raise e

如上,将query,召回的表信息table_infos,召回的metric_infos,还有数仓信息,当前日期信息一起交给大模型。生成sql返回。

校验sql 节点
python 复制代码
async def validate_sql(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "校验SQL"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        sql = state["sql"]
        dw_mysql_repository = runtime.context["dw_mysql_repository"]

        try:
            await dw_mysql_repository.validate_sql(sql)
            logger.info(f"SQL验证成功")
            writer(StreamInfo(type="progress", step=step, status="success"))
            return {"error": None}
        except Exception as ex:
            logger.error(f"SQL验证失败: {str(ex)}")
            writer(StreamInfo(type="progress", step=step, status="success"))
            return {"error": str(ex)}
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"校验SQL报错: {str(e)}")
        raise e

直接使用explain spl来判断当前sql有没有问题。

有问题,直接返回error:xx到重新生成sql节点,没问题的话,可以直接到执行sql节点。

修正sql
python 复制代码
async def correct_sql(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "修正SQL"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        sql = state["sql"]
        error = state["error"]

        query = state["query"]
        table_infos = state["table_infos"]
        metric_infos = state["metric_infos"]
        db_info = state["db_info"]
        date_info = state["date_info"]

        prompt = PromptTemplate(template=load_prompt_template("correct_sql"),
                                input_variables=["sql", "error", "query", "table_infos", "metric_infos", "db_info",
                                                 "date_info"])

        output_parser = StrOutputParser()
        chain = prompt | llm | output_parser
        async with llm_semaphore:
            result = await chain.ainvoke({
                "sql": sql,
                "query": query,
                "error": error,
                "table_infos": yaml.dump(table_infos),
                "metric_infos": yaml.dump(metric_infos),
                "db_info": yaml.dump(db_info),
                "date_info": yaml.dump(date_info)
            })

        logger.info(f"修正后的SQL: {result}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {
            "sql": result
        }
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"修正SQL报错: {str(e)}")
        raise e

修正sql节点主要是将生成的sql+错误信息+所有数据,都交给大模型让他重新生成。

执行sql节点
python 复制代码
async def run_sql(state: DataAgentState, runtime: Runtime[DataAgentContext]):
    writer = runtime.stream_writer
    step = "执行SQL"

    try:
        writer(StreamInfo(type="progress", step=step, status="running"))

        sql = state["sql"]
        dw_mysql_repository = runtime.context["dw_mysql_repository"]

        result = await dw_mysql_repository.execute_sql(sql)

        logger.info(f"执行SQL结果: {result}")
        writer(StreamInfo(type="progress", step=step, status="success"))
        return {"result": result}
    except Exception as e:
        writer({"type": "progress", "step": step, "status": "error"})
        logger.error(f"执行SQL报错: {str(e)}")
        raise e


 async def execute_sql(self, sql):
        result = await self.session.execute(text(sql))
        # mappings将每一行转为{id:xx,name:xx,}这种数据,否则默认是 (xx,xx,xx)
        # 将返回的数据转为dict数据
        return [dict(row) for row in result.mappings().fetchall()]

执行sql就是直接exeute(text(sql)),然后将结果返回即可。

执行结果:

集成FastApi

fastapi,中间件,生命周期跟express差不多。

也有DI依赖注入的概念,只不过fastapi的依赖注入不是单例,他只是在一个请求里面是单例。

python 复制代码
# 这是一个示例 Python 脚本。
import uuid

from fastapi import FastAPI, Request

from app.api.routers.query_router import query_router
from app.core.context import request_id_ctx_var
from app.core.lefespan import lifespan

# 添加生命周期函数
app = FastAPI(lifespan=lifespan)

# 注册路由,类似于app.use
app.include_router(query_router)


# 中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    # 调用路径函数之前,往每个请求注入独立变量
    request_id_ctx_var.set(uuid.uuid4())
    # 调用路径函数
    response = await call_next(request)
    # 调用路径函数之后
    return response


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)

我们需要使用FastAPI创建一个app,然后在app启动之前和之后做一些实例的初始化和关闭。

python 复制代码
@asynccontextmanager
async def lifespan(app: FastAPI):
    # FastAPI启动前执行
    embedding_client_manager.init()
    qdrant_client_manager.init()
    es_client_manager.init()
    meta_mysql_client_manager.init()
    dw_mysql_client_manager.init()

    # 启动状态
    yield

    # FastApi关闭前运行
    await qdrant_client_manager.close()
    await es_client_manager.close()
    await meta_mysql_client_manager.close()
    await dw_mysql_client_manager.close()

然后中间件主要是在调用函数之前,为每个请求注入唯一的id,在py中,每个请求就是一个协程,类似于java每个请求是一个线程一样,py可以使用

python 复制代码
from contextvars import ContextVar

request_id_ctx_var = ContextVar("request_id", default="1")

ContextVar为每个请求注入单独的request_id,类似于java的Thread。

最后则是注册路由

python 复制代码
@query_router.post("/api/query")
async def query(
        query: QueryDTO,
        # Depends依赖注入,每一个请求都会通过get_query_service创建一个新的实例
        query_service: QueryService = Depends(get_query_service),
):
    # 流式返回
    return StreamingResponse(
        query_service.query(query.query), media_type="text/event-stream",
    )

这里使用StreamingResponse流式返回,然后使用Depends依赖注入,每个单独的请求,都会调用get_query_service创建QuerySerivce实例,java则是单一实例。

甚至可以使用嵌套注入

python 复制代码
async def get_metric_qdrant_repository():
    return MetricQdrantRepository(qdrant_client_manager.client)


async def get_meta_mysql_repository(session: AsyncSession = Depends(get_meta_session)):
    return MetaMySQLRepository(session)


async def get_dw_mysql_repository(session: AsyncSession = Depends(get_dw_session)):
    return DwMySQLRepository(session)


async def get_query_service(
        # # 嵌套依赖注入
        embedding_client: HuggingFaceEndpointEmbeddings = Depends(get_embedding_client),
        column_qdrant_repository: ColumnQdrantRepository = Depends(get_column_qdrant_repository),
        value_es_repository: ValueEsRepository = Depends(get_value_es_repository),
        metric_qdrant_repository: MetricQdrantRepository = Depends(get_metric_qdrant_repository),
        meta_mysql_repository: MetaMySQLRepository = Depends(get_meta_mysql_repository),
        dw_mysql_repository: DwMySQLRepository = Depends(get_dw_mysql_repository)
) -> QueryService:
    return QueryService(
        embedding_client=embedding_client,
        column_qdrant_repository=column_qdrant_repository,
        value_es_repository=value_es_repository,
        metric_qdrant_repository=metric_qdrant_repository,
        meta_mysql_repository=meta_mysql_repository,
        dw_mysql_repository=dw_mysql_repository,
    )

如上,get_query_service又依赖get_embedding_client这些。

这里实现的是controller层,我们还需要实现逻辑Service层。

python 复制代码
class QueryService:

    def __init__(self,
                 embedding_client: HuggingFaceEndpointEmbeddings,
                 column_qdrant_repository: ColumnQdrantRepository,
                 value_es_repository: ValueEsRepository,
                 metric_qdrant_repository: MetricQdrantRepository,
                 meta_mysql_repository: MetaMySQLRepository,
                 dw_mysql_repository: DwMySQLRepository):
        self.embedding_client = embedding_client
        self.column_qdrant_repository = column_qdrant_repository
        self.value_es_repository = value_es_repository
        self.metric_qdrant_repository = metric_qdrant_repository
        self.meta_mysql_repository = meta_mysql_repository
        self.dw_mysql_repository = dw_mysql_repository

    async def query(self, query: str):
        context = DataAgentContext(
            embedding_client=self.embedding_client,
            column_qdrant_repository=self.column_qdrant_repository,
            value_es_repository=self.value_es_repository,
            metric_qdrant_repository=self.metric_qdrant_repository,
            meta_mysql_repository=self.meta_mysql_repository,
            dw_mysql_repository=self.dw_mysql_repository
        )
        state = DataAgentState(query=query)

        try:
            async for chunk in app.astream(input=state, context=context, stream_mode=["custom"]):
                # SSE发送格式: \n\n,  yield每次都会往前端推数据
                print("chunk==", chunk)
                yield f"data: {json.dumps(chunk, ensure_ascii=False, default=str)}\n\n"
        except Exception as e:
            # 接收每个节点的报错
            yield f"data: {json.dumps({"type": "error", "message": str(e)}, ensure_ascii=False, default=str)}\n\n"

因为我们是流式输出,这里需要使用yield返回数据,约定数据格式为{type:xx}这样。

这样一个简单的接口就做好了。

前端实现

因为使用sse

SSE 原理

SSE(Server-Sent Events)是一种单向的服务器推送协议,建立在普通 HTTP 之上。

关键特性:

连接由客户端发起,之后服务端持续往里写数据,连接不关闭

基于 text/event-stream MIME 类型

天然支持断线重连(浏览器自动重试)

只能服务端 → 客户端,不能反向

前端的实现方式:

原生 EventSource(只支持 GET)
js 复制代码
const es = new EventSource('/api/query')

es.addEventListener('custom', (e) => {
  const event = JSON.parse(e.data)
  // { type: 'progress', step: '...', status: '...' }
})

es.onerror = () => es.close()

局限:只能 GET,不能携带 body,无法传查询参数,实际项目中用得少。

方式二:fetch + ReadableStream(本项目用的)
js 复制代码
const response = await fetch('/api/query', {
  method: 'POST',
  body: JSON.stringify({ query: text }),
})

const reader = response.body!.getReader()
const decoder = new TextDecoder()
let buffer = ''

while (true) {
  const { done, value } = await reader.read()  // ① 分块读取二进制
  if (done) break
  buffer += decoder.decode(value, { stream: true })  // ② 解码,stream:true 处理跨块 UTF-8
  const lines = buffer.split('\n')
  buffer = lines.pop() ?? ''  // ③ 最后一行可能不完整,留到下次拼接

  for (const line of lines) {
    if (!line.startsWith('data:')) continue  // ④ 跳过 event:/id:/空行
    const raw = line.slice(5).trim()
    const parsed = JSON.parse(raw)           // ⑤ 解析 JSON
    yield parsed[1]                          // ⑥ 本项目后端包了一层数组
  }
}

stream: true 是关键:TCP 包边界不一定在字符边界,一个中文字可能被截成两个 chunk,TextDecoder 需要知道"还没结束"才能正确解码。

效果:

下面是一条一条推送的,所以我们展示的时候也是一步骤一步骤展示的。

这个项目还需要做很多扩展,比如

  • 记忆模式
  • 节点错误重试
  • 问题召回
  • ...
    后续有空继续研究
相关推荐
xyx-3v1 小时前
信号量(二进制/计数)
java·linux·数据库
u0110225121 小时前
HTML5多媒体资源动态替换Source标签的刷新机制
jvm·数据库·python
AI人工智能+电脑小能手1 小时前
【大白话说Java面试题】【Java基础篇】第18题:HashMap底层是如何扩容的
java·开发语言·面试·散列表·hash-index·hash
云祺vinchin1 小时前
“十五五”引领灾备升级,数字化安全建设如何合规落地?
网络·数据库·安全·kubernetes·数据安全·容灾备份
当战神遇到编程1 小时前
关系型数据库设计基础:约束、三大范式、表关系与表设计流程
数据库
想躺平的小羊1 小时前
IDEA 如何显示或关闭项目类的结构(类的方法)
java·ide·intellij-idea
其实防守也摸鱼1 小时前
《SQL注入进阶实验:基于sqli-Labs的报错注入(Error-Based Injection)实战解析》
网络·数据库·sql·安全·网络安全·sql注入·报错注入
A-Jie-Y1 小时前
JAVA设计模式-建造者模式
java·设计模式
Ting.~1 小时前
软件设计师备考笔记【day3】-数据库
数据库·笔记