附源码在最后
前言
在现代 RAG 系统中,大语言模型 API 的集成是核心环节之一。本文将深入解析一个获得 RAG 挑战赛冠军的多模态 API 处理系统实现,该系统优雅地统一了 OpenAI、IBM、Google Gemini、阿里云 DashScope 等多个主流 LLM 提供商的接口,通过统一的抽象层、智能重试机制、结构化输出处理等技术,实现了高可用、高性能的企业级大模型服务集成。
系统架构概览
该 API 处理系统采用了分层抽象的设计模式:
-
统一接口层 :
APIProcessor
- 对外提供统一的调用接口 -
提供商适配层 :各个
Base*Processor
- 适配不同 LLM 提供商的 API 差异 -
功能增强层:结构化输出、重试机制、异步处理等高级功能
-
工具支持层:Token 计算、JSON 修复、提示词管理等辅助功能
核心组件详解
1. 统一 API 处理器(APIProcessor)
APIProcessor 是系统的门面类,提供统一的多提供商 LLM 访问接口:
class APIProcessor:
def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):
self.provider = provider.lower()
if self.provider == "openai":
self.processor = BaseOpenaiProcessor()
elif self.provider == "ibm":
self.processor = BaseIBMAPIProcessor()
elif self.provider == "gemini":
self.processor = BaseGeminiProcessor()
elif self.provider == "dashscope":
self.processor = BaseDashscopeProcessor()
def send_message(
self,
model=None,
temperature=0.5,
seed=None,
system_content="You are a helpful assistant.",
human_content="Hello!",
is_structured=False,
response_format=None,
**kwargs
):
"""统一的消息发送接口,路由到对应的处理器"""
if model is None:
model = self.processor.default_model
return self.processor.send_message(
model=model,
temperature=temperature,
seed=seed,
system_content=system_content,
human_content=human_content,
is_structured=is_structured,
response_format=response_format,
**kwargs
)
设计亮点:
-
统一接口:屏蔽不同提供商的 API 差异
-
动态路由:根据配置自动选择对应的处理器
-
参数透传:支持各提供商的特有参数
-
类型安全:使用 Literal 类型确保提供商名称正确
2. OpenAI 处理器(BaseOpenaiProcessor)
OpenAI 处理器实现了对 OpenAI API 的完整封装,支持结构化输出和 Token 统计:
class BaseOpenaiProcessor:
def __init__(self):
self.llm = self.set_up_llm()
self.default_model = 'gpt-4o-2024-08-06'
def set_up_llm(self):
# 加载OpenAI API密钥,初始化LLM
load_dotenv()
llm = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
timeout=None,
max_retries=2
)
return llm
def send_message(
self,
model=None,
temperature=0.5,
seed=None,
system_content='You are a helpful assistant.',
human_content='Hello!',
is_structured=False,
response_format=None
):
if model is None:
model = self.default_model
params = {
"model": model,
"seed": seed,
"messages": [
{"role": "system", "content": system_content},
{"role": "user", "content": human_content}
]
}
# 部分模型不支持temperature参数
if "o3-mini" not in model:
params["temperature"] = temperature
if not is_structured:
# 普通文本输出
completion = self.llm.chat.completions.create(**params)
content = completion.choices[0].message.content
else:
# 结构化输出
params["response_format"] = response_format
completion = self.llm.beta.chat.completions.parse(**params)
response = completion.choices[0].message.parsed
content = response.dict()
# 记录使用统计
self.response_data = {
"model": completion.model,
"input_tokens": completion.usage.prompt_tokens,
"output_tokens": completion.usage.completion_tokens
}
print(self.response_data)
return content
@staticmethod
def count_tokens(string, encoding_name="o200k_base"):
# 统计字符串的token数
encoding = tiktoken.get_encoding(encoding_name)
tokens = encoding.encode(string)
token_count = len(tokens)
return token_count
技术特色:
-
结构化输出:支持 Pydantic 模型的结构化响应
-
模型兼容性:智能处理不同模型的参数差异
-
Token 统计:精确的 Token 使用量统计
-
错误处理:完善的超时和重试机制
3. IBM API 处理器(BaseIBMAPIProcessor)
IBM 处理器提供了对 IBM Watson 等企业级 AI 服务的集成:
class BaseIBMAPIProcessor:
def __init__(self):
load_dotenv()
self.api_token = os.getenv("IBM_API_KEY")
self.base_url = "https://rag.timetoact.at/ibm"
self.default_model = 'meta-llama/llama-3-3-70b-instruct'
def check_balance(self):
"""查询当前API余额"""
balance_url = f"{self.base_url}/balance"
headers = {"Authorization": f"Bearer {self.api_token}"}
try:
response = requests.get(balance_url, headers=headers)
response.raise_for_status()
return response.json()
except requests.HTTPError as err:
print(f"Error checking balance: {err}")
return None
def get_embeddings(self, texts, model_id="ibm/granite-embedding-278m-multilingual"):
"""获取文本的向量嵌入"""
embeddings_url = f"{self.base_url}/embeddings"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
payload = {
"inputs": texts,
"model_id": model_id
}
try:
response = requests.post(embeddings_url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except requests.HTTPError as err:
print(f"Error getting embeddings: {err}")
return None
def send_message(
self,
model=None,
temperature=0.5,
seed=None,
system_content='You are a helpful assistant.',
human_content='Hello!',
is_structured=False,
response_format=None,
max_new_tokens=5000,
min_new_tokens=1,
**kwargs
):
if model is None:
model = self.default_model
text_generation_url = f"{self.base_url}/text_generation"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
# 准备输入消息
input_messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": human_content}
]
# 准备参数
parameters = {
"temperature": temperature,
"random_seed": seed,
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
**kwargs
}
payload = {
"input": input_messages,
"model_id": model,
"parameters": parameters
}
try:
response = requests.post(text_generation_url, headers=headers, json=payload)
response.raise_for_status()
completion = response.json()
content = completion.get("results")[0].get("generated_text")
self.response_data = {
"model": completion.get("model_id"),
"input_tokens": completion.get("results")[0].get("input_token_count"),
"output_tokens": completion.get("results")[0].get("generated_token_count")
}
# 结构化输出处理
if is_structured and response_format is not None:
try:
repaired_json = repair_json(content)
parsed_dict = json.loads(repaired_json)
validated_data = response_format.model_validate(parsed_dict)
content = validated_data.model_dump()
except Exception as err:
print("Error processing structured response, attempting to reparse...")
content = self._reparse_response(content, system_content)
return content
except requests.HTTPError as err:
print(f"Error generating text: {err}")
return None
企业级特性:
-
余额查询:实时监控 API 使用额度
-
嵌入服务:支持多语言文本嵌入
-
灵活参数:支持丰富的生成参数配置
-
错误恢复:智能的 JSON 修复和重解析机制
4. Google Gemini 处理器(BaseGeminiProcessor)
Gemini 处理器集成了 Google 的最新大模型服务:
class BaseGeminiProcessor:
def __init__(self):
self.llm = self._set_up_llm()
self.default_model = 'gemini-2.0-flash-001'
def _set_up_llm(self):
load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=api_key)
return genai
@retry(
wait=wait_fixed(20),
stop=stop_after_attempt(3),
before_sleep=lambda retry_state: print(f"\nAPI Error: {retry_state.outcome.exception()}\nWaiting 20 seconds...\n"),
)
def _generate_with_retry(self, model, human_content, generation_config):
"""带重试机制的内容生成"""
try:
return model.generate_content(
human_content,
generation_config=generation_config
)
except Exception as e:
if getattr(e, '_attempt_number', 0) == 3:
print(f"\nRetry failed. Error: {str(e)}\n")
raise
def send_message(
self,
model=None,
temperature: float = 0.5,
seed=12345,
system_content: str = "You are a helpful assistant.",
human_content: str = "Hello!",
is_structured: bool = False,
response_format: Optional[Type[BaseModel]] = None,
) -> Union[str, Dict, None]:
if model is None:
model = self.default_model
generation_config = {"temperature": temperature}
# Gemini 使用单一提示词格式
prompt = f"{system_content}\n\n---\n\n{human_content}"
model_instance = self.llm.GenerativeModel(
model_name=model,
generation_config=generation_config
)
try:
response = self._generate_with_retry(model_instance, prompt, generation_config)
self.response_data = {
"model": response.model_version,
"input_tokens": response.usage_metadata.prompt_token_count,
"output_tokens": response.usage_metadata.candidates_token_count
}
if is_structured and response_format is not None:
return self._parse_structured_response(response.text, response_format)
return response.text
except Exception as e:
raise Exception(f"API request failed after retries: {str(e)}")
def _parse_structured_response(self, response_text, response_format):
"""解析结构化响应"""
try:
repaired_json = repair_json(response_text)
parsed_dict = json.loads(repaired_json)
validated_data = response_format.model_validate(parsed_dict)
return validated_data.model_dump()
except Exception as err:
print(f"Error parsing structured response: {err}")
return self._reparse_response(response_text, response_format)
技术亮点:
-
智能重试:使用 tenacity 库实现指数退避重试
-
提示词适配:适配 Gemini 的单一提示词格式
-
使用统计:详细的 Token 使用量统计
-
结构化解析:智能的 JSON 解析和修复
5. 阿里云 DashScope 处理器(BaseDashscopeProcessor)
DashScope 处理器集成了阿里云的通义千问大模型:
class BaseDashscopeProcessor:
def __init__(self):
# 从环境变量读取API-KEY
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
self.default_model = 'qwen-turbo-latest'
def send_message(
self,
model="qwen-turbo-latest",
temperature=0.1,
seed=None,
system_content='You are a helpful assistant.',
human_content='Hello!',
is_structured=False,
response_format=None,
**kwargs
):
"""发送消息到DashScope Qwen大模型"""
if model is None:
model = self.default_model
# 拼接 messages
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
if human_content:
messages.append({"role": "user", "content": human_content})
# 调用 dashscope Generation.call
response = dashscope.Generation.call(
model=model,
messages=messages,
temperature=temperature,
result_format='message'
)
# 兼容统一接口格式
if hasattr(response, 'output') and hasattr(response.output, 'choices'):
content = response.output.choices[0].message.content
else:
content = str(response)
# 保持接口一致性
self.response_data = {"model": model, "input_tokens": None, "output_tokens": None}
# 统一返回格式
return {"final_answer": content}
国产化特色:
-
本土优化:针对中文场景优化的通义千问模型
-
简化接口:简洁的 API 调用方式
-
兼容设计:保持与其他提供商的接口一致性
-
成本优势:相对较低的使用成本
6. RAG 上下文处理
系统提供了专门的 RAG 上下文处理功能:
def get_answer_from_rag_context(self, question, rag_context, schema, model):
"""从RAG上下文生成答案"""
system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)
answer_dict = self.processor.send_message(
model=model,
system_content=system_prompt,
human_content=user_prompt.format(context=rag_context, question=question),
is_structured=True,
response_format=response_format
)
self.response_data = self.processor.response_data
# 兜底处理:确保返回完整的答案结构
if 'step_by_step_analysis' not in answer_dict:
answer_dict = {
"step_by_step_analysis": "",
"reasoning_summary": "",
"relevant_pages": [],
"final_answer": answer_dict.get("final_answer", "N/A")
}
return answer_dict
def _build_rag_context_prompts(self, schema):
"""根据答案类型构建提示词"""
use_schema_prompt = True if self.provider == "ibm" or self.provider == "gemini" else False
if schema == "name":
system_prompt = (prompts.AnswerWithRAGContextNamePrompt.system_prompt_with_schema
if use_schema_prompt else prompts.AnswerWithRAGContextNamePrompt.system_prompt)
response_format = prompts.AnswerWithRAGContextNamePrompt.AnswerSchema
user_prompt = prompts.AnswerWithRAGContextNamePrompt.user_prompt
elif schema == "number":
system_prompt = (prompts.AnswerWithRAGContextNumberPrompt.system_prompt_with_schema
if use_schema_prompt else prompts.AnswerWithRAGContextNumberPrompt.system_prompt)
response_format = prompts.AnswerWithRAGContextNumberPrompt.AnswerSchema
user_prompt = prompts.AnswerWithRAGContextNumberPrompt.user_prompt
# ... 其他类型
else:
raise ValueError(f"Unsupported schema: {schema}")
return system_prompt, response_format, user_prompt
RAG 特色功能:
-
多类型支持:支持 name、number、boolean、names、comparative 等多种答案类型
-
智能提示词:根据提供商特性选择最优提示词格式
-
结构化输出:确保答案包含推理过程和引用信息
-
兜底机制:处理不完整响应的智能补全
7. 异步批量处理
系统还提供了高性能的异步批量处理功能:
class AsyncOpenaiProcessor:
async def process_structured_ouputs_requests(
self,
model="gpt-4o-mini-2024-07-18",
temperature=0.5,
seed=None,
system_content="You are a helpful assistant.",
queries=None,
response_format=None,
requests_filepath='./temp_async_llm_requests.jsonl',
save_filepath='./temp_async_llm_results.jsonl',
max_requests_per_minute=3_500,
max_tokens_per_minute=3_500_000,
progress_callback=None
):
# 创建批量请求
jsonl_requests = []
for idx, query in enumerate(queries):
request = {
"model": model,
"temperature": temperature,
"seed": seed,
"messages": [
{"role": "system", "content": system_content},
{"role": "user", "content": query},
],
'response_format': type_to_response_format_param(response_format),
'metadata': {'original_index': idx}
}
jsonl_requests.append(request)
# 写入JSONL文件
with open(requests_filepath, "w") as f:
for request in jsonl_requests:
json_string = json.dumps(request)
f.write(json_string + "\n")
# 异步处理和进度监控
async def monitor_progress():
last_count = 0
while True:
try:
with open(save_filepath, 'r') as f:
current_count = sum(1 for _ in f)
if current_count > last_count:
if progress_callback:
for _ in range(current_count - last_count):
progress_callback()
last_count = current_count
if current_count >= len(jsonl_requests):
break
except FileNotFoundError:
pass
await asyncio.sleep(0.1)
# 并行执行处理和监控
await asyncio.gather(
process_api_requests_from_file(
requests_filepath=requests_filepath,
save_filepath=save_filepath,
request_url="https://api.openai.com/v1/chat/completions",
api_key=os.getenv("OPENAI_API_KEY"),
max_requests_per_minute=max_requests_per_minute,
max_tokens_per_minute=max_tokens_per_minute,
max_attempts=5
),
monitor_progress()
)
# 解析结果并排序
with open(save_filepath, "r") as f:
results = []
for line_number, line in enumerate(f, start=1):
try:
result = json.loads(line.strip())
answer_content = result[1]['choices'][0]['message']['content']
answer_parsed = json.loads(answer_content)
answer = response_format(**answer_parsed).model_dump()
results.append({
'index': result[2]['original_index'],
'question': result[0]['messages'],
'answer': answer
})
except Exception as e:
print(f"[ERROR] Line {line_number}: Failed to parse. Error: {e}")
# 按原始顺序排序
validated_data_list = [
{'question': r['question'], 'answer': r['answer']}
for r in sorted(results, key=lambda x: x['index'])
]
return validated_data_list
异步处理优势:
-
高并发:支持每分钟数千次请求
-
进度监控:实时显示处理进度
-
结果排序:保持输出顺序与输入一致
-
错误处理:完善的异常处理和日志记录
智能错误处理与修复
JSON 修复机制
系统实现了智能的 JSON 修复机制:
def _reparse_response(self, response, system_content):
"""使用LLM重新解析无效的JSON响应"""
user_prompt = prompts.AnswerSchemaFixPrompt.user_prompt.format(
system_prompt=system_content,
response=response
)
reparsed_response = self.send_message(
system_content=prompts.AnswerSchemaFixPrompt.system_prompt,
human_content=user_prompt,
is_structured=False
)
try:
repaired_json = repair_json(reparsed_response)
reparsed_dict = json.loads(repaired_json)
validated_data = response_format.model_validate(reparsed_dict)
print("Reparsing successful!")
return validated_data.model_dump()
except Exception as reparse_err:
print(f"Reparse failed with error: {reparse_err}")
return response
修复策略:
-
自动修复:使用 json_repair 库自动修复常见 JSON 错误
-
LLM 重解析:当自动修复失败时,使用 LLM 重新格式化
-
多层兜底:提供多层错误处理机制
-
日志记录:详细记录修复过程和结果
实际应用场景
1. 企业多云部署
# 配置多个提供商作为备份
primary_processor = APIProcessor(provider="openai")
backup_processor = APIProcessor(provider="dashscope")
def robust_api_call(question, context, schema):
try:
return primary_processor.get_answer_from_rag_context(
question=question,
rag_context=context,
schema=schema,
model="gpt-4o-2024-08-06"
)
except Exception as e:
print(f"Primary API failed: {e}, switching to backup...")
return backup_processor.get_answer_from_rag_context(
question=question,
rag_context=context,
schema=schema,
model="qwen-turbo-latest"
)
2. 成本优化策略
# 根据问题复杂度选择合适的模型
def cost_optimized_processing(question, context, schema):
# 简单问题使用成本较低的模型
if len(question) < 50 and schema in ["boolean", "name"]:
processor = APIProcessor(provider="dashscope")
model = "qwen-turbo-latest"
else:
# 复杂问题使用高性能模型
processor = APIProcessor(provider="openai")
model = "gpt-4o-2024-08-06"
return processor.get_answer_from_rag_context(
question=question,
rag_context=context,
schema=schema,
model=model
)
3. 批量处理优化
# 大规模批量处理
async def batch_process_questions(questions_list):
processor = AsyncOpenaiProcessor()
queries = [q["question"] for q in questions_list]
results = await processor.process_structured_ouputs_requests(
model="gpt-4o-mini-2024-07-18",
system_content="You are a helpful RAG assistant.",
queries=queries,
response_format=AnswerSchema,
max_requests_per_minute=1000,
progress_callback=lambda: print(".", end="")
)
return results
性能优化策略
1. 连接池管理
# 优化HTTP连接
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
def setup_session():
session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(max_retries=retry_strategy, pool_connections=20, pool_maxsize=20)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
2. 缓存机制
from functools import lru_cache
import hashlib
class CachedAPIProcessor(APIProcessor):
@lru_cache(maxsize=1000)
def cached_send_message(self, content_hash, **kwargs):
return super().send_message(**kwargs)
def send_message(self, **kwargs):
# 生成内容哈希用于缓存
content = f"{kwargs.get('system_content', '')}{kwargs.get('human_content', '')}"
content_hash = hashlib.md5(content.encode()).hexdigest()
return self.cached_send_message(content_hash, **kwargs)
3. 负载均衡
import random
from typing import List
class LoadBalancedAPIProcessor:
def __init__(self, providers: List[str]):
self.processors = [APIProcessor(provider=p) for p in providers]
self.weights = [1.0] * len(self.processors) # 可根据性能调整权重
def send_message(self, **kwargs):
# 根据权重随机选择处理器
processor = random.choices(self.processors, weights=self.weights)[0]
try:
return processor.send_message(**kwargs)
except Exception as e:
# 降低失败处理器的权重
idx = self.processors.index(processor)
self.weights[idx] *= 0.8
# 重试其他处理器
for other_processor in self.processors:
if other_processor != processor:
try:
return other_processor.send_message(**kwargs)
except:
continue
raise e
监控与调试
1. 详细日志记录
import logging
from datetime import datetime
class LoggedAPIProcessor(APIProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = logging.getLogger(f"APIProcessor-{self.provider}")
def send_message(self, **kwargs):
start_time = datetime.now()
try:
result = super().send_message(**kwargs)
end_time = datetime.now()
self.logger.info(f"API call successful - Provider: {self.provider}, "
f"Duration: {(end_time - start_time).total_seconds():.2f}s, "
f"Tokens: {self.response_data}")
return result
except Exception as e:
end_time = datetime.now()
self.logger.error(f"API call failed - Provider: {self.provider}, "
f"Duration: {(end_time - start_time).total_seconds():.2f}s, "
f"Error: {str(e)}")
raise
2. 性能监控
import time
from collections import defaultdict
class PerformanceMonitor:
def __init__(self):
self.stats = defaultdict(list)
def record_api_call(self, provider, duration, tokens_used, success):
self.stats[provider].append({
'duration': duration,
'tokens': tokens_used,
'success': success,
'timestamp': time.time()
})
def get_stats(self, provider=None):
if provider:
calls = self.stats[provider]
else:
calls = []
for provider_calls in self.stats.values():
calls.extend(provider_calls)
if not calls:
return {}
successful_calls = [c for c in calls if c['success']]
return {
'total_calls': len(calls),
'success_rate': len(successful_calls) / len(calls),
'avg_duration': sum(c['duration'] for c in successful_calls) / len(successful_calls) if successful_calls else 0,
'total_tokens': sum(c['tokens'].get('input_tokens', 0) + c['tokens'].get('output_tokens', 0) for c in successful_calls if c['tokens'])
}
最佳实践建议
1. 提供商选择策略
# 根据场景选择最优提供商
def choose_optimal_provider(task_type, budget_level, latency_requirement):
if budget_level == "low":
return "dashscope" # 成本优势
elif latency_requirement == "ultra_low":
return "openai" # 响应速度快
elif task_type == "multilingual":
return "gemini" # 多语言支持好
elif task_type == "enterprise":
return "ibm" # 企业级特性
else:
return "openai" # 默认选择
2. 错误处理最佳实践
from tenacity import retry, stop_after_attempt, wait_exponential
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10)
)
def robust_api_call(processor, **kwargs):
try:
return processor.send_message(**kwargs)
except Exception as e:
if "rate_limit" in str(e).lower():
time.sleep(60) # 等待限流恢复
raise
3. 配置管理
import yaml
class APIConfig:
def __init__(self, config_file="api_config.yaml"):
with open(config_file, 'r') as f:
self.config = yaml.safe_load(f)
def get_provider_config(self, provider):
return self.config.get('providers', {}).get(provider, {})
def get_model_for_task(self, task_type):
return self.config.get('task_models', {}).get(task_type, "default")
# api_config.yaml 示例
"""
providers:
openai:
default_model: "gpt-4o-2024-08-06"
max_retries: 3
timeout: 30
dashscope:
default_model: "qwen-turbo-latest"
max_retries: 2
timeout: 20
task_models:
simple_qa: "gpt-4o-mini-2024-07-18"
complex_analysis: "gpt-4o-2024-08-06"
multilingual: "gemini-2.0-flash-001"
"""
总结
这个多模态 API 处理系统展示了企业级 LLM 集成的最佳实践:
-
统一抽象:优雅的多提供商统一接口设计
-
智能适配:针对不同提供商的特性优化
-
错误恢复:完善的重试、修复和兜底机制
-
性能优化:异步处理、连接池、缓存等优化策略
-
监控调试:详细的日志记录和性能监控
-
扩展性:易于添加新提供商和新功能
-
企业特性:余额查询、嵌入服务等企业级功能
对于构建企业级 RAG 系统,这个 API 处理架构提供了完整的参考实现。通过合理的设计和优化,可以在保证服务质量的同时,实现高可用、高性能、低成本的大模型服务集成。
参考资源
本文基于 RAG-Challenge-2 获奖项目的 API 处理模块源码分析,展示了工业级多模态 LLM 集成的完整实现和优化策略。希望对正在构建类似系统的开发者有所帮助。
import os
import json
from dotenv import load_dotenv
from typing import Union, List, Dict, Type, Optional, Literal
from openai import OpenAI
import asyncio
from src.api_request_parallel_processor import process_api_requests_from_file
from openai.lib._parsing import type_to_response_format_param
import tiktoken
import src.prompts as prompts
import requests
from json_repair import repair_json
from pydantic import BaseModel
import google.generativeai as genai
from copy import deepcopy
from tenacity import retry, stop_after_attempt, wait_fixed
import dashscope
# OpenAI基础处理器,封装了消息发送、结构化输出、计费等逻辑
class BaseOpenaiProcessor:
def __init__(self):
self.llm = self.set_up_llm()
self.default_model = 'gpt-4o-2024-08-06'
# self.default_model = 'gpt-4o-mini-2024-07-18',
def set_up_llm(self):
# 加载OpenAI API密钥,初始化LLM
load_dotenv()
llm = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
timeout=None,
max_retries=2
)
return llm
def send_message(
self,
model=None,
temperature=0.5,
seed=None, # For deterministic ouptputs
system_content='You are a helpful assistant.',
human_content='Hello!',
is_structured=False,
response_format=None
):
# 发送消息到OpenAI,支持结构化/非结构化输出
if model is None:
model = self.default_model
params = {
"model": model,
"seed": seed,
"messages": [
{"role": "system", "content": system_content},
{"role": "user", "content": human_content}
]
}
# 部分模型不支持temperature
if "o3-mini" not in model:
params["temperature"] = temperature
if not is_structured:
completion = self.llm.chat.completions.create(**params)
content = completion.choices[0].message.content
elif is_structured:
params["response_format"] = response_format
completion = self.llm.beta.chat.completions.parse(**params)
response = completion.choices[0].message.parsed
content = response.dict()
self.response_data = {"model": completion.model, "input_tokens": completion.usage.prompt_tokens, "output_tokens": completion.usage.completion_tokens}
print(self.response_data)
return content
@staticmethod
def count_tokens(string, encoding_name="o200k_base"):
# 统计字符串的token数
encoding = tiktoken.get_encoding(encoding_name)
# Encode the string and count the tokens
tokens = encoding.encode(string)
token_count = len(tokens)
return token_count
# IBM API基础处理器,支持余额查询、模型列表、嵌入、消息发送等
class BaseIBMAPIProcessor:
def __init__(self):
load_dotenv()
self.api_token = os.getenv("IBM_API_KEY")
self.base_url = "https://rag.timetoact.at/ibm"
self.default_model = 'meta-llama/llama-3-3-70b-instruct'
def check_balance(self):
"""查询当前API余额"""
balance_url = f"{self.base_url}/balance"
headers = {"Authorization": f"Bearer {self.api_token}"}
try:
response = requests.get(balance_url, headers=headers)
response.raise_for_status()
return response.json()
except requests.HTTPError as err:
print(f"Error checking balance: {err}")
return None
def get_available_models(self):
"""获取可用基础模型列表"""
models_url = f"{self.base_url}/foundation_model_specs"
try:
response = requests.get(models_url)
response.raise_for_status()
return response.json()
except requests.HTTPError as err:
print(f"Error getting available models: {err}")
return None
def get_embeddings(self, texts, model_id="ibm/granite-embedding-278m-multilingual"):
"""获取文本的向量嵌入"""
embeddings_url = f"{self.base_url}/embeddings"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
payload = {
"inputs": texts,
"model_id": model_id
}
try:
response = requests.post(embeddings_url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except requests.HTTPError as err:
print(f"Error getting embeddings: {err}")
return None
def send_message(
self,
# model='meta-llama/llama-3-1-8b-instruct',
model=None,
temperature=0.5,
seed=None, # For deterministic outputs
system_content='You are a helpful assistant.',
human_content='Hello!',
is_structured=False,
response_format=None,
max_new_tokens=5000,
min_new_tokens=1,
**kwargs
):
# 发送消息到IBM API,支持结构化/非结构化输出
if model is None:
model = self.default_model
text_generation_url = f"{self.base_url}/text_generation"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
# Prepare the input messages
input_messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": human_content}
]
# Prepare parameters with defaults and any additional parameters
parameters = {
"temperature": temperature,
"random_seed": seed,
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
**kwargs
}
payload = {
"input": input_messages,
"model_id": model,
"parameters": parameters
}
try:
response = requests.post(text_generation_url, headers=headers, json=payload)
response.raise_for_status()
completion = response.json()
content = completion.get("results")[0].get("generated_text")
self.response_data = {"model": completion.get("model_id"), "input_tokens": completion.get("results")[0].get("input_token_count"), "output_tokens": completion.get("results")[0].get("generated_token_count")}
print(self.response_data)
if is_structured and response_format is not None:
try:
repaired_json = repair_json(content)
parsed_dict = json.loads(repaired_json)
validated_data = response_format.model_validate(parsed_dict)
content = validated_data.model_dump()
return content
except Exception as err:
print("Error processing structured response, attempting to reparse the response...")
reparsed = self._reparse_response(content, system_content)
try:
repaired_json = repair_json(reparsed)
reparsed_dict = json.loads(repaired_json)
try:
validated_data = response_format.model_validate(reparsed_dict)
print("Reparsing successful!")
content = validated_data.model_dump()
return content
except Exception:
return reparsed_dict
except Exception as reparse_err:
print(f"Reparse failed with error: {reparse_err}")
print(f"Reparsed response: {reparsed}")
return content
return content
except requests.HTTPError as err:
print(f"Error generating text: {err}")
return None
def _reparse_response(self, response, system_content):
user_prompt = prompts.AnswerSchemaFixPrompt.user_prompt.format(
system_prompt=system_content,
response=response
)
reparsed_response = self.send_message(
system_content=prompts.AnswerSchemaFixPrompt.system_prompt,
human_content=user_prompt,
is_structured=False
)
return reparsed_response
class BaseGeminiProcessor:
def __init__(self):
self.llm = self._set_up_llm()
self.default_model = 'gemini-2.0-flash-001'
# self.default_model = "gemini-2.0-flash-thinking-exp-01-21",
def _set_up_llm(self):
load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=api_key)
return genai
def list_available_models(self) -> None:
"""
Prints available Gemini models that support text generation.
"""
print("Available models for text generation:")
for model in self.llm.list_models():
if "generateContent" in model.supported_generation_methods:
print(f"- {model.name}")
print(f" Input token limit: {model.input_token_limit}")
print(f" Output token limit: {model.output_token_limit}")
print()
def _log_retry_attempt(retry_state):
"""Print information about the retry attempt"""
exception = retry_state.outcome.exception()
print(f"\nAPI Error encountered: {str(exception)}")
print("Waiting 20 seconds before retry...\n")
@retry(
wait=wait_fixed(20),
stop=stop_after_attempt(3),
before_sleep=_log_retry_attempt,
)
def _generate_with_retry(self, model, human_content, generation_config):
"""Wrapper for generate_content with retry logic"""
try:
return model.generate_content(
human_content,
generation_config=generation_config
)
except Exception as e:
if getattr(e, '_attempt_number', 0) == 3:
print(f"\nRetry failed. Error: {str(e)}\n")
raise
def _parse_structured_response(self, response_text, response_format):
try:
repaired_json = repair_json(response_text)
parsed_dict = json.loads(repaired_json)
validated_data = response_format.model_validate(parsed_dict)
return validated_data.model_dump()
except Exception as err:
print(f"Error parsing structured response: {err}")
print("Attempting to reparse the response...")
reparsed = self._reparse_response(response_text, response_format)
return reparsed
def _reparse_response(self, response, response_format):
"""Reparse invalid JSON responses using the model itself."""
user_prompt = prompts.AnswerSchemaFixPrompt.user_prompt.format(
system_prompt=prompts.AnswerSchemaFixPrompt.system_prompt,
response=response
)
try:
reparsed_response = self.send_message(
model="gemini-2.0-flash-001",
system_content=prompts.AnswerSchemaFixPrompt.system_prompt,
human_content=user_prompt,
is_structured=False
)
try:
repaired_json = repair_json(reparsed_response)
reparsed_dict = json.loads(repaired_json)
try:
validated_data = response_format.model_validate(reparsed_dict)
print("Reparsing successful!")
return validated_data.model_dump()
except Exception:
return reparsed_dict
except Exception as reparse_err:
print(f"Reparse failed with error: {reparse_err}")
print(f"Reparsed response: {reparsed_response}")
return response
except Exception as e:
print(f"Reparse attempt failed: {e}")
return response
def send_message(
self,
model=None,
temperature: float = 0.5,
seed=12345, # For back compatibility
system_content: str = "You are a helpful assistant.",
human_content: str = "Hello!",
is_structured: bool = False,
response_format: Optional[Type[BaseModel]] = None,
) -> Union[str, Dict, None]:
if model is None:
model = self.default_model
generation_config = {"temperature": temperature}
prompt = f"{system_content}\n\n---\n\n{human_content}"
model_instance = self.llm.GenerativeModel(
model_name=model,
generation_config=generation_config
)
try:
response = self._generate_with_retry(model_instance, prompt, generation_config)
self.response_data = {
"model": response.model_version,
"input_tokens": response.usage_metadata.prompt_token_count,
"output_tokens": response.usage_metadata.candidates_token_count
}
print(self.response_data)
if is_structured and response_format is not None:
return self._parse_structured_response(response.text, response_format)
return response.text
except Exception as e:
raise Exception(f"API request failed after retries: {str(e)}")
class APIProcessor:
def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):
self.provider = provider.lower()
if self.provider == "openai":
self.processor = BaseOpenaiProcessor()
elif self.provider == "ibm":
self.processor = BaseIBMAPIProcessor()
elif self.provider == "gemini":
self.processor = BaseGeminiProcessor()
elif self.provider == "dashscope":
self.processor = BaseDashscopeProcessor()
def send_message(
self,
model=None,
temperature=0.5,
seed=None,
system_content="You are a helpful assistant.",
human_content="Hello!",
is_structured=False,
response_format=None,
**kwargs
):
"""
Routes the send_message call to the appropriate processor.
The underlying processor's send_message method is responsible for handling the parameters.
"""
if model is None:
model = self.processor.default_model
return self.processor.send_message(
model=model,
temperature=temperature,
seed=seed,
system_content=system_content,
human_content=human_content,
is_structured=is_structured,
response_format=response_format,
**kwargs
)
def get_answer_from_rag_context(self, question, rag_context, schema, model):
system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)
answer_dict = self.processor.send_message(
model=model,
system_content=system_prompt,
human_content=user_prompt.format(context=rag_context, question=question),
is_structured=True,
response_format=response_format
)
self.response_data = self.processor.response_data
# 假如 answer_dict 只有 final_answer,自动兜底
if 'step_by_step_analysis' not in answer_dict:
answer_dict = {
"step_by_step_analysis": "",
"reasoning_summary": "",
"relevant_pages": [],
"final_answer": answer_dict.get("final_answer", "N/A")
}
return answer_dict
def _build_rag_context_prompts(self, schema):
"""Return prompts tuple for the given schema."""
use_schema_prompt = True if self.provider == "ibm" or self.provider == "gemini" else False
if schema == "name":
system_prompt = (prompts.AnswerWithRAGContextNamePrompt.system_prompt_with_schema
if use_schema_prompt else prompts.AnswerWithRAGContextNamePrompt.system_prompt)
response_format = prompts.AnswerWithRAGContextNamePrompt.AnswerSchema
user_prompt = prompts.AnswerWithRAGContextNamePrompt.user_prompt
elif schema == "number":
system_prompt = (prompts.AnswerWithRAGContextNumberPrompt.system_prompt_with_schema
if use_schema_prompt else prompts.AnswerWithRAGContextNumberPrompt.system_prompt)
response_format = prompts.AnswerWithRAGContextNumberPrompt.AnswerSchema
user_prompt = prompts.AnswerWithRAGContextNumberPrompt.user_prompt
elif schema == "boolean":
system_prompt = (prompts.AnswerWithRAGContextBooleanPrompt.system_prompt_with_schema
if use_schema_prompt else prompts.AnswerWithRAGContextBooleanPrompt.system_prompt)
response_format = prompts.AnswerWithRAGContextBooleanPrompt.AnswerSchema
user_prompt = prompts.AnswerWithRAGContextBooleanPrompt.user_prompt
elif schema == "names":
system_prompt = (prompts.AnswerWithRAGContextNamesPrompt.system_prompt_with_schema
if use_schema_prompt else prompts.AnswerWithRAGContextNamesPrompt.system_prompt)
response_format = prompts.AnswerWithRAGContextNamesPrompt.AnswerSchema
user_prompt = prompts.AnswerWithRAGContextNamesPrompt.user_prompt
elif schema == "comparative":
system_prompt = (prompts.ComparativeAnswerPrompt.system_prompt_with_schema
if use_schema_prompt else prompts.ComparativeAnswerPrompt.system_prompt)
response_format = prompts.ComparativeAnswerPrompt.AnswerSchema
user_prompt = prompts.ComparativeAnswerPrompt.user_prompt
else:
raise ValueError(f"Unsupported schema: {schema}")
return system_prompt, response_format, user_prompt
def get_rephrased_questions(self, original_question: str, companies: List[str]) -> Dict[str, str]:
"""Use LLM to break down a comparative question into individual questions."""
answer_dict = self.processor.send_message(
system_content=prompts.RephrasedQuestionsPrompt.system_prompt,
human_content=prompts.RephrasedQuestionsPrompt.user_prompt.format(
question=original_question,
companies=", ".join([f'"{company}"' for company in companies])
),
is_structured=True,
response_format=prompts.RephrasedQuestionsPrompt.RephrasedQuestions
)
# Convert the answer_dict to the desired format
questions_dict = {item["company_name"]: item["question"] for item in answer_dict["questions"]}
return questions_dict
class AsyncOpenaiProcessor:
def _get_unique_filepath(self, base_filepath):
"""Helper method to get unique filepath"""
if not os.path.exists(base_filepath):
return base_filepath
base, ext = os.path.splitext(base_filepath)
counter = 1
while os.path.exists(f"{base}_{counter}{ext}"):
counter += 1
return f"{base}_{counter}{ext}"
async def process_structured_ouputs_requests(
self,
model="gpt-4o-mini-2024-07-18",
temperature=0.5,
seed=None,
system_content="You are a helpful assistant.",
queries=None,
response_format=None,
requests_filepath='./temp_async_llm_requests.jsonl',
save_filepath='./temp_async_llm_results.jsonl',
preserve_requests=False,
preserve_results=True,
request_url="https://api.openai.com/v1/chat/completions",
max_requests_per_minute=3_500,
max_tokens_per_minute=3_500_000,
token_encoding_name="o200k_base",
max_attempts=5,
logging_level=20,
progress_callback=None
):
# Create requests for jsonl
jsonl_requests = []
for idx, query in enumerate(queries):
request = {
"model": model,
"temperature": temperature,
"seed": seed,
"messages": [
{"role": "system", "content": system_content},
{"role": "user", "content": query},
],
'response_format': type_to_response_format_param(response_format),
'metadata': {'original_index': idx}
}
jsonl_requests.append(request)
# Get unique filepaths if files already exist
requests_filepath = self._get_unique_filepath(requests_filepath)
save_filepath = self._get_unique_filepath(save_filepath)
# Write requests to JSONL file
with open(requests_filepath, "w") as f:
for request in jsonl_requests:
json_string = json.dumps(request)
f.write(json_string + "\n")
# Process API requests
total_requests = len(jsonl_requests)
async def monitor_progress():
last_count = 0
while True:
try:
with open(save_filepath, 'r') as f:
current_count = sum(1 for _ in f)
if current_count > last_count:
if progress_callback:
for _ in range(current_count - last_count):
progress_callback()
last_count = current_count
if current_count >= total_requests:
break
except FileNotFoundError:
pass
await asyncio.sleep(0.1)
async def process_with_progress():
await asyncio.gather(
process_api_requests_from_file(
requests_filepath=requests_filepath,
save_filepath=save_filepath,
request_url=request_url,
api_key=os.getenv("OPENAI_API_KEY"),
max_requests_per_minute=max_requests_per_minute,
max_tokens_per_minute=max_tokens_per_minute,
token_encoding_name=token_encoding_name,
max_attempts=max_attempts,
logging_level=logging_level
),
monitor_progress()
)
await process_with_progress()
with open(save_filepath, "r") as f:
validated_data_list = []
results = []
for line_number, line in enumerate(f, start=1):
raw_line = line.strip()
try:
result = json.loads(raw_line)
except json.JSONDecodeError as e:
print(f"[ERROR] Line {line_number}: Failed to load JSON from line: {raw_line}")
continue
# Check finish_reason in the API response
finish_reason = result[1]['choices'][0].get('finish_reason', '')
if finish_reason != "stop":
print(f"[WARNING] Line {line_number}: finish_reason is '{finish_reason}' (expected 'stop').")
# Safely parse answer; if it fails, leave answer empty and report the error.
try:
answer_content = result[1]['choices'][0]['message']['content']
answer_parsed = json.loads(answer_content)
answer = response_format(**answer_parsed).model_dump()
except Exception as e:
print(f"[ERROR] Line {line_number}: Failed to parse answer JSON. Error: {e}.")
answer = ""
results.append({
'index': result[2],
'question': result[0]['messages'],
'answer': answer
})
# Sort by original index and build final list
validated_data_list = [
{'question': r['question'], 'answer': r['answer']}
for r in sorted(results, key=lambda x: x['index']['original_index'])
]
if not preserve_requests:
os.remove(requests_filepath)
if not preserve_results:
os.remove(save_filepath)
else: # Fix requests order
with open(save_filepath, "r") as f:
results = [json.loads(line) for line in f]
sorted_results = sorted(results, key=lambda x: x[2]['original_index'])
with open(save_filepath, "w") as f:
for result in sorted_results:
json_string = json.dumps(result)
f.write(json_string + "\n")
return validated_data_list
# DashScope基础处理器,支持Qwen大模型对话
class BaseDashscopeProcessor:
def __init__(self):
# 从环境变量读取API-KEY
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
self.default_model = 'qwen-turbo-latest'
def send_message(
self,
model="qwen-turbo-latest",
temperature=0.1,
seed=None, # 兼容参数,暂不使用
system_content='You are a helpful assistant.',
human_content='Hello!',
is_structured=False,
response_format=None,
**kwargs
):
"""
发送消息到DashScope Qwen大模型,支持 system_content + human_content 拼接为 messages。
暂不支持结构化输出。
"""
if model is None:
model = self.default_model
# 拼接 messages
messages = []
if system_content:
messages.append({"role": "system", "content": system_content})
if human_content:
messages.append({"role": "user", "content": human_content})
#print('system_content=', system_content)
#print('='*30)
#print('human_content=', human_content)
#print('='*30)
#print('messages=', messages)
#print('='*30)
# 调用 dashscope Generation.call
response = dashscope.Generation.call(
model=model,
messages=messages,
temperature=temperature,
result_format='message'
)
print('dashscope.api_key=', dashscope.api_key)
print('model=', model)
print('response=', response)
# 兼容 openai/gemini 返回格式,始终返回 dict
if hasattr(response, 'output') and hasattr(response.output, 'choices'):
content = response.output.choices[0].message.content
else:
content = str(response)
# 增加 response_data 属性,保证接口一致性
self.response_data = {"model": model, "input_tokens": None, "output_tokens": None}
print('content=', content)
# 始终返回 dict,避免下游 AttributeError
return {"final_answer": content}