RAG 问题处理系统架构解析:企业级智能问答QuestionsProcessor.py的工程实现

附完整代码

前言

在企业知识库和智能问答系统中,问题处理(Questions Processing)是连接用户查询和知识检索的核心桥梁。本文将深入解析一个获得 RAG 挑战赛冠军的问题处理系统实现,该系统支持单公司查询、多公司比较、并行处理、错误恢复等企业级特性,展示了现代 RAG 系统的完整工程实践。

系统架构概览

该问题处理系统采用了模块化的分层架构:

  1. QuestionsProcessor:核心问题处理器,统筹整个问答流程

  2. APIProcessor:多提供商 API 处理器,支持 OpenAI、IBM、Gemini、DashScope

  3. 检索集成:无缝集成向量检索和混合检索

  4. 并行处理:支持多线程并发和批量处理

  5. 错误恢复:完善的异常处理和断点续传机制

核心组件详解

1. 问题处理器核心类(QuestionsProcessor)

QuestionsProcessor 是系统的核心控制器,负责协调检索、推理、答案生成等各个环节。

复制代码
class QuestionsProcessor:
    def __init__(
        self,
        vector_db_dir: Union[str, Path] = './vector_dbs',
        documents_dir: Union[str, Path] = './documents',
        questions_file_path: Optional[Union[str, Path]] = None,
        new_challenge_pipeline: bool = False,
        subset_path: Optional[Union[str, Path]] = None,
        parent_document_retrieval: bool = False,  # 是否启用父文档检索
        llm_reranking: bool = False,              # 是否启用LLM重排
        llm_reranking_sample_size: int = 20,
        top_n_retrieval: int = 10,
        parallel_requests: int = 10,
        api_provider: str = "dashscope", # openai
        answering_model: str = "qwen-turbo-latest", # gpt-4o-2024-08-06
        full_context: bool = False
    ):
        # 初始化配置参数
        self.questions = self._load_questions(questions_file_path)
        self.documents_dir = Path(documents_dir)
        self.vector_db_dir = Path(vector_db_dir)
        
        # 检索策略配置
        self.return_parent_pages = parent_document_retrieval
        self.llm_reranking = llm_reranking
        self.llm_reranking_sample_size = llm_reranking_sample_size
        self.top_n_retrieval = top_n_retrieval
        
        # API和并发配置
        self.api_provider = api_provider
        self.answering_model = answering_model
        self.parallel_requests = parallel_requests
        self.openai_processor = APIProcessor(provider=api_provider)
        
        # 线程安全和状态管理
        self.answer_details = []
        self.detail_counter = 0
        self._lock = threading.Lock()

设计亮点

  • 丰富的配置参数,支持多种检索策略

  • 多 API 提供商支持,提高系统可用性

  • 线程安全设计,支持并发处理

  • 灵活的流水线模式切换

2. 单公司问答核心流程

针对单个公司的问答是系统的基础功能:

复制代码
def get_answer_for_company(self, company_name: str, question: str, schema: str) -> dict:
    # 根据配置选择检索器
    if self.llm_reranking:
        retriever = HybridRetriever(
            vector_db_dir=self.vector_db_dir,
            documents_dir=self.documents_dir
        )
    else:
        retriever = VectorRetriever(
            vector_db_dir=self.vector_db_dir,
            documents_dir=self.documents_dir
        )
​
    # 执行检索
    if self.full_context:
        retrieval_results = retriever.retrieve_all(company_name)
    else:           
        retrieval_results = retriever.retrieve_by_company_name(
            company_name=company_name,
            query=question,
            llm_reranking_sample_size=self.llm_reranking_sample_size,
            top_n=self.top_n_retrieval,
            return_parent_pages=self.return_parent_pages
        )
    
    if not retrieval_results:
        raise ValueError("No relevant context found")
    
    # 格式化检索结果为RAG上下文
    rag_context = self._format_retrieval_results(retrieval_results)
    
    # 调用LLM生成答案
    answer_dict = self.openai_processor.get_answer_from_rag_context(
        question=question,
        rag_context=rag_context,
        schema=schema,
        model=self.answering_model
    )
    
    # 后处理:页码校验和引用提取
    if self.new_challenge_pipeline:
        pages = answer_dict.get("relevant_pages", [])
        validated_pages = self._validate_page_references(pages, retrieval_results)
        answer_dict["relevant_pages"] = validated_pages
        answer_dict["references"] = self._extract_references(validated_pages, company_name)
    
    return answer_dict

技术特色

  • 智能检索器选择:根据配置自动选择最优检索策略

  • 灵活的上下文模式:支持全文档和精确检索两种模式

  • 智能页码校验:防止 LLM 幻觉,确保引用准确性

  • 结构化输出:支持多种答案类型(name、number、boolean、names)

3. 检索结果格式化

将检索结果转换为 LLM 可理解的上下文格式:

