基于SQL_Agent实现SQL助理
如何使用渐进式披露(一种上下文管理技术,其中Agent按需加载信息而非预先加载)来实现技能(基于提示的专用指令)。智能体通过工具调用加载技能,而不是动态更改系统提示,从而发现并加载每个任务所需的技能。尤其是在一个企业的SQL查询当中,企业可能会有上千张表,预先加载所有模式都会使上下文窗口不堪重负。渐进式披露通过仅在需要时加载相关模式来解决这个问题。这种架构还允许不同的产品负责人和利益相关者独立贡献并维护其特定业务部门的技能。
1.选择聊天模型(chat_model)
我是基于ollama平台调用大模型,比较方便:
from langchain_ollama import ChatOllama
model = ChatOllama(
model="qwen3:1.7b",
temperature=0,
reasoning = False
)
2.定义技能
首先,定义技能结构。每个技能都有一个名称、一个简短描述(显示在系统提示符中)和完整内容(按需加载)。
from typing import TypedDict
class Skill(TypedDict):
"""A skill that can be progressively disclosed to the agent."""
name: str # Unique identifier for the skill
description: str # 1-2 sentence description to show in system prompt
content: str # Full skill content with detailed instructions
SKILLS: list[Skill] = [
{
"name": "sales_analytics",
"description": "数据库架构和销售数据分析业务逻辑,包括客户、订单和收入。",
"content": """# 销售分析架构
## 表
### 顾客表(customers)
- customer_id (PRIMARY KEY)
- name
- email
- signup_date
- status (active/inactive)
- customer_tier (bronze/silver/gold/platinum)
### 订单表(orders)
- order_id (PRIMARY KEY)
- customer_id (FOREIGN KEY -> customers)
- order_date
- status (pending/completed/cancelled/refunded)
- total_amount
- sales_region (north/south/east/west)
### 订单项目表(order_items)
- item_id (PRIMARY KEY)
- order_id (FOREIGN KEY -> orders)
- product_id
- quantity
- unit_price
- discount_percent
## 业务逻辑
**活跃客户**: status = 'active' AND signup_date <= CURRENT_DATE - INTERVAL '90 days'
**收入计算**: Only count orders with status = 'completed'. Use total_amount from orders table, which already accounts for discounts.
**客户终生价值 (CLV)**: Sum of all completed order amounts for a customer.
**High-value orders**: Orders with total_amount > 1000
## 示例Query
-- 找出上个季度收入排名前十的客户
SELECT
c.customer_id,
c.name,
c.customer_tier,
SUM(o.total_amount) as total_revenue
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
WHERE o.status = 'completed'
AND o.order_date >= CURRENT_DATE - INTERVAL '3 months'
GROUP BY c.customer_id, c.name, c.customer_tier
ORDER BY total_revenue DESC
LIMIT 10;
""",
},
{
"name": "inventory_management",
"description": "用于库存跟踪的数据库架构和业务逻辑,包括产品、仓库和库存水平.",
"content": """# 库存管理方案
## Tables
### products(产品)
- product_id (PRIMARY KEY)
- product_name
- sku
- category
- unit_cost
- reorder_point (minimum stock level before reordering)
- discontinued (boolean)
### warehouses(仓库)
- warehouse_id (PRIMARY KEY)
- warehouse_name
- location
- capacity
### inventory(存货)
- inventory_id (PRIMARY KEY)
- product_id (FOREIGN KEY -> products)
- warehouse_id (FOREIGN KEY -> warehouses)
- quantity_on_hand
- last_updated
### stock_movements(存货)
- movement_id (PRIMARY KEY)
- product_id (FOREIGN KEY -> products)
- warehouse_id (FOREIGN KEY -> warehouses)
- movement_type (inbound/outbound/transfer/adjustment)
- quantity (positive for inbound, negative for outbound)
- movement_date
- reference_number
## Business Logic(业务逻辑)
**可用库存**: quantity_on_hand from inventory table where quantity_on_hand > 0
**需要重新订购的产品**: Products where total quantity_on_hand across all warehouses is less than or equal to the product's reorder_point
**仅有效产品**: Exclude products where discontinued = true unless specifically analyzing discontinued items
**股票估值**: quantity_on_hand * unit_cost for each product
## Example Query
-- 查找所有仓库中低于重新订购点的产品。
SELECT
p.product_id,
p.product_name,
p.reorder_point,
SUM(i.quantity_on_hand) as total_stock,
p.unit_cost,
(p.reorder_point - SUM(i.quantity_on_hand)) as units_to_reorder
FROM products p
JOIN inventory i ON p.product_id = i.product_id
WHERE p.discontinued = false
GROUP BY p.product_id, p.product_name, p.reorder_point, p.unit_cost
HAVING SUM(i.quantity_on_hand) <= p.reorder_point
ORDER BY units_to_reorder DESC;
""",
},
]
3.创建技能加载工具
创建一个按需加载完整技能内容的工具,该工具会以字符串的形式返回完整的技能内容,作为工具消息(ToolMessage)成为对话的一部分:
from langchain.tools import tool
@tool
def load_skill(skill_name: str) -> str:
"""将技能的完整内容加载到代理的上下文中。
当您需要了解如何处理特定类型的请求的详细信息时,请使用此功能。
这将为您提供该技能领域的全面说明、策略和指南。
Args:
skill_name: The name of the skill to load (e.g., "expense_reporting", "travel_booking")
"""
# Find and return the requested skill
for skill in SKILLS:
if skill["name"] == skill_name:
return f"Loaded skill: {skill_name}\n\n{skill['content']}"
# Skill not found
available = ", ".join(s["name"] for s in SKILLS)
return f"Skill '{skill_name}' not found. Available skills: {available}"
4.构建技能中间件
创建自定义中间件,将技能描述注入系统提示符。该中间件使技能无需预先加载全部内容即可被用户发现。
-
在不修改 system prompt 原文的情况下,通过中间件(Middleware)把"可用技能目录"动态注入到系统提示中,并引导 Agent 在需要时通过工具加载具体技能。
-
a.AgentMiddleware
- Agent 调用 LLM 前的拦截器 可以:改 request 注入 prompt 增加约束
-
ModelRequest
- a.表示一次 模型调用请求,包含:system message messages tools state
-
ModelResponse : 模型返回的结果(这里你只是透传)
from langchain.agents.middleware import ModelRequest, ModelResponse, AgentMiddleware
from langchain.messages import SystemMessage
from typing import Callableclass SkillMiddleware(AgentMiddleware):
"""将技能描述注入系统提示符的中间件."""# 将 load_skill 工具注册为类变量 tools = [load_skill] #load_skill 工具会自动注册到 Agent 可用工具中 def __init__(self): """从 SKILLS 初始化并生成技能提示.""" # 根据技能列表构建技能提示 skills_list = [] for skill in SKILLS: skills_list.append( f"- **{skill['name']}**: {skill['description']}" #没有技能细节 只是"你可以用哪些技能" ) self.skills_prompt = "\n".join(skills_list) def wrap_model_call( #这个方法什么时候执行? --> 每一次 Agent 调用模型之前 self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelResponse: """Sync: 将技能描述注入系统提示.""" # 技能提升附录 skills_addendum = ( f"\n\n## Available Skills\n\n{self.skills_prompt}\n\n" "当您需要有关处理特定类型请求的详细信息时,请使用 load_skill 工具" ) # Append to system message content blocks 真正的"注入系统提示" new_content = list(request.system_message.content_blocks) + [ {"type": "text", "text": skills_addendum} ] new_system_message = SystemMessage(content=new_content) modified_request = request.override(system_message=new_system_message) return handler(modified_request)
5.创建具有技能支持的代理
现在创建包含技能中间件和用于状态持久化的检查指针的代理:
from langchain.agents import create_agent
from langgraph.checkpoint.memory import InMemorySaver
agent = create_agent(
model,
system_prompt=(
"您是 SQL 查询助手,可以帮助用户编写针对业务数据库的查询。"
),
middleware=[SkillMiddleware()],
checkpointer=InMemorySaver(),
)
6.测试渐进式披露
通过一个需要特定技能知识的问题来测试智能体:
import uuid
# 配置对话线程
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}}
# Ask for a SQL query
result = agent.invoke(
{
"messages": [
{
"role": "user",
"content": (
"编写 SQL 查询语句,查找上个月所有订单金额超过 1000 美元的客户。"
),
}
]
},
config
)
# Print the conversation
for message in result["messages"]:
if hasattr(message, 'pretty_print'):
message.pretty_print()
else:
print(f"{message.type}: {message.content}")