附完整代码
前言
在企业知识库和智能问答系统中,问题处理(Questions Processing)是连接用户查询和知识检索的核心桥梁。本文将深入解析一个获得 RAG 挑战赛冠军的问题处理系统实现,该系统支持单公司查询、多公司比较、并行处理、错误恢复等企业级特性,展示了现代 RAG 系统的完整工程实践。
系统架构概览
该问题处理系统采用了模块化的分层架构:
-
QuestionsProcessor:核心问题处理器,统筹整个问答流程
-
APIProcessor:多提供商 API 处理器,支持 OpenAI、IBM、Gemini、DashScope
-
检索集成:无缝集成向量检索和混合检索
-
并行处理:支持多线程并发和批量处理
-
错误恢复:完善的异常处理和断点续传机制
核心组件详解
1. 问题处理器核心类(QuestionsProcessor)
QuestionsProcessor 是系统的核心控制器,负责协调检索、推理、答案生成等各个环节。
class QuestionsProcessor:
def __init__(
self,
vector_db_dir: Union[str, Path] = './vector_dbs',
documents_dir: Union[str, Path] = './documents',
questions_file_path: Optional[Union[str, Path]] = None,
new_challenge_pipeline: bool = False,
subset_path: Optional[Union[str, Path]] = None,
parent_document_retrieval: bool = False, # 是否启用父文档检索
llm_reranking: bool = False, # 是否启用LLM重排
llm_reranking_sample_size: int = 20,
top_n_retrieval: int = 10,
parallel_requests: int = 10,
api_provider: str = "dashscope", # openai
answering_model: str = "qwen-turbo-latest", # gpt-4o-2024-08-06
full_context: bool = False
):
# 初始化配置参数
self.questions = self._load_questions(questions_file_path)
self.documents_dir = Path(documents_dir)
self.vector_db_dir = Path(vector_db_dir)
# 检索策略配置
self.return_parent_pages = parent_document_retrieval
self.llm_reranking = llm_reranking
self.llm_reranking_sample_size = llm_reranking_sample_size
self.top_n_retrieval = top_n_retrieval
# API和并发配置
self.api_provider = api_provider
self.answering_model = answering_model
self.parallel_requests = parallel_requests
self.openai_processor = APIProcessor(provider=api_provider)
# 线程安全和状态管理
self.answer_details = []
self.detail_counter = 0
self._lock = threading.Lock()
设计亮点:
-
丰富的配置参数,支持多种检索策略
-
多 API 提供商支持,提高系统可用性
-
线程安全设计,支持并发处理
-
灵活的流水线模式切换
2. 单公司问答核心流程
针对单个公司的问答是系统的基础功能:
def get_answer_for_company(self, company_name: str, question: str, schema: str) -> dict:
# 根据配置选择检索器
if self.llm_reranking:
retriever = HybridRetriever(
vector_db_dir=self.vector_db_dir,
documents_dir=self.documents_dir
)
else:
retriever = VectorRetriever(
vector_db_dir=self.vector_db_dir,
documents_dir=self.documents_dir
)
# 执行检索
if self.full_context:
retrieval_results = retriever.retrieve_all(company_name)
else:
retrieval_results = retriever.retrieve_by_company_name(
company_name=company_name,
query=question,
llm_reranking_sample_size=self.llm_reranking_sample_size,
top_n=self.top_n_retrieval,
return_parent_pages=self.return_parent_pages
)
if not retrieval_results:
raise ValueError("No relevant context found")
# 格式化检索结果为RAG上下文
rag_context = self._format_retrieval_results(retrieval_results)
# 调用LLM生成答案
answer_dict = self.openai_processor.get_answer_from_rag_context(
question=question,
rag_context=rag_context,
schema=schema,
model=self.answering_model
)
# 后处理:页码校验和引用提取
if self.new_challenge_pipeline:
pages = answer_dict.get("relevant_pages", [])
validated_pages = self._validate_page_references(pages, retrieval_results)
answer_dict["relevant_pages"] = validated_pages
answer_dict["references"] = self._extract_references(validated_pages, company_name)
return answer_dict
技术特色:
-
智能检索器选择:根据配置自动选择最优检索策略
-
灵活的上下文模式:支持全文档和精确检索两种模式
-
智能页码校验:防止 LLM 幻觉,确保引用准确性
-
结构化输出:支持多种答案类型(name、number、boolean、names)
3. 检索结果格式化
将检索结果转换为 LLM 可理解的上下文格式:
def _format_retrieval_results(self, retrieval_results) -> str:
"""将检索结果格式化为RAG上下文字符串"""
if not retrieval_results:
return ""
context_parts = []
for result in retrieval_results:
page_number = result['page']
text = result['text']
context_parts.append(f'Text retrieved from page {page_number}: \n"""\n{text}\n"""')
return "\n\n---\n\n".join(context_parts)
格式化策略:
-
清晰的页码标识,便于 LLM 理解和引用
-
统一的分隔符,提高解析准确性
-
结构化的文本组织,优化 LLM 理解效果
4. 页码引用校验机制
防止 LLM 产生虚假引用的智能校验系统:
def _validate_page_references(self, claimed_pages: list, retrieval_results: list, min_pages: int = 2, max_pages: int = 8) -> list:
"""
校验LLM答案中引用的页码是否真实存在于检索结果中。
若不足最小页数,则补充检索结果中的top页。
"""
if claimed_pages is None:
claimed_pages = []
# 获取实际检索到的页码
retrieved_pages = [result['page'] for result in retrieval_results]
# 校验声称的页码是否真实存在
validated_pages = [page for page in claimed_pages if page in retrieved_pages]
# 记录被移除的虚假引用
if len(validated_pages) < len(claimed_pages):
removed_pages = set(claimed_pages) - set(validated_pages)
print(f"Warning: Removed {len(removed_pages)} hallucinated page references: {removed_pages}")
# 如果有效页码不足最小要求,自动补充
if len(validated_pages) < min_pages and retrieval_results:
existing_pages = set(validated_pages)
for result in retrieval_results:
page = result['page']
if page not in existing_pages:
validated_pages.append(page)
existing_pages.add(page)
if len(validated_pages) >= min_pages:
break
# 限制最大页码数量
if len(validated_pages) > max_pages:
print(f"Trimming references from {len(validated_pages)} to {max_pages} pages")
validated_pages = validated_pages[:max_pages]
return validated_pages
校验机制优势:
-
幻觉检测:自动识别和移除 LLM 产生的虚假页码
-
智能补充:当引用不足时自动补充高质量页码
-
数量控制:防止引用过多影响答案质量
-
透明日志:详细记录校验过程,便于调试
5. 多公司比较问答
系统的高级功能,支持复杂的多公司对比分析:
def process_comparative_question(self, question: str, companies: List[str], schema: str) -> dict:
"""
处理多公司比较类问题:
1. 先将比较问题重写为单公司问题
2. 并行处理每个公司
3. 汇总结果并生成最终比较答案
"""
# Step 1: 问题重写
rephrased_questions = self.openai_processor.get_rephrased_questions(
original_question=question,
companies=companies
)
individual_answers = {}
aggregated_references = []
# Step 2: 并行处理各公司问题
def process_company_question(company: str) -> tuple[str, dict]:
"""处理单个公司问题的辅助函数"""
sub_question = rephrased_questions.get(company)
if not sub_question:
raise ValueError(f"Could not generate sub-question for company: {company}")
answer_dict = self.get_answer_for_company(
company_name=company,
question=sub_question,
schema="number"
)
return company, answer_dict
# 使用线程池并行处理
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_company = {
executor.submit(process_company_question, company): company
for company in companies
}
for future in concurrent.futures.as_completed(future_to_company):
try:
company, answer_dict = future.result()
individual_answers[company] = answer_dict
# 聚合引用信息
company_references = answer_dict.get("references", [])
aggregated_references.extend(company_references)
except Exception as e:
company = future_to_company[future]
print(f"Error processing company {company}: {str(e)}")
raise
# 去重引用
unique_refs = {}
for ref in aggregated_references:
key = (ref.get("pdf_sha1"), ref.get("page_index"))
unique_refs[key] = ref
aggregated_references = list(unique_refs.values())
# Step 3: 生成比较答案
comparative_answer = self.openai_processor.get_answer_from_rag_context(
question=question,
rag_context=individual_answers,
schema="comparative",
model=self.answering_model
)
comparative_answer["references"] = aggregated_references
return comparative_answer
比较问答特色:
-
智能问题分解:自动将比较问题拆分为单公司问题
-
并行处理:多线程同时处理各公司,提高效率
-
结果聚合:智能合并各公司答案和引用信息
-
去重优化:自动去除重复的引用信息
6. 批量处理与并发控制
支持大规模问题批量处理的高效系统:
def process_questions_list(self, questions_list: List[dict], output_path: str = None, submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = "") -> dict:
# 批量处理问题列表,支持并行与断点保存
total_questions = len(questions_list)
questions_with_index = [{**q, "_question_index": i} for i, q in enumerate(questions_list)]
self.answer_details = [None] * total_questions # 预分配答案详情列表
processed_questions = []
parallel_threads = self.parallel_requests
if parallel_threads <= 1:
# 单线程顺序处理
for question_data in tqdm(questions_with_index, desc="Processing questions"):
processed_question = self._process_single_question(question_data)
processed_questions.append(processed_question)
if output_path:
self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
else:
# 多线程并行处理
with tqdm(total=total_questions, desc="Processing questions") as pbar:
for i in range(0, total_questions, parallel_threads):
batch = questions_with_index[i : i + parallel_threads]
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_threads) as executor:
# executor.map 保证结果顺序与输入一致
batch_results = list(executor.map(self._process_single_question, batch))
processed_questions.extend(batch_results)
if output_path:
self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
pbar.update(len(batch_results))
statistics = self._calculate_statistics(processed_questions, print_stats = True)
return {
"questions": processed_questions,
"answer_details": self.answer_details,
"statistics": statistics
}
并发处理优势:
-
灵活并发控制:支持单线程和多线程两种模式
-
批量处理:按批次处理,平衡效率和资源消耗
-
进度可视化:实时显示处理进度
-
断点续传:支持中途保存和恢复处理
7. 智能错误处理与恢复
完善的异常处理机制,确保系统稳定性:
def _handle_processing_error(self, question_text: str, schema: str, err: Exception, question_index: int) -> dict:
"""
处理问题处理过程中的异常。
记录错误详情并返回包含错误信息的字典。
"""
import traceback
error_message = str(err)
tb = traceback.format_exc()
error_ref = f"#/answer_details/{question_index}"
error_detail = {
"error_traceback": tb,
"self": error_ref
}
# 线程安全的错误记录
with self._lock:
self.answer_details[question_index] = error_detail
# 详细的错误日志
print(f"Error encountered processing question: {question_text}")
print(f"Error type: {type(err).__name__}")
print(f"Error message: {error_message}")
print(f"Full traceback:\n{tb}\n")
# 返回标准化的错误响应
if self.new_challenge_pipeline:
return {
"question_text": question_text,
"kind": schema,
"value": None,
"references": [],
"error": f"{type(err).__name__}: {error_message}",
"answer_details": {"$ref": error_ref}
}
else:
return {
"question": question_text,
"schema": schema,
"answer": None,
"error": f"{type(err).__name__}: {error_message}",
"answer_details": {"$ref": error_ref},
}
错误处理特色:
-
详细错误记录:完整的堆栈跟踪和错误上下文
-
线程安全:多线程环境下的安全错误处理
-
标准化响应:统一的错误响应格式
-
调试友好:丰富的调试信息输出
8. 统计分析与监控
实时的处理统计和性能监控:
def _calculate_statistics(self, processed_questions: List[dict], print_stats: bool = False) -> dict:
"""统计处理结果,包括总数、错误数、N/A数、成功数"""
total_questions = len(processed_questions)
error_count = sum(1 for q in processed_questions if "error" in q)
na_count = sum(1 for q in processed_questions if (q.get("value") if "value" in q else q.get("answer")) == "N/A")
success_count = total_questions - error_count - na_count
if print_stats:
print(f"\nFinal Processing Statistics:")
print(f"Total questions: {total_questions}")
print(f"Errors: {error_count} ({(error_count/total_questions)*100:.1f}%)")
print(f"N/A answers: {na_count} ({(na_count/total_questions)*100:.1f}%)")
print(f"Successfully answered: {success_count} ({(success_count/total_questions)*100:.1f}%)\n")
return {
"total_questions": total_questions,
"error_count": error_count,
"na_count": na_count,
"success_count": success_count
}
API 处理器架构
多提供商统一接口
class APIProcessor:
def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):
self.provider = provider.lower()
if self.provider == "openai":
self.processor = BaseOpenaiProcessor()
elif self.provider == "ibm":
self.processor = BaseIBMAPIProcessor()
elif self.provider == "gemini":
self.processor = BaseGeminiProcessor()
elif self.provider == "dashscope":
self.processor = BaseDashscopeProcessor()
def get_answer_from_rag_context(self, question, rag_context, schema, model):
system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)
answer_dict = self.processor.send_message(
model=model,
system_content=system_prompt,
human_content=user_prompt.format(context=rag_context, question=question),
is_structured=True,
response_format=response_format
)
# 兜底处理:确保返回完整的答案结构
if 'step_by_step_analysis' not in answer_dict:
answer_dict = {
"step_by_step_analysis": "",
"reasoning_summary": "",
"relevant_pages": [],
"final_answer": answer_dict.get("final_answer", "N/A")
}
return answer_dict
API 处理器优势:
-
统一接口:屏蔽不同提供商的 API 差异
-
智能适配:根据提供商特性自动调整参数
-
容错机制:完善的兜底和重试逻辑
-
扩展性:易于添加新的 API 提供商
实际应用场景
1. 企业财务分析
# 单公司财务查询
processor = QuestionsProcessor(
vector_db_dir="./financial_dbs",
documents_dir="./financial_docs",
llm_reranking=True,
api_provider="openai",
answering_model="gpt-4o-2024-08-06"
)
answer = processor.get_answer_for_company(
company_name="Apple Inc.",
question="2023年第四季度净利润是多少?",
schema="number"
)
2. 多公司对比分析
# 多公司比较查询
comparative_answer = processor.process_comparative_question(
question="2023年哪家公司研发投入更高,'Apple Inc.'还是'Microsoft Corporation'?",
companies=["Apple Inc.", "Microsoft Corporation"],
schema="comparative"
)
3. 批量问题处理
# 大规模批量处理
questions_list = [
{"question": "公司CEO是谁?", "schema": "name"},
{"question": "2023年总营收是多少?", "schema": "number"},
{"question": "是否进行了股票回购?", "schema": "boolean"}
]
results = processor.process_questions_list(
questions_list=questions_list,
output_path="./results.json",
submission_file=True,
parallel_requests=5
)
性能优化策略
1. 检索优化
-
智能检索器选择:根据查询类型自动选择最优检索策略
-
缓存机制:缓存常见查询的检索结果
-
批量检索:合并相似查询,减少检索次数
2. 并发优化
-
动态线程池:根据系统负载调整并发数
-
批量处理:平衡并发度和资源消耗
-
负载均衡:在多个 API 提供商间分配请求
3. 内存管理
-
流式处理:大规模数据的流式处理
-
及时释放:处理完成后及时释放资源
-
内存监控:实时监控内存使用情况
系统监控与调试
1. 实时监控
# 处理统计监控
statistics = processor._calculate_statistics(processed_questions, print_stats=True)
print(f"成功率: {(statistics['success_count']/statistics['total_questions'])*100:.1f}%")
print(f"错误率: {(statistics['error_count']/statistics['total_questions'])*100:.1f}%")
2. 详细日志
# 启用详细日志
import logging
logging.basicConfig(level=logging.INFO)
# 自定义日志记录
def log_processing_details(question, answer, processing_time):
logger.info(f"Question: {question}")
logger.info(f"Answer: {answer.get('final_answer', 'N/A')}")
logger.info(f"Processing time: {processing_time:.2f}s")
3. 错误分析
# 错误统计分析
def analyze_errors(processed_questions):
errors = [q for q in processed_questions if "error" in q]
error_types = {}
for error in errors:
error_type = error["error"].split(":")[0]
error_types[error_type] = error_types.get(error_type, 0) + 1
print("Error Analysis:")
for error_type, count in error_types.items():
print(f" {error_type}: {count}")
最佳实践建议
1. 配置优化
# 推荐的生产环境配置
production_config = {
"llm_reranking": True, # 启用重排序提高质量
"parent_document_retrieval": True, # 启用父文档检索
"top_n_retrieval": 10, # 适中的检索数量
"parallel_requests": 5, # 避免API限流
"api_provider": "openai", # 稳定的API提供商
"answering_model": "gpt-4o-2024-08-06" # 高质量模型
}
2. 错误处理
# 完善的错误处理策略
def robust_question_processing(processor, question, max_retries=3):
for attempt in range(max_retries):
try:
return processor.process_question(question["question"], question["schema"])
except Exception as e:
if attempt == max_retries - 1:
return {"error": f"Failed after {max_retries} attempts: {str(e)}"}
time.sleep(2 ** attempt) # 指数退避
3. 性能监控
# 性能监控装饰器
import time
from functools import wraps
def monitor_performance(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
return result
return wrapper
总结
这个问题处理系统展示了企业级 RAG 系统的完整工程实践:
-
模块化架构:清晰的分层设计,易于维护和扩展
-
多模态支持:支持多种问题类型和答案格式
-
并发处理:高效的多线程并行处理机制
-
错误恢复:完善的异常处理和断点续传
-
监控调试:丰富的统计信息和调试工具
-
API 抽象:统一的多提供商 API 接口
-
智能校验:防止 LLM 幻觉的页码校验机制
对于构建企业级智能问答系统,这个实现提供了完整的参考架构和最佳实践。通过合理的设计和优化,可以在保证答案质量的同时,实现高效、稳定的大规模问题处理能力。
参考资源
本文基于 RAG-Challenge-2 获奖项目的问题处理模块源码分析,展示了工业级问答系统的完整实现和优化策略。希望对正在构建类似系统的开发者有所帮助。
完整代码
import json
from typing import Union, Dict, List, Optional
import re
from pathlib import Path
from src.retrieval import VectorRetriever, HybridRetriever
from src.api_requests import APIProcessor
from tqdm import tqdm
import pandas as pd
import threading
import concurrent.futures
class QuestionsProcessor:
def __init__(
self,
vector_db_dir: Union[str, Path] = './vector_dbs',
documents_dir: Union[str, Path] = './documents',
questions_file_path: Optional[Union[str, Path]] = None,
new_challenge_pipeline: bool = False,
subset_path: Optional[Union[str, Path]] = None,
parent_document_retrieval: bool = False, # 是否启用父文档检索
llm_reranking: bool = False, # 是否启用LLM重排
llm_reranking_sample_size: int = 20,
top_n_retrieval: int = 10,
parallel_requests: int = 10,
api_provider: str = "dashscope", # openai
answering_model: str = "qwen-turbo-latest", # gpt-4o-2024-08-06
full_context: bool = False
):
# 初始化问题处理器,配置检索、模型、并发等参数
self.questions = self._load_questions(questions_file_path)
self.documents_dir = Path(documents_dir)
self.vector_db_dir = Path(vector_db_dir)
self.subset_path = Path(subset_path) if subset_path else None
self.new_challenge_pipeline = new_challenge_pipeline
self.return_parent_pages = parent_document_retrieval
self.llm_reranking = llm_reranking
self.llm_reranking_sample_size = llm_reranking_sample_size
self.top_n_retrieval = top_n_retrieval
self.answering_model = answering_model
self.parallel_requests = parallel_requests
self.api_provider = api_provider
self.openai_processor = APIProcessor(provider=api_provider)
self.full_context = full_context
self.answer_details = []
self.detail_counter = 0
self._lock = threading.Lock()
def _load_questions(self, questions_file_path: Optional[Union[str, Path]]) -> List[Dict[str, str]]:
# 加载问题文件,返回问题列表
if questions_file_path is None:
return []
with open(questions_file_path, 'r', encoding='utf-8') as file:
return json.load(file)
def _format_retrieval_results(self, retrieval_results) -> str:
"""将检索结果格式化为RAG上下文字符串"""
if not retrieval_results:
return ""
context_parts = []
for result in retrieval_results:
page_number = result['page']
text = result['text']
context_parts.append(f'Text retrieved from page {page_number}: \n"""\n{text}\n"""')
return "\n\n---\n\n".join(context_parts)
def _extract_references(self, pages_list: list, company_name: str) -> list:
# 根据公司名和页码列表,提取引用信息
if self.subset_path is None:
raise ValueError("subset_path is required for new challenge pipeline when processing references.")
self.companies_df = pd.read_csv(self.subset_path)
# Find the company's SHA1 from the subset CSV
matching_rows = self.companies_df[self.companies_df['company_name'] == company_name]
if matching_rows.empty:
company_sha1 = ""
else:
company_sha1 = matching_rows.iloc[0]['sha1']
refs = []
for page in pages_list:
refs.append({"pdf_sha1": company_sha1, "page_index": page})
return refs
def _validate_page_references(self, claimed_pages: list, retrieval_results: list, min_pages: int = 2, max_pages: int = 8) -> list:
"""
校验LLM答案中引用的页码是否真实存在于检索结果中。
若不足最小页数,则补充检索结果中的top页。
"""
if claimed_pages is None:
claimed_pages = []
retrieved_pages = [result['page'] for result in retrieval_results]
validated_pages = [page for page in claimed_pages if page in retrieved_pages]
if len(validated_pages) < len(claimed_pages):
removed_pages = set(claimed_pages) - set(validated_pages)
print(f"Warning: Removed {len(removed_pages)} hallucinated page references: {removed_pages}")
if len(validated_pages) < min_pages and retrieval_results:
existing_pages = set(validated_pages)
for result in retrieval_results:
page = result['page']
if page not in existing_pages:
validated_pages.append(page)
existing_pages.add(page)
if len(validated_pages) >= min_pages:
break
if len(validated_pages) > max_pages:
print(f"Trimming references from {len(validated_pages)} to {max_pages} pages")
validated_pages = validated_pages[:max_pages]
return validated_pages
def get_answer_for_company(self, company_name: str, question: str, schema: str) -> dict:
# 针对单个公司,检索上下文并调用LLM生成答案
if self.llm_reranking:
retriever = HybridRetriever(
vector_db_dir=self.vector_db_dir,
documents_dir=self.documents_dir
)
else:
retriever = VectorRetriever(
vector_db_dir=self.vector_db_dir,
documents_dir=self.documents_dir
)
if self.full_context:
retrieval_results = retriever.retrieve_all(company_name)
else:
retrieval_results = retriever.retrieve_by_company_name(
company_name=company_name,
query=question,
llm_reranking_sample_size=self.llm_reranking_sample_size,
top_n=self.top_n_retrieval,
return_parent_pages=self.return_parent_pages
)
if not retrieval_results:
raise ValueError("No relevant context found")
rag_context = self._format_retrieval_results(retrieval_results)
answer_dict = self.openai_processor.get_answer_from_rag_context(
question=question,
rag_context=rag_context,
schema=schema,
model=self.answering_model
)
self.response_data = self.openai_processor.response_data
if self.new_challenge_pipeline:
pages = answer_dict.get("relevant_pages", [])
validated_pages = self._validate_page_references(pages, retrieval_results)
answer_dict["relevant_pages"] = validated_pages
answer_dict["references"] = self._extract_references(validated_pages, company_name)
return answer_dict
def _extract_companies_from_subset(self, question_text: str) -> list[str]:
"""从问题文本中提取公司名,匹配subset文件中的公司"""
if not hasattr(self, 'companies_df'):
if self.subset_path is None:
raise ValueError("subset_path must be provided to use subset extraction")
self.companies_df = pd.read_csv(self.subset_path)
found_companies = []
company_names = sorted(self.companies_df['company_name'].unique(), key=len, reverse=True)
for company in company_names:
escaped_company = re.escape(company)
pattern = rf'{escaped_company}(?:\W|$)'
if re.search(pattern, question_text, re.IGNORECASE):
found_companies.append(company)
question_text = re.sub(pattern, '', question_text, flags=re.IGNORECASE)
return found_companies
def process_question(self, question: str, schema: str):
# 处理单个问题,支持多公司比较
if self.new_challenge_pipeline:
extracted_companies = self._extract_companies_from_subset(question)
else:
extracted_companies = re.findall(r'"([^"]*)"', question)
if len(extracted_companies) == 0:
raise ValueError("No company name found in the question.")
if len(extracted_companies) == 1:
company_name = extracted_companies[0]
answer_dict = self.get_answer_for_company(company_name=company_name, question=question, schema=schema)
return answer_dict
else:
return self.process_comparative_question(question, extracted_companies, schema)
def _create_answer_detail_ref(self, answer_dict: dict, question_index: int) -> str:
"""创建答案详情的引用ID,并存储详细内容"""
ref_id = f"#/answer_details/{question_index}"
with self._lock:
self.answer_details[question_index] = {
"step_by_step_analysis": answer_dict['step_by_step_analysis'],
"reasoning_summary": answer_dict['reasoning_summary'],
"relevant_pages": answer_dict['relevant_pages'],
"response_data": self.response_data,
"self": ref_id
}
return ref_id
def _calculate_statistics(self, processed_questions: List[dict], print_stats: bool = False) -> dict:
"""统计处理结果,包括总数、错误数、N/A数、成功数"""
total_questions = len(processed_questions)
error_count = sum(1 for q in processed_questions if "error" in q)
na_count = sum(1 for q in processed_questions if (q.get("value") if "value" in q else q.get("answer")) == "N/A")
success_count = total_questions - error_count - na_count
if print_stats:
print(f"\nFinal Processing Statistics:")
print(f"Total questions: {total_questions}")
print(f"Errors: {error_count} ({(error_count/total_questions)*100:.1f}%)")
print(f"N/A answers: {na_count} ({(na_count/total_questions)*100:.1f}%)")
print(f"Successfully answered: {success_count} ({(success_count/total_questions)*100:.1f}%)\n")
return {
"total_questions": total_questions,
"error_count": error_count,
"na_count": na_count,
"success_count": success_count
}
def process_questions_list(self, questions_list: List[dict], output_path: str = None, submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = "") -> dict:
# 批量处理问题列表,支持并行与断点保存,返回处理结果和统计信息
total_questions = len(questions_list)
# 给每个问题加索引,便于后续答案详情定位
questions_with_index = [{**q, "_question_index": i} for i, q in enumerate(questions_list)]
self.answer_details = [None] * total_questions # 预分配答案详情列表
processed_questions = []
parallel_threads = self.parallel_requests
if parallel_threads <= 1:
# 单线程顺序处理
for question_data in tqdm(questions_with_index, desc="Processing questions"):
processed_question = self._process_single_question(question_data)
processed_questions.append(processed_question)
if output_path:
self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
else:
# 多线程并行处理
with tqdm(total=total_questions, desc="Processing questions") as pbar:
for i in range(0, total_questions, parallel_threads):
batch = questions_with_index[i : i + parallel_threads]
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_threads) as executor:
# executor.map 保证结果顺序与输入一致
batch_results = list(executor.map(self._process_single_question, batch))
processed_questions.extend(batch_results)
if output_path:
self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
pbar.update(len(batch_results))
statistics = self._calculate_statistics(processed_questions, print_stats = True)
return {
"questions": processed_questions,
"answer_details": self.answer_details,
"statistics": statistics
}
def _process_single_question(self, question_data: dict) -> dict:
question_index = question_data.get("_question_index", 0)
if self.new_challenge_pipeline:
question_text = question_data.get("text")
schema = question_data.get("kind")
else:
question_text = question_data.get("question")
schema = question_data.get("schema")
try:
answer_dict = self.process_question(question_text, schema)
if "error" in answer_dict:
detail_ref = self._create_answer_detail_ref({
"step_by_step_analysis": None,
"reasoning_summary": None,
"relevant_pages": None
}, question_index)
if self.new_challenge_pipeline:
return {
"question_text": question_text,
"kind": schema,
"value": None,
"references": [],
"error": answer_dict["error"],
"answer_details": {"$ref": detail_ref}
}
else:
return {
"question": question_text,
"schema": schema,
"answer": None,
"error": answer_dict["error"],
"answer_details": {"$ref": detail_ref},
}
detail_ref = self._create_answer_detail_ref(answer_dict, question_index)
if self.new_challenge_pipeline:
return {
"question_text": question_text,
"kind": schema,
"value": answer_dict.get("final_answer"),
"references": answer_dict.get("references", []),
"answer_details": {"$ref": detail_ref}
}
else:
return {
"question": question_text,
"schema": schema,
"answer": answer_dict.get("final_answer"),
"answer_details": {"$ref": detail_ref},
}
except Exception as err:
return self._handle_processing_error(question_text, schema, err, question_index)
def _handle_processing_error(self, question_text: str, schema: str, err: Exception, question_index: int) -> dict:
"""
处理问题处理过程中的异常。
记录错误详情并返回包含错误信息的字典。
"""
import traceback
error_message = str(err)
tb = traceback.format_exc()
error_ref = f"#/answer_details/{question_index}"
error_detail = {
"error_traceback": tb,
"self": error_ref
}
with self._lock:
self.answer_details[question_index] = error_detail
print(f"Error encountered processing question: {question_text}")
print(f"Error type: {type(err).__name__}")
print(f"Error message: {error_message}")
print(f"Full traceback:\n{tb}\n")
if self.new_challenge_pipeline:
return {
"question_text": question_text,
"kind": schema,
"value": None,
"references": [],
"error": f"{type(err).__name__}: {error_message}",
"answer_details": {"$ref": error_ref}
}
else:
return {
"question": question_text,
"schema": schema,
"answer": None,
"error": f"{type(err).__name__}: {error_message}",
"answer_details": {"$ref": error_ref},
}
def _post_process_submission_answers(self, processed_questions: List[dict]) -> List[dict]:
"""
提交格式后处理:
1. 页码从1-based转为0-based
2. N/A答案清空引用
3. 格式化为比赛提交schema
4. 包含step_by_step_analysis
"""
submission_answers = []
for q in processed_questions:
question_text = q.get("question_text") or q.get("question")
kind = q.get("kind") or q.get("schema")
value = "N/A" if "error" in q else (q.get("value") if "value" in q else q.get("answer"))
references = q.get("references", [])
answer_details_ref = q.get("answer_details", {}).get("$ref", "")
step_by_step_analysis = None
if answer_details_ref and answer_details_ref.startswith("#/answer_details/"):
try:
index = int(answer_details_ref.split("/")[-1])
if 0 <= index < len(self.answer_details) and self.answer_details[index]:
step_by_step_analysis = self.answer_details[index].get("step_by_step_analysis")
except (ValueError, IndexError):
pass
# Clear references if value is N/A
if value == "N/A":
references = []
else:
# Convert page indices from one-based to zero-based (competition requires 0-based page indices, but for debugging it is easier to use 1-based)
references = [
{
"pdf_sha1": ref["pdf_sha1"],
"page_index": ref["page_index"] - 1
}
for ref in references
]
submission_answer = {
"question_text": question_text,
"kind": kind,
"value": value,
"references": references,
}
if step_by_step_analysis:
submission_answer["reasoning_process"] = step_by_step_analysis
submission_answers.append(submission_answer)
return submission_answers
def _save_progress(self, processed_questions: List[dict], output_path: Optional[str], submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = ""):
if output_path:
statistics = self._calculate_statistics(processed_questions)
# Prepare debug content
result = {
"questions": processed_questions,
"answer_details": self.answer_details,
"statistics": statistics
}
output_file = Path(output_path)
debug_file = output_file.with_name(output_file.stem + "_debug" + output_file.suffix)
with open(debug_file, 'w', encoding='utf-8') as file:
json.dump(result, file, ensure_ascii=False, indent=2)
if submission_file:
# Post-process answers for submission
submission_answers = self._post_process_submission_answers(processed_questions)
submission = {
"answers": submission_answers,
"team_email": team_email,
"submission_name": submission_name,
"details": pipeline_details
}
with open(output_file, 'w', encoding='utf-8') as file:
json.dump(submission, file, ensure_ascii=False, indent=2)
def process_all_questions(self, output_path: str = 'questions_with_answers.json', team_email: str = "79250515615@yandex.com", submission_name: str = "Ilia_Ris SO CoT + Parent Document Retrieval", submission_file: bool = False, pipeline_details: str = ""):
result = self.process_questions_list(
self.questions,
output_path,
submission_file=submission_file,
team_email=team_email,
submission_name=submission_name,
pipeline_details=pipeline_details
)
return result
def process_comparative_question(self, question: str, companies: List[str], schema: str) -> dict:
"""
处理多公司比较类问题:
1. 先将比较问题重写为单公司问题
2. 并行处理每个公司
3. 汇总结果并生成最终比较答案
"""
# Step 1: Rephrase the comparative question
rephrased_questions = self.openai_processor.get_rephrased_questions(
original_question=question,
companies=companies
)
individual_answers = {}
aggregated_references = []
# Step 2: Process each individual question in parallel
def process_company_question(company: str) -> tuple[str, dict]:
"""Helper function to process one company's question and return (company, answer)"""
sub_question = rephrased_questions.get(company)
if not sub_question:
raise ValueError(f"Could not generate sub-question for company: {company}")
answer_dict = self.get_answer_for_company(
company_name=company,
question=sub_question,
schema="number"
)
return company, answer_dict
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_company = {
executor.submit(process_company_question, company): company
for company in companies
}
for future in concurrent.futures.as_completed(future_to_company):
try:
company, answer_dict = future.result()
individual_answers[company] = answer_dict
company_references = answer_dict.get("references", [])
aggregated_references.extend(company_references)
except Exception as e:
company = future_to_company[future]
print(f"Error processing company {company}: {str(e)}")
raise
# Remove duplicate references
unique_refs = {}
for ref in aggregated_references:
key = (ref.get("pdf_sha1"), ref.get("page_index"))
unique_refs[key] = ref
aggregated_references = list(unique_refs.values())
# Step 3: Get the comparative answer using all individual answers
comparative_answer = self.openai_processor.get_answer_from_rag_context(
question=question,
rag_context=individual_answers,
schema="comparative",
model=self.answering_model
)
self.response_data = self.openai_processor.response_data
comparative_answer["references"] = aggregated_references
return comparative_answer