复制代码
def _format_retrieval_results(self, retrieval_results) -> str:
    """将检索结果格式化为RAG上下文字符串"""
    if not retrieval_results:
        return ""
    
    context_parts = []
    for result in retrieval_results:
        page_number = result['page']
        text = result['text']
        context_parts.append(f'Text retrieved from page {page_number}: \n"""\n{text}\n"""')
        
    return "\n\n---\n\n".join(context_parts)

格式化策略

  • 清晰的页码标识,便于 LLM 理解和引用

  • 统一的分隔符,提高解析准确性

  • 结构化的文本组织,优化 LLM 理解效果

4. 页码引用校验机制

防止 LLM 产生虚假引用的智能校验系统:

复制代码
def _validate_page_references(self, claimed_pages: list, retrieval_results: list, min_pages: int = 2, max_pages: int = 8) -> list:
    """
    校验LLM答案中引用的页码是否真实存在于检索结果中。
    若不足最小页数,则补充检索结果中的top页。
    """
    if claimed_pages is None:
        claimed_pages = []
    
    # 获取实际检索到的页码
    retrieved_pages = [result['page'] for result in retrieval_results]
    
    # 校验声称的页码是否真实存在
    validated_pages = [page for page in claimed_pages if page in retrieved_pages]
    
    # 记录被移除的虚假引用
    if len(validated_pages) < len(claimed_pages):
        removed_pages = set(claimed_pages) - set(validated_pages)
        print(f"Warning: Removed {len(removed_pages)} hallucinated page references: {removed_pages}")
    
    # 如果有效页码不足最小要求,自动补充
    if len(validated_pages) < min_pages and retrieval_results:
        existing_pages = set(validated_pages)
        
        for result in retrieval_results:
            page = result['page']
            if page not in existing_pages:
                validated_pages.append(page)
                existing_pages.add(page)
                
                if len(validated_pages) >= min_pages:
                    break
    
    # 限制最大页码数量
    if len(validated_pages) > max_pages:
        print(f"Trimming references from {len(validated_pages)} to {max_pages} pages")
        validated_pages = validated_pages[:max_pages]
    
    return validated_pages

校验机制优势

  • 幻觉检测:自动识别和移除 LLM 产生的虚假页码

  • 智能补充:当引用不足时自动补充高质量页码

  • 数量控制:防止引用过多影响答案质量

  • 透明日志:详细记录校验过程,便于调试

5. 多公司比较问答

系统的高级功能,支持复杂的多公司对比分析:

复制代码
def process_comparative_question(self, question: str, companies: List[str], schema: str) -> dict:
    """
    处理多公司比较类问题:
    1. 先将比较问题重写为单公司问题
    2. 并行处理每个公司
    3. 汇总结果并生成最终比较答案
    """
    # Step 1: 问题重写
    rephrased_questions = self.openai_processor.get_rephrased_questions(
        original_question=question,
        companies=companies
    )
    
    individual_answers = {}
    aggregated_references = []
    
    # Step 2: 并行处理各公司问题
    def process_company_question(company: str) -> tuple[str, dict]:
        """处理单个公司问题的辅助函数"""
        sub_question = rephrased_questions.get(company)
        if not sub_question:
            raise ValueError(f"Could not generate sub-question for company: {company}")
        
        answer_dict = self.get_answer_for_company(
            company_name=company, 
            question=sub_question, 
            schema="number"
        )
        return company, answer_dict
​
    # 使用线程池并行处理
    with concurrent.futures.ThreadPoolExecutor() as executor:
        future_to_company = {
            executor.submit(process_company_question, company): company 
            for company in companies
        }
        
        for future in concurrent.futures.as_completed(future_to_company):
            try:
                company, answer_dict = future.result()
                individual_answers[company] = answer_dict
                
                # 聚合引用信息
                company_references = answer_dict.get("references", [])
                aggregated_references.extend(company_references)
            except Exception as e:
                company = future_to_company[future]
                print(f"Error processing company {company}: {str(e)}")
                raise
    
    # 去重引用
    unique_refs = {}
    for ref in aggregated_references:
        key = (ref.get("pdf_sha1"), ref.get("page_index"))
        unique_refs[key] = ref
    aggregated_references = list(unique_refs.values())
    
    # Step 3: 生成比较答案
    comparative_answer = self.openai_processor.get_answer_from_rag_context(
        question=question,
        rag_context=individual_answers,
        schema="comparative",
        model=self.answering_model
    )
    
    comparative_answer["references"] = aggregated_references
    return comparative_answer

比较问答特色

  • 智能问题分解:自动将比较问题拆分为单公司问题

  • 并行处理:多线程同时处理各公司,提高效率

  • 结果聚合:智能合并各公司答案和引用信息

  • 去重优化:自动去除重复的引用信息

6. 批量处理与并发控制

支持大规模问题批量处理的高效系统:

复制代码
def process_questions_list(self, questions_list: List[dict], output_path: str = None, submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = "") -> dict:
    # 批量处理问题列表,支持并行与断点保存
    total_questions = len(questions_list)
    questions_with_index = [{**q, "_question_index": i} for i, q in enumerate(questions_list)]
    self.answer_details = [None] * total_questions  # 预分配答案详情列表
    processed_questions = []
    parallel_threads = self.parallel_requests
​
    if parallel_threads <= 1:
        # 单线程顺序处理
        for question_data in tqdm(questions_with_index, desc="Processing questions"):
            processed_question = self._process_single_question(question_data)
            processed_questions.append(processed_question)
            if output_path:
                self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
    else:
        # 多线程并行处理
        with tqdm(total=total_questions, desc="Processing questions") as pbar:
            for i in range(0, total_questions, parallel_threads):
                batch = questions_with_index[i : i + parallel_threads]
                with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_threads) as executor:
                    # executor.map 保证结果顺序与输入一致
                    batch_results = list(executor.map(self._process_single_question, batch))
                processed_questions.extend(batch_results)
                
                if output_path:
                    self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
                pbar.update(len(batch_results))
    
    statistics = self._calculate_statistics(processed_questions, print_stats = True)
    
    return {
        "questions": processed_questions,
        "answer_details": self.answer_details,
        "statistics": statistics
    }

并发处理优势

  • 灵活并发控制:支持单线程和多线程两种模式

  • 批量处理:按批次处理,平衡效率和资源消耗

  • 进度可视化:实时显示处理进度

  • 断点续传:支持中途保存和恢复处理

7. 智能错误处理与恢复

完善的异常处理机制,确保系统稳定性:

复制代码
def _handle_processing_error(self, question_text: str, schema: str, err: Exception, question_index: int) -> dict:
    """
    处理问题处理过程中的异常。
    记录错误详情并返回包含错误信息的字典。
    """
    import traceback
    error_message = str(err)
    tb = traceback.format_exc()
    error_ref = f"#/answer_details/{question_index}"
    error_detail = {
        "error_traceback": tb,
        "self": error_ref
    }
    
    # 线程安全的错误记录
    with self._lock:
        self.answer_details[question_index] = error_detail
    
    # 详细的错误日志
    print(f"Error encountered processing question: {question_text}")
    print(f"Error type: {type(err).__name__}")
    print(f"Error message: {error_message}")
    print(f"Full traceback:\n{tb}\n")
    
    # 返回标准化的错误响应
    if self.new_challenge_pipeline:
        return {
            "question_text": question_text,
            "kind": schema,
            "value": None,
            "references": [],
            "error": f"{type(err).__name__}: {error_message}",
            "answer_details": {"$ref": error_ref}
        }
    else:
        return {
            "question": question_text,
            "schema": schema,
            "answer": None,
            "error": f"{type(err).__name__}: {error_message}",
            "answer_details": {"$ref": error_ref},
        }

错误处理特色

  • 详细错误记录:完整的堆栈跟踪和错误上下文

  • 线程安全:多线程环境下的安全错误处理

  • 标准化响应:统一的错误响应格式

  • 调试友好:丰富的调试信息输出

8. 统计分析与监控

实时的处理统计和性能监控:

复制代码
def _calculate_statistics(self, processed_questions: List[dict], print_stats: bool = False) -> dict:
    """统计处理结果,包括总数、错误数、N/A数、成功数"""
    total_questions = len(processed_questions)
    error_count = sum(1 for q in processed_questions if "error" in q)
    na_count = sum(1 for q in processed_questions if (q.get("value") if "value" in q else q.get("answer")) == "N/A")
    success_count = total_questions - error_count - na_count
    
    if print_stats:
        print(f"\nFinal Processing Statistics:")
        print(f"Total questions: {total_questions}")
        print(f"Errors: {error_count} ({(error_count/total_questions)*100:.1f}%)")
        print(f"N/A answers: {na_count} ({(na_count/total_questions)*100:.1f}%)")
        print(f"Successfully answered: {success_count} ({(success_count/total_questions)*100:.1f}%)\n")
    
    return {
        "total_questions": total_questions,
        "error_count": error_count,
        "na_count": na_count,
        "success_count": success_count
    }

API 处理器架构

多提供商统一接口

复制代码
class APIProcessor:
    def __init__(self, provider: Literal["openai", "ibm", "gemini", "dashscope"] ="dashscope"):
        self.provider = provider.lower()
        if self.provider == "openai":
            self.processor = BaseOpenaiProcessor()
        elif self.provider == "ibm":
            self.processor = BaseIBMAPIProcessor()
        elif self.provider == "gemini":
            self.processor = BaseGeminiProcessor()
        elif self.provider == "dashscope":
            self.processor = BaseDashscopeProcessor()

    def get_answer_from_rag_context(self, question, rag_context, schema, model):
        system_prompt, response_format, user_prompt = self._build_rag_context_prompts(schema)
        
        answer_dict = self.processor.send_message(
            model=model,
            system_content=system_prompt,
            human_content=user_prompt.format(context=rag_context, question=question),
            is_structured=True,
            response_format=response_format
        )
        
        # 兜底处理:确保返回完整的答案结构
        if 'step_by_step_analysis' not in answer_dict:
            answer_dict = {
                "step_by_step_analysis": "",
                "reasoning_summary": "",
                "relevant_pages": [],
                "final_answer": answer_dict.get("final_answer", "N/A")
            }
        return answer_dict

API 处理器优势

  • 统一接口:屏蔽不同提供商的 API 差异

  • 智能适配:根据提供商特性自动调整参数

  • 容错机制:完善的兜底和重试逻辑

  • 扩展性:易于添加新的 API 提供商

实际应用场景

1. 企业财务分析

复制代码
# 单公司财务查询
processor = QuestionsProcessor(
    vector_db_dir="./financial_dbs",
    documents_dir="./financial_docs",
    llm_reranking=True,
    api_provider="openai",
    answering_model="gpt-4o-2024-08-06"
)

answer = processor.get_answer_for_company(
    company_name="Apple Inc.",
    question="2023年第四季度净利润是多少?",
    schema="number"
)

2. 多公司对比分析

复制代码
# 多公司比较查询
comparative_answer = processor.process_comparative_question(
    question="2023年哪家公司研发投入更高,'Apple Inc.'还是'Microsoft Corporation'?",
    companies=["Apple Inc.", "Microsoft Corporation"],
    schema="comparative"
)

3. 批量问题处理

复制代码
# 大规模批量处理
questions_list = [
    {"question": "公司CEO是谁?", "schema": "name"},
    {"question": "2023年总营收是多少?", "schema": "number"},
    {"question": "是否进行了股票回购?", "schema": "boolean"}
]

results = processor.process_questions_list(
    questions_list=questions_list,
    output_path="./results.json",
    submission_file=True,
    parallel_requests=5
)

性能优化策略

1. 检索优化

  • 智能检索器选择:根据查询类型自动选择最优检索策略

  • 缓存机制:缓存常见查询的检索结果

  • 批量检索:合并相似查询,减少检索次数

2. 并发优化

  • 动态线程池:根据系统负载调整并发数

  • 批量处理:平衡并发度和资源消耗

  • 负载均衡:在多个 API 提供商间分配请求

3. 内存管理

  • 流式处理:大规模数据的流式处理

  • 及时释放:处理完成后及时释放资源

  • 内存监控:实时监控内存使用情况

系统监控与调试

1. 实时监控

复制代码
# 处理统计监控
statistics = processor._calculate_statistics(processed_questions, print_stats=True)
print(f"成功率: {(statistics['success_count']/statistics['total_questions'])*100:.1f}%")
print(f"错误率: {(statistics['error_count']/statistics['total_questions'])*100:.1f}%")

2. 详细日志

复制代码
# 启用详细日志
import logging
logging.basicConfig(level=logging.INFO)

# 自定义日志记录
def log_processing_details(question, answer, processing_time):
    logger.info(f"Question: {question}")
    logger.info(f"Answer: {answer.get('final_answer', 'N/A')}")
    logger.info(f"Processing time: {processing_time:.2f}s")

3. 错误分析

复制代码
# 错误统计分析
def analyze_errors(processed_questions):
    errors = [q for q in processed_questions if "error" in q]
    error_types = {}
    for error in errors:
        error_type = error["error"].split(":")[0]
        error_types[error_type] = error_types.get(error_type, 0) + 1
    
    print("Error Analysis:")
    for error_type, count in error_types.items():
        print(f"  {error_type}: {count}")

最佳实践建议

1. 配置优化

复制代码
# 推荐的生产环境配置
production_config = {
    "llm_reranking": True,           # 启用重排序提高质量
    "parent_document_retrieval": True, # 启用父文档检索
    "top_n_retrieval": 10,           # 适中的检索数量
    "parallel_requests": 5,          # 避免API限流
    "api_provider": "openai",        # 稳定的API提供商
    "answering_model": "gpt-4o-2024-08-06"  # 高质量模型
}

2. 错误处理

复制代码
# 完善的错误处理策略
def robust_question_processing(processor, question, max_retries=3):
    for attempt in range(max_retries):
        try:
            return processor.process_question(question["question"], question["schema"])
        except Exception as e:
            if attempt == max_retries - 1:
                return {"error": f"Failed after {max_retries} attempts: {str(e)}"}
            time.sleep(2 ** attempt)  # 指数退避

3. 性能监控

复制代码
# 性能监控装饰器
import time
from functools import wraps

def monitor_performance(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
        return result
    return wrapper

总结

这个问题处理系统展示了企业级 RAG 系统的完整工程实践:

  1. 模块化架构:清晰的分层设计,易于维护和扩展

  2. 多模态支持:支持多种问题类型和答案格式

  3. 并发处理:高效的多线程并行处理机制

  4. 错误恢复:完善的异常处理和断点续传

  5. 监控调试:丰富的统计信息和调试工具

  6. API 抽象:统一的多提供商 API 接口

  7. 智能校验:防止 LLM 幻觉的页码校验机制

对于构建企业级智能问答系统,这个实现提供了完整的参考架构和最佳实践。通过合理的设计和优化,可以在保证答案质量的同时,实现高效、稳定的大规模问题处理能力。

参考资源


本文基于 RAG-Challenge-2 获奖项目的问题处理模块源码分析,展示了工业级问答系统的完整实现和优化策略。希望对正在构建类似系统的开发者有所帮助。

完整代码

复制代码
import json
from typing import Union, Dict, List, Optional
import re
from pathlib import Path
from src.retrieval import VectorRetriever, HybridRetriever
from src.api_requests import APIProcessor
from tqdm import tqdm
import pandas as pd
import threading
import concurrent.futures


class QuestionsProcessor:
    def __init__(
        self,
        vector_db_dir: Union[str, Path] = './vector_dbs',
        documents_dir: Union[str, Path] = './documents',
        questions_file_path: Optional[Union[str, Path]] = None,
        new_challenge_pipeline: bool = False,
        subset_path: Optional[Union[str, Path]] = None,
        parent_document_retrieval: bool = False,  # 是否启用父文档检索
        llm_reranking: bool = False,              # 是否启用LLM重排
        llm_reranking_sample_size: int = 20,
        top_n_retrieval: int = 10,
        parallel_requests: int = 10,
        api_provider: str = "dashscope", # openai
        answering_model: str = "qwen-turbo-latest", # gpt-4o-2024-08-06
        full_context: bool = False
    ):
        # 初始化问题处理器,配置检索、模型、并发等参数
        self.questions = self._load_questions(questions_file_path)
        self.documents_dir = Path(documents_dir)
        self.vector_db_dir = Path(vector_db_dir)
        self.subset_path = Path(subset_path) if subset_path else None
        
        self.new_challenge_pipeline = new_challenge_pipeline
        self.return_parent_pages = parent_document_retrieval
        self.llm_reranking = llm_reranking
        self.llm_reranking_sample_size = llm_reranking_sample_size
        self.top_n_retrieval = top_n_retrieval
        self.answering_model = answering_model
        self.parallel_requests = parallel_requests
        self.api_provider = api_provider
        self.openai_processor = APIProcessor(provider=api_provider)
        self.full_context = full_context

        self.answer_details = []
        self.detail_counter = 0
        self._lock = threading.Lock()

    def _load_questions(self, questions_file_path: Optional[Union[str, Path]]) -> List[Dict[str, str]]:
        # 加载问题文件,返回问题列表
        if questions_file_path is None:
            return []
        with open(questions_file_path, 'r', encoding='utf-8') as file:
            return json.load(file)

    def _format_retrieval_results(self, retrieval_results) -> str:
        """将检索结果格式化为RAG上下文字符串"""
        if not retrieval_results:
            return ""
        
        context_parts = []
        for result in retrieval_results:
            page_number = result['page']
            text = result['text']
            context_parts.append(f'Text retrieved from page {page_number}: \n"""\n{text}\n"""')
            
        return "\n\n---\n\n".join(context_parts)

    def _extract_references(self, pages_list: list, company_name: str) -> list:
        # 根据公司名和页码列表,提取引用信息
        if self.subset_path is None:
            raise ValueError("subset_path is required for new challenge pipeline when processing references.")
        self.companies_df = pd.read_csv(self.subset_path)

        # Find the company's SHA1 from the subset CSV
        matching_rows = self.companies_df[self.companies_df['company_name'] == company_name]
        if matching_rows.empty:
            company_sha1 = ""
        else:
            company_sha1 = matching_rows.iloc[0]['sha1']

        refs = []
        for page in pages_list:
            refs.append({"pdf_sha1": company_sha1, "page_index": page})
        return refs

    def _validate_page_references(self, claimed_pages: list, retrieval_results: list, min_pages: int = 2, max_pages: int = 8) -> list:
        """
        校验LLM答案中引用的页码是否真实存在于检索结果中。
        若不足最小页数,则补充检索结果中的top页。
        """
        if claimed_pages is None:
            claimed_pages = []
        
        retrieved_pages = [result['page'] for result in retrieval_results]
        
        validated_pages = [page for page in claimed_pages if page in retrieved_pages]
        
        if len(validated_pages) < len(claimed_pages):
            removed_pages = set(claimed_pages) - set(validated_pages)
            print(f"Warning: Removed {len(removed_pages)} hallucinated page references: {removed_pages}")
        
        if len(validated_pages) < min_pages and retrieval_results:
            existing_pages = set(validated_pages)
            
            for result in retrieval_results:
                page = result['page']
                if page not in existing_pages:
                    validated_pages.append(page)
                    existing_pages.add(page)
                    
                    if len(validated_pages) >= min_pages:
                        break
        
        if len(validated_pages) > max_pages:
            print(f"Trimming references from {len(validated_pages)} to {max_pages} pages")
            validated_pages = validated_pages[:max_pages]
        
        return validated_pages

    def get_answer_for_company(self, company_name: str, question: str, schema: str) -> dict:
        # 针对单个公司,检索上下文并调用LLM生成答案
        if self.llm_reranking:
            retriever = HybridRetriever(
                vector_db_dir=self.vector_db_dir,
                documents_dir=self.documents_dir
            )
        else:
            retriever = VectorRetriever(
                vector_db_dir=self.vector_db_dir,
                documents_dir=self.documents_dir
            )

        if self.full_context:
            retrieval_results = retriever.retrieve_all(company_name)
        else:           
            retrieval_results = retriever.retrieve_by_company_name(
                company_name=company_name,
                query=question,
                llm_reranking_sample_size=self.llm_reranking_sample_size,
                top_n=self.top_n_retrieval,
                return_parent_pages=self.return_parent_pages
            )
        
        if not retrieval_results:
            raise ValueError("No relevant context found")
        
        rag_context = self._format_retrieval_results(retrieval_results)
        answer_dict = self.openai_processor.get_answer_from_rag_context(
            question=question,
            rag_context=rag_context,
            schema=schema,
            model=self.answering_model
        )
        self.response_data = self.openai_processor.response_data
        if self.new_challenge_pipeline:
            pages = answer_dict.get("relevant_pages", [])
            validated_pages = self._validate_page_references(pages, retrieval_results)
            answer_dict["relevant_pages"] = validated_pages
            answer_dict["references"] = self._extract_references(validated_pages, company_name)
        return answer_dict

    def _extract_companies_from_subset(self, question_text: str) -> list[str]:
        """从问题文本中提取公司名,匹配subset文件中的公司"""
        if not hasattr(self, 'companies_df'):
            if self.subset_path is None:
                raise ValueError("subset_path must be provided to use subset extraction")
            self.companies_df = pd.read_csv(self.subset_path)
        
        found_companies = []
        company_names = sorted(self.companies_df['company_name'].unique(), key=len, reverse=True)
        
        for company in company_names:
            escaped_company = re.escape(company)
            
            pattern = rf'{escaped_company}(?:\W|$)'
            
            if re.search(pattern, question_text, re.IGNORECASE):
                found_companies.append(company)
                question_text = re.sub(pattern, '', question_text, flags=re.IGNORECASE)
        
        return found_companies

    def process_question(self, question: str, schema: str):
        # 处理单个问题,支持多公司比较
        if self.new_challenge_pipeline:
            extracted_companies = self._extract_companies_from_subset(question)
        else:
            extracted_companies = re.findall(r'"([^"]*)"', question)
        
        if len(extracted_companies) == 0:
            raise ValueError("No company name found in the question.")
        
        if len(extracted_companies) == 1:
            company_name = extracted_companies[0]
            answer_dict = self.get_answer_for_company(company_name=company_name, question=question, schema=schema)
            return answer_dict
        else:
            return self.process_comparative_question(question, extracted_companies, schema)
    
    def _create_answer_detail_ref(self, answer_dict: dict, question_index: int) -> str:
        """创建答案详情的引用ID,并存储详细内容"""
        ref_id = f"#/answer_details/{question_index}"
        with self._lock:
            self.answer_details[question_index] = {
                "step_by_step_analysis": answer_dict['step_by_step_analysis'],
                "reasoning_summary": answer_dict['reasoning_summary'],
                "relevant_pages": answer_dict['relevant_pages'],
                "response_data": self.response_data,
                "self": ref_id
            }
        return ref_id

    def _calculate_statistics(self, processed_questions: List[dict], print_stats: bool = False) -> dict:
        """统计处理结果,包括总数、错误数、N/A数、成功数"""
        total_questions = len(processed_questions)
        error_count = sum(1 for q in processed_questions if "error" in q)
        na_count = sum(1 for q in processed_questions if (q.get("value") if "value" in q else q.get("answer")) == "N/A")
        success_count = total_questions - error_count - na_count
        if print_stats:
            print(f"\nFinal Processing Statistics:")
            print(f"Total questions: {total_questions}")
            print(f"Errors: {error_count} ({(error_count/total_questions)*100:.1f}%)")
            print(f"N/A answers: {na_count} ({(na_count/total_questions)*100:.1f}%)")
            print(f"Successfully answered: {success_count} ({(success_count/total_questions)*100:.1f}%)\n")
        
        return {
            "total_questions": total_questions,
            "error_count": error_count,
            "na_count": na_count,
            "success_count": success_count
        }

    def process_questions_list(self, questions_list: List[dict], output_path: str = None, submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = "") -> dict:
        # 批量处理问题列表,支持并行与断点保存,返回处理结果和统计信息
        total_questions = len(questions_list)
        # 给每个问题加索引,便于后续答案详情定位
        questions_with_index = [{**q, "_question_index": i} for i, q in enumerate(questions_list)]
        self.answer_details = [None] * total_questions  # 预分配答案详情列表
        processed_questions = []
        parallel_threads = self.parallel_requests

        if parallel_threads <= 1:
            # 单线程顺序处理
            for question_data in tqdm(questions_with_index, desc="Processing questions"):
                processed_question = self._process_single_question(question_data)
                processed_questions.append(processed_question)
                if output_path:
                    self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
        else:
            # 多线程并行处理
            with tqdm(total=total_questions, desc="Processing questions") as pbar:
                for i in range(0, total_questions, parallel_threads):
                    batch = questions_with_index[i : i + parallel_threads]
                    with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_threads) as executor:
                        # executor.map 保证结果顺序与输入一致
                        batch_results = list(executor.map(self._process_single_question, batch))
                    processed_questions.extend(batch_results)
                    
                    if output_path:
                        self._save_progress(processed_questions, output_path, submission_file=submission_file, team_email=team_email, submission_name=submission_name, pipeline_details=pipeline_details)
                    pbar.update(len(batch_results))
        
        statistics = self._calculate_statistics(processed_questions, print_stats = True)
        
        return {
            "questions": processed_questions,
            "answer_details": self.answer_details,
            "statistics": statistics
        }

    def _process_single_question(self, question_data: dict) -> dict:
        question_index = question_data.get("_question_index", 0)
        
        if self.new_challenge_pipeline:
            question_text = question_data.get("text")
            schema = question_data.get("kind")
        else:
            question_text = question_data.get("question")
            schema = question_data.get("schema")
        try:
            answer_dict = self.process_question(question_text, schema)
            
            if "error" in answer_dict:
                detail_ref = self._create_answer_detail_ref({
                    "step_by_step_analysis": None,
                    "reasoning_summary": None,
                    "relevant_pages": None
                }, question_index)
                if self.new_challenge_pipeline:
                    return {
                        "question_text": question_text,
                        "kind": schema,
                        "value": None,
                        "references": [],
                        "error": answer_dict["error"],
                        "answer_details": {"$ref": detail_ref}
                    }
                else:
                    return {
                        "question": question_text,
                        "schema": schema,
                        "answer": None,
                        "error": answer_dict["error"],
                        "answer_details": {"$ref": detail_ref},
                    }
            detail_ref = self._create_answer_detail_ref(answer_dict, question_index)
            if self.new_challenge_pipeline:
                return {
                    "question_text": question_text,
                    "kind": schema,
                    "value": answer_dict.get("final_answer"),
                    "references": answer_dict.get("references", []),
                    "answer_details": {"$ref": detail_ref}
                }
            else:
                return {
                    "question": question_text,
                    "schema": schema,
                    "answer": answer_dict.get("final_answer"),
                    "answer_details": {"$ref": detail_ref},
                }
        except Exception as err:
            return self._handle_processing_error(question_text, schema, err, question_index)

    def _handle_processing_error(self, question_text: str, schema: str, err: Exception, question_index: int) -> dict:
        """
        处理问题处理过程中的异常。
        记录错误详情并返回包含错误信息的字典。
        """
        import traceback
        error_message = str(err)
        tb = traceback.format_exc()
        error_ref = f"#/answer_details/{question_index}"
        error_detail = {
            "error_traceback": tb,
            "self": error_ref
        }
        
        with self._lock:
            self.answer_details[question_index] = error_detail
        
        print(f"Error encountered processing question: {question_text}")
        print(f"Error type: {type(err).__name__}")
        print(f"Error message: {error_message}")
        print(f"Full traceback:\n{tb}\n")
        
        if self.new_challenge_pipeline:
            return {
                "question_text": question_text,
                "kind": schema,
                "value": None,
                "references": [],
                "error": f"{type(err).__name__}: {error_message}",
                "answer_details": {"$ref": error_ref}
            }
        else:
            return {
                "question": question_text,
                "schema": schema,
                "answer": None,
                "error": f"{type(err).__name__}: {error_message}",
                "answer_details": {"$ref": error_ref},
            }

    def _post_process_submission_answers(self, processed_questions: List[dict]) -> List[dict]:
        """
        提交格式后处理:
        1. 页码从1-based转为0-based
        2. N/A答案清空引用
        3. 格式化为比赛提交schema
        4. 包含step_by_step_analysis
        """
        submission_answers = []
        
        for q in processed_questions:
            question_text = q.get("question_text") or q.get("question")
            kind = q.get("kind") or q.get("schema")
            value = "N/A" if "error" in q else (q.get("value") if "value" in q else q.get("answer"))
            references = q.get("references", [])
            
            answer_details_ref = q.get("answer_details", {}).get("$ref", "")
            step_by_step_analysis = None
            if answer_details_ref and answer_details_ref.startswith("#/answer_details/"):
                try:
                    index = int(answer_details_ref.split("/")[-1])
                    if 0 <= index < len(self.answer_details) and self.answer_details[index]:
                        step_by_step_analysis = self.answer_details[index].get("step_by_step_analysis")
                except (ValueError, IndexError):
                    pass
            
            # Clear references if value is N/A
            if value == "N/A":
                references = []
            else:
                # Convert page indices from one-based to zero-based (competition requires 0-based page indices, but for debugging it is easier to use 1-based)
                references = [
                    {
                        "pdf_sha1": ref["pdf_sha1"],
                        "page_index": ref["page_index"] - 1
                    }
                    for ref in references
                ]
            
            submission_answer = {
                "question_text": question_text,
                "kind": kind,
                "value": value,
                "references": references,
            }
            
            if step_by_step_analysis:
                submission_answer["reasoning_process"] = step_by_step_analysis
            
            submission_answers.append(submission_answer)
        
        return submission_answers

    def _save_progress(self, processed_questions: List[dict], output_path: Optional[str], submission_file: bool = False, team_email: str = "", submission_name: str = "", pipeline_details: str = ""):
        if output_path:
            statistics = self._calculate_statistics(processed_questions)
            
            # Prepare debug content
            result = {
                "questions": processed_questions,
                "answer_details": self.answer_details,
                "statistics": statistics
            }
            output_file = Path(output_path)
            debug_file = output_file.with_name(output_file.stem + "_debug" + output_file.suffix)
            with open(debug_file, 'w', encoding='utf-8') as file:
                json.dump(result, file, ensure_ascii=False, indent=2)
            
            if submission_file:
                # Post-process answers for submission
                submission_answers = self._post_process_submission_answers(processed_questions)
                submission = {
                    "answers": submission_answers,
                    "team_email": team_email,
                    "submission_name": submission_name,
                    "details": pipeline_details
                }
                with open(output_file, 'w', encoding='utf-8') as file:
                    json.dump(submission, file, ensure_ascii=False, indent=2)

    def process_all_questions(self, output_path: str = 'questions_with_answers.json', team_email: str = "79250515615@yandex.com", submission_name: str = "Ilia_Ris SO CoT + Parent Document Retrieval", submission_file: bool = False, pipeline_details: str = ""):
        result = self.process_questions_list(
            self.questions,
            output_path,
            submission_file=submission_file,
            team_email=team_email,
            submission_name=submission_name,
            pipeline_details=pipeline_details
        )
        return result

    def process_comparative_question(self, question: str, companies: List[str], schema: str) -> dict:
        """
        处理多公司比较类问题:
        1. 先将比较问题重写为单公司问题
        2. 并行处理每个公司
        3. 汇总结果并生成最终比较答案
        """
        # Step 1: Rephrase the comparative question
        rephrased_questions = self.openai_processor.get_rephrased_questions(
            original_question=question,
            companies=companies
        )
        
        individual_answers = {}
        aggregated_references = []
        
        # Step 2: Process each individual question in parallel
        def process_company_question(company: str) -> tuple[str, dict]:
            """Helper function to process one company's question and return (company, answer)"""
            sub_question = rephrased_questions.get(company)
            if not sub_question:
                raise ValueError(f"Could not generate sub-question for company: {company}")
            
            answer_dict = self.get_answer_for_company(
                company_name=company, 
                question=sub_question, 
                schema="number"
            )
            return company, answer_dict

        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_company = {
                executor.submit(process_company_question, company): company 
                for company in companies
            }
            
            for future in concurrent.futures.as_completed(future_to_company):
                try:
                    company, answer_dict = future.result()
                    individual_answers[company] = answer_dict
                    
                    company_references = answer_dict.get("references", [])
                    aggregated_references.extend(company_references)
                except Exception as e:
                    company = future_to_company[future]
                    print(f"Error processing company {company}: {str(e)}")
                    raise
        
        # Remove duplicate references
        unique_refs = {}
        for ref in aggregated_references:
            key = (ref.get("pdf_sha1"), ref.get("page_index"))
            unique_refs[key] = ref
        aggregated_references = list(unique_refs.values())
        
        # Step 3: Get the comparative answer using all individual answers
        comparative_answer = self.openai_processor.get_answer_from_rag_context(
            question=question,
            rag_context=individual_answers,
            schema="comparative",
            model=self.answering_model
        )
        self.response_data = self.openai_processor.response_data
        
        comparative_answer["references"] = aggregated_references
        return comparative_answer
相关推荐
iCoding911 天前
前端分页 vs 后端分页:技术选型
前端·后端·系统架构
武子康1 天前
AI-调查研究-105-具身智能 机器人学习数据采集:从示范视频到状态-动作对的流程解析
人工智能·深度学习·机器学习·ai·系统架构·机器人·具身智能
qqxhb2 天前
系统架构设计师备考第38天——系统架构评估
系统架构·atam·架构评估·saam·敏感点·权衡点·度量
星瞰物联2 天前
RDSS 与 RNSS 定位技术深度解析(二)——系统架构、性能指标
网络·系统架构
月上柳青3 天前
rrk3588 与 NPU 主机下的异构通信:基于 PCIe 的设计与实现
系统架构
小古jy4 天前
系统架构设计师考点——软件架构设计(架构风格!!!)
架构·系统架构
武子康4 天前
Java-148 深入浅出 MongoDB 聚合操作:$match、$group、$project、$sort 全面解析 Pipeline 实例详解与性能优化
java·数据库·sql·mongodb·性能优化·系统架构·nosql
老友@4 天前
集中式架构、分布式架构与微服务架构全面解析
分布式·微服务·架构·系统架构
十五年专注C++开发5 天前
QT 中的元对象系统(六):connect函数详解
开发语言·c++·qt·设计模式·系统架构·qevent