一个完整的 RAG 系统涉及多个技术组件,每个组件的选型都直接影响系统的性能、成本和可维护性。本文档详细分析本项目各项技术选型的考量因素及竞品对比。
一、向量数据库:Milvus
为什么选择 Milvus?
本项目选择 Milvus 作为向量数据库,核心考量如下:
1. 混合检索原生支持
Milvus 是少数原生支持稠密向量 + 稀疏向量混合检索的向量数据库之一。这意味着:
-
无需额外对接 Elasticsearch 等组件
-
一次查询即可完成语义检索和关键词检索的融合
-
减少系统复杂度和维护成本
2. 索引类型丰富
Milvus 支持 IVF_FLAT、HNSW、DiskANN 等多种索引类型,可根据场景灵活选择:
-
本项目使用 IVF_FLAT 稠密索引 + SPARSE_INVERTED_INDEX 稀疏索引
-
在精度和性能之间取得平衡
3. 云原生架构
Milvus 采用存算分离的云原生架构,支持水平扩展,为后续大规模部署预留空间。
4. 中文社区活跃
Milvus 由 Zilliz 开源维护,中文文档完善,社区活跃,问题响应快。
竞品对比
| 维度 | Milvus(本项�选择) | Qdrant | Faiss | 腾讯云向量数据库 |
|---|---|---|---|---|
| 架构 | 分布式云原生 | Rust 高性能 | 单机/轻量集群 | 全托管 Serverless |
| 混合检索 | ✅ 原生支持 | ✅ 标量过滤 | ❌ 无 | ✅ 支持 |
| 稀疏向量 | ✅ SPARSE_FLOAT_VECTOR | ❌ | ❌ | ✅ IVF_RaBitQ |
| 运维复杂度 | 需 K8s 集群管理 | 需 Rust 环境 | 无分布式功能 | 自动扩缩容 |
| 成本 | 硬件采购成本高 | 社区版免费 | 无网络开销 | 按需付费 |
| 适用场景 | 电商/安防大规模检索 | 实时反欺诈 | 学术研究 | 企业级版权管理 |
选型决策逻辑:
-
放弃 Qdrant:虽然 Rust 实现性能优异,但对混合检索的支持较弱,稀疏向量能力不足
-
放弃 Faiss:纯学术工具,无分布式能力,不适合生产环境
-
放弃云托管:项目需要本地部署,数据不出域,且希望掌握核心技术栈
二、Embedding 模型:BGE-M3
为什么选择 BGE-M3?
本项目选择 BAAI 开源的 BGE-M3 作为嵌入模型,这是整个系统最核心的选型决策之一。
1. 混合向量能力(关键优势)
BGE-M3 最突出的特性是能同时生成三种向量表示:
| 向量类型 | 作用 | 在本项目中的使用 |
|---|---|---|
| 稠密向量 | 捕捉深层语义关联,处理同义词、paraphrase | Milvus 稠密向量检索 |
| 稀疏向量 | 生成"相关术语权重",语义推理关键词 | Milvus 稀疏向量检索 |
| 多向量 | 每个 token 的嵌入向量,用于交互计算 | 辅助精排 |
这意味着:
-
一套模型,两种检索:同一个模型输出的稠密和稀疏向量可直接用于混合检索
-
语义推理能力:即使文本中未出现"机器学习",模型也能通过语义推理在稀疏向量中赋予该术语权重
-
统一向量空间:稠密和稀疏向量在语义上对齐,融合效果更好
2. 中文优化
BGE-M3 在中英文多语言场景表现均衡,词汇量覆盖 100+ 语言,对中文专业术语的处理能力远超 OpenAI 的通用模型。
3. 上下文窗口
支持 8192 tokens 的上下文窗口,可覆盖完整课程章节,无需截断。
4. 开源可私有化部署
BGE-M3 采用 MIT 许可证,可自由修改和商用,满足项目私有化部署要求。
竞品对比
| 维度 | BGE-M3(本项目选择) | BERT | OpenAI text-embedding-ada-002 | Word2Vec |
|---|---|---|---|---|
| 向量类型 | 稠密+稀疏+多向量 | 仅稠密 | 仅稠密 | 静态词向量 |
| 语义理解 | 深度语义推理 | 双向上下文 | 深度语义 | 无 |
| 中文支持 | ✅ 100+ 语言 | ✅ 有中文版 | 一般 | 依赖语料 |
| 上下文窗口 | 8192 tokens | 512 tokens | 8192 tokens | 词级别 |
| 开源 | ✅ MIT | ✅ Apache 2.0 | ❌ 商业 | ✅ |
| MTEB 中文榜 | 前列 | 基准水平 | 中等 | 不适用 |
选型决策逻辑:
-
放弃 BERT:BERT 只能生成稠密向量,无法支撑混合检索;上下文窗口仅 512 tokens
-
放弃 OpenAI:商业 API,数据需上传,不适合教育数据私有化场景;无法生成稀疏向量
-
放弃 Word2Vec:静态词向量无法处理一词多义,语义理解能力严重不足
核心洞察:BGE-M3 的"一套模型、两种向量"能力是本项目混合检索策略的技术基石。若使用纯稠密模型,关键词匹配能力不足;若使用纯稀疏模型,语义理解能力不足。
三、重排序模型:BGE-Reranker-Large
为什么选择 BGE-Reranker-Large?
1. Cross-Encoder 架构
与 Bi-Encoder(如 BGE-M3)不同,Cross-Encoder 将 query 和 document 拼接后一起输入模型,通过交叉注意力机制深度计算相关性,精度远高于向量相似度计算。
2. 与 BGE-M3 同源
BGE-Reranker 与 BGE-M3 同属 BAAI 系列,向量空间对齐,配合使用效果更佳。
3. 中英文支持良好
在中文 MTEB 榜单上表现优异,适合教育领域的中文问答场景。
竞品对比
| 维度 | BGE-Reranker(本项目选择) | Cohere Rerank | 纯向量相似度 |
|---|---|---|---|
| 精度 | 高(Cross-Encoder) | 最高 | 低(Bi-Encoder) |
| 速度 | 中等(需逐对计算) | 中等 | 快 |
| 中文支持 | ✅ 优秀 | 一般 | - |
| 私有化部署 | ✅ 开源 | ❌ API 调用 | ✅ |
| 成本 | 本地 GPU | 按量付费 | 免费 |
四、Web 框架:FastAPI
为什么选择 FastAPI?
1. 异步高性能
FastAPI 基于 Starlette(ASGI),支持原生异步,性能远超 Flask 和 Django:
-
FastAPI:每秒 3 万次请求
-
Flask:每秒 9 千次请求
-
Django:每秒 5 千次请求
对于需要流式输出 SSE 的场景,异步能力至关重要。
2. 自动生成 API 文档
FastAPI 基于 Python type hints 自动生成 OpenAPI(Swagger)文档,开发调试效率高。
3. WebSocket 原生支持
本项目需要 WebSocket 支持流式输出,FastAPI 提供开箱即用的 WebSocket 支持。
竞品对比
| 维度 | FastAPI(本项目选择) | Flask | Django |
|---|---|---|---|
| 性能 | 极高(异步) | 低(同步) | 低(同步) |
| 并发能力 | 强 | 弱 | 弱 |
| WebSocket | ✅ 原生 | ❌ 需扩展 | ❌ 需扩展 |
| 自动文档 | ✅ OpenAPI | ❌ | ❌ |
| 学习曲线 | 平缓 | 平缓 | 陡峭 |
| 适用场景 | API 服务、AI 应用 | 小型网站、原型 | 完整 Web 应用 |
选型决策逻辑:
-
放弃 Flask:同步框架无法高效处理流式响应和 WebSocket
-
放弃 Django:过于重量级,且异步支持不成熟
五、本地 LLM 部署:Ollama
为什么选择 Ollama?
1. 极致简单的部署体验
Ollama 被开发者称为"LLM 版 Docker",一条命令即可完成模型下载和运行:
python
ollama pull qwen2.5:7b
ollama run qwen2.5:7b
2. OpenAI 兼容 API
Ollama 原生提供 OpenAI 兼容的 API 接口(http://localhost:11434/v1),LangChain 等框架开箱即用,代码无需修改即可切换后端。
3. 跨平台 + Apple Silicon 加速
支持 Windows、Linux、macOS,对 M 系列芯片有深度优化。
4. 内置模型管理
内置模型市场,支持 200+ 预量化模型,管理方便。
竞品对比
| 维度 | Ollama(本项目选择) | llama.cpp | vLLM | LM Studio |
|---|---|---|---|---|
| 定位 | 开发者快速集成 | 极致性能调优 | 高并发 API 服务 | 桌面可视化 |
| 上手难度 | ⭐ 极低 | ⭐⭐⭐ 中等 | ⭐⭐⭐ 中等 | ⭐ 极低 |
| 生产级性能 | 一般 | 较高 | 最高 | 一般 |
| 多 GPU | 基础支持 | 有限 | ✅ 完善 | ❌ |
| 适用场景 | 原型开发、本地集成 | 边缘设备、资源受限 | 大规模生产服务 | 模型探索、日常使用 |
选型决策逻辑:
-
放弃 llama.cpp:本项目需要 API 服务,llama.cpp 偏向命令行工具,集成成本高
-
放弃 vLLM:vLLM 主要面向高并发生产环境,对硬件要求高,项目初期过度设计
-
放弃 LM Studio:GUI 应用不适合自动化集成和脚本化部署
六、缓存:Redis
为什么选择 Redis?
1. 丰富的数据结构
与 Memcached 仅支持简单 key-value 不同,Redis 支持 String、Hash、List、Set、Sorted Set 等多种数据结构,本项目需要存储:
-
答案缓存(String)
-
BM25 分词索引(List/Set)
-
会话状态(Hash)
2. 持久化支持
Redis 提供 RDB 和 AOF 两种持久化机制,重启后缓存不丢失;Memcached 纯内存,重启即失。
3. 主从复制与高可用
Redis 原生支持主从复制和 Sentinel 高可用架构,Memcached 不支持。
4. 广泛的生态支持
几乎所有的编程语言都有成熟的 Redis 客户端,与 Python 集成无缝。
竞品对比
| 维度 | Redis(本项目选择) | Memcached |
|---|---|---|
| 数据结构 | 丰富(String/Hash/List/Set/ZSet) | 仅 key-value |
| 持久化 | ✅ RDB + AOF | ❌ 纯内存 |
| 主从复制 | ✅ | ❌ |
| 高可用 | ✅ Sentinel/Cluster | ❌ |
| 性能 | 单核,小数据更优 | 多核,大数据更优 |
| 适用场景 | 复杂缓存、会话存储、消息队列 | 简单缓存、静态数据 |
七、关系型数据库:MySQL
为什么选择 MySQL?
1. 生态成熟,文档丰富
MySQL 拥有庞大的社区和完善的文档,问题解决方案丰富,开发效率高。
2. 轻量易用
相比 PostgreSQL,MySQL 配置简单、上手快,适合中小规模项目快速迭代。
3. 满足需求
本项目的数据库需求相对简单:
-
BM25 知识库(jpkb 表)
-
对话历史(conversations 表)
不需要 PostgreSQL 的高级特性(如 JSONB、PostGIS、递归 CTE 等)。
4. 广泛的工具支持
与 pandas、SQLAlchemy 等 Python 生态工具集成良好。
竞品对比
| 维度 | MySQL(本项目选择) | PostgreSQL |
|---|---|---|
| 上手难度 | 低 | 中高 |
| JSON 支持 | ✅ | ✅✅(JSONB 更强) |
| 并发能力 | 高 | 极高 |
| 高级功能 | 有限 | 丰富(全文搜索、GIS、数组等) |
| 适用场景 | Web 应用、中小型系统 | 复杂查询、分析系统、GIS |
选型决策逻辑:
-
放弃 PostgreSQL:项目不需要其高级功能,MySQL 更轻量、更易维护
-
若后续需要复杂分析查询,可考虑迁移
八、中文分词:Jieba
为什么选择 Jieba?
1. 轻量易用
Jieba 是最流行的 Python 中文分词库,API 简洁,开箱即用:
python
import jieba
seg_list = jieba.cut("中文自然语言处理库")
2. 支持自定义词典
可通过 jieba.load_userdict() 加载领域词典,解决教育术语分词问题。
3. 性能足够
本项目 BM25 索引在系统启动时一次性构建,对分词速度要求不高,Jieba 完全满足。
4. 社区活跃
GitHub 活跃,文档完善,问题解答资源丰富。
竞品对比
| 维度 | Jieba(本项目选择) | HanLP | LTP(哈工大) |
|---|---|---|---|
| 定位 | 轻量分词 | 功能全面 NLP 工具包 | 专业语言技术平台 |
| 功能 | 基础分词 | 词性标注、命名实体、依存句法 | 10+ 种功能 |
| 词典支持 | ✅ 自定义 | ✅ 强大 | ✅ |
| 模型规模 | 小 | 大 | 大 |
| 资源占用 | 低 | 中高 | 高 |
| 适用场景 | 轻量项目、快速集成 | 专业 NLP 任务 | 学术研究、复杂分析 |
选型决策逻辑:
-
放弃 HanLP/LTP:项目只需要基础分词用于 BM25,不需要词性标注、命名实体等高级功能
-
Jieba 在"够用"和"轻量"之间取得最佳平衡
九、文档处理:PyMuPDF
为什么选择 PyMuPDF?
1. 功能全面
PyMuPDF(fitz)不仅支持 PDF 文本提取,还支持:
-
图片提取与处理
-
页面旋转校正
-
PDF 修复与转换
这些能力对处理扫描版 PDF 至关重要。
2. 高性能
C 语言实现,处理速度快,适合批量文档处理。
3. 与 OCR 良好配合
本项目需要识别 PDF 中的图片并 OCR,PyMuPDF 的图片提取接口非常便捷。
竞品对比
| 维度 | PyMuPDF(本项目选择) | pdfplumber | PyPDF2 | OCRmyPDF |
|---|---|---|---|---|
| 速度 | 快 | 中等 | 慢 | 慢 |
| 图片提取 | ✅ 完善 | ✅ | ❌ | ✅ |
| 旋转校正 | ✅ | ❌ | ❌ | ✅ |
| OCR 配合 | 便捷 | 一般 | 不支持 | 专用 OCR |
| 表格提取 | 一般 | ✅ 优秀 | ❌ | ❌ |
| 适用场景 | 通用 PDF 处理 | 表格密集型 | 简单操作 | 扫描版 PDF |
十、选型决策总结
选型矩阵
| 组件 | 选择 | 核心考量 | 竞品 |
|---|---|---|---|
| 向量数据库 | Milvus | 混合检索原生支持 | Qdrant、Faiss |
| Embedding 模型 | BGE-M3 | 同时生成稠密+稀疏向量 | BERT、OpenAI |
| 重排序模型 | BGE-Reranker | Cross-Encoder 高精度 | Cohere |
| Web 框架 | FastAPI | 异步高性能 + WebSocket | Flask、Django |
| 本地 LLM | Ollama | 部署简单 + OpenAI 兼容 | llama.cpp、vLLM |
| 缓存 | Redis | 数据结构丰富 + 持久化 | Memcached |
| 数据库 | MySQL | 轻量易用 + 生态成熟 | PostgreSQL |
| 分词 | Jieba | 轻量够用 + 自定义词典 | HanLP、LTP |
| PDF 处理 | PyMuPDF | 功能全面 + 性能高 | pdfplumber、PyPDF2 |
选型原则
-
够用优先:不引入不需要的复杂度,Jieba 能完成分词就不上 HanLP
-
私有化部署:所有组件均支持本地部署,数据不出域
-
中文友好:优先选择中文支持好的工具(BGE-M3、Jieba)
-
开源优先:避免商业 API 锁定,降低长期成本
-
生态整合:优先选择与 Python/LangChain 生态集成好的工具
成本与性能权衡
| 环节 | 决策 | 理由 |
|---|---|---|
| 检索 | 三级降级(Redis → BM25 → Milvus) | 用低成本检索拦截大部分请求 |
| Embedding | BGE-M3(本地) | 避免 API 调用成本,一次性显卡投入 |
| LLM | Ollama + Qwen2.5-7B(本地) | 避免按 token 付费,隐私保护 |
| 部署 | 单机可运行 | 降低硬件门槛,便于演示和测试 |
闲话不再多说,上代码!!!
项目目录结构
python
"""
d_multi_layer_rag/
│
├── base/ # 基础模块
│ ├── __init__.py
│ ├── config.py # 全局配置管理
│ └── logger.py # 日志系统
│
├── cache/ # 缓存层
│ └── redis_client.py # Redis 客户端封装
│
├── db/ # 数据库层
│ └── mysql_client.py # MySQL 客户端封装
│
├── logs/ # 日志目录
│ └── app.log # 应用日志文件
│
├── mysql_qa/ # MySQL 问答模块
│ ├── __init__.py
│ ├── cache/
│ │ └── redis_client.py # Redis 缓存
│ ├── db/
│ │ └── mysql_client.py # MySQL 操作
│ └── utils/
│ └── preprocess.py # 文本预处理(分词)
│
├── rag_qa/ # RAG 问答核心模块
│ ├── __init__.py
│ │
│ ├── core/ # 核心逻辑
│ │ ├── __init__.py
│ │ ├── document_processor.py # 文档处理 + 父子切分
│ │ ├── new_rag_system.py # RAG 系统(流式+历史)
│ │ ├── prompts.py # Prompt 模板管理
│ │ ├── query_classifier.py # BERT 查询分类器
│ │ ├── rag_system.py # RAG 系统(基础版)
│ │ ├── strategy_selector.py # LLM 策略选择器
│ │ └── vector_store.py # Milvus 向量存储
│ │
│ ├── edu_document_loaders/ # 文档加载器
│ │ ├── __init__.py
│ │ ├── edu_docloader.py # Word 文档加载器
│ │ ├── edu_imgloader.py # 图片加载器
│ │ ├── edu_ocr.py # OCR 识别封装
│ │ ├── edu_pdfloader.py # PDF 加载器
│ │ ├── edu_pptloader.py # PPT 加载器
│ │ └── review.py # 代码审查
│ │
│ └── edu_text_spliter/ # 文本切分器
│ ├── __init__.py
│ ├── edu_chinese_recursive_text_splitter.py # 中文递归切分
│ ├── edu_model_text_spliter.py # 模型语义切分
│ └── review.py
│
├── retrieval/ # 检索模块
│ └── bm25_search.py # BM25 检索实现
│
├── utils/ # 工具模块
│ ├── __init__.py
│ └── preprocess.py # 通用预处理
│
├── data/ # 数据目录
│ ├── ai_data/ # AI 学科数据
│ ├── java_data/ # Java 学科数据
│ ├── test_data/ # 测试学科数据
│ ├── ops_data/ # Ops 学科数据
│ ├── bigdata_data/ # 大数据学科数据
│ └── ocr_samples/ # OCR 测试样本
│
├── models/ # 本地模型目录
│ ├── bert-base-chinese/ # BERT 基础模型
│ ├── bert_query_classifier/ # 微调后的分类模型
│ ├── bge-m3/ # BGE-M3 嵌入模型
│ └── bge-reranker-large/ # BGE 重排序模型
│
├── static/ # 前端静态文件
│ └── index.html # Web 演示页面
│
├── app.py # FastAPI 应用(WebSocket + HTTP)
├── new_main.py # 新版主入口(集成问答系统)
├── old_main.py # 旧版主入口
├── use_api.py # API 调用示例
├── config.ini # 配置文件
├── requirements.txt # 依赖列表
└── review.py # 代码审查
"""
核心文件说明
| 文件 | 作用 | 重要性 |
|---|---|---|
| base/config.py | 统一配置管理(MySQL/Redis/Milvus/LLM) | ⭐⭐⭐⭐⭐ |
| rag_qa/core/vector_store.py | Milvus 向量存储 + 混合检索 + 重排序 | ⭐⭐⭐⭐⭐ |
| rag_qa/core/document_processor.py | 文档加载 + 父子切分 | ⭐⭐⭐⭐⭐ |
| rag_qa/core/new_rag_system.py | RAG 核心流程(分类→策略→检索→生成) | ⭐⭐⭐⭐⭐ |
| mysql_qa/bm25_search.py | BM25 关键词检索 + 降级逻辑 | ⭐⭐⭐⭐ |
| rag_qa/core/query_classifier.py | BERT 查询分类器 | ⭐⭐⭐⭐ |
| rag_qa/core/strategy_selector.py | LLM 策略选择器 | ⭐⭐⭐⭐ |
| app.py | FastAPI 服务(WebSocket 流式) | ⭐⭐⭐⭐ |
| new_main.py | 集成问答系统入口 | ⭐⭐⭐⭐ |
数据流简图
python
"""
data/*.pdf/.docx/.pptx
│
▼
edu_document_loaders/ → Document 对象
│
▼
document_processor.py → 父子切分 → Child Chunks
│
▼
vector_store.py → Milvus 向量库
│
▼
new_main.py / app.py → 用户查询入口
│
├── BM25Search (MySQL + Redis) → 快速命中
│
└── RAGSystem (Milvus + LLM) → 深度检索
│
├── QueryClassifier (BERT) → 分类
├── StrategySelector (LLM) → 选策略
├── VectorStore (Milvus) → 混合检索
└── call_llm_stream (Ollama) → 流式输出
"""
代码展示
config.ini
python
# MySQL 配置
[mysql]
host = localhost
port = 33060
user = root
password = root123
database = multi_layer_rag_mysql_db
# Redis 配置
[redis]
host = localhost
port = 6379
password = 1234
db = 0
# Milvus 配置
[milvus]
host = localhost
port = 19530
database_name = multi_layer_rag_milvus_db
collection_name = multi_layer_rag_vectors
# LLM 配置
[llm]
model = deepseek-ai/DeepSeek-V3
dashscope_api_key = your_api_key_here
dashscope_base_url = https://api.siliconflow.cn/v1
# 检索参数配置
[retrieval]
parent_chunk_size = 1200
child_chunk_size = 300
chunk_overlap = 50
retrieval_k = 5
candidate_m = 2
# 日志配置
[logger]
log_file = logs/app.log
# 应用配置
[app]
valid_sources = ["ai", "java", "test", "ops", "bigdata"]
customer_service_phone = 12345678
base/config.py
python
# -*- coding:utf-8 -*-
# 导入配置ini文件的解析库
import configparser
# 导入路径操作
import os
# 获取当前文件的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在目录的绝对路径
current_dir_path = os.path.dirname(current_file_path)
# 获取项目根目录的绝对路径
project_root = os.path.dirname(current_dir_path)
config_file_path = os.path.join(project_root, 'config.ini')
class Config():
def __init__(self, config_file=config_file_path):
# config_file代表配置文件ini的路径
# 1.创建配置文件解析器
self.config = configparser.ConfigParser()
# 2. 读取配置文件
self.config.read(config_file, encoding='utf-8')
# 3. 获取相关的配置
# 3.1 获取Mysql数据库的配置
self.MYSQL_HOST = self.config.get('mysql', 'host', fallback='localhost')
# MySQL 用户名
self.MYSQL_USER = self.config.get('mysql', 'user', fallback='root')
# MySQL 密码
self.MYSQL_PASSWORD = self.config.get('mysql', 'password', fallback='123456')
# MySQL 数据库名
self.MYSQL_DATABASE = self.config.get('mysql', 'database', fallback='subjects_kg')
# Redis 配置
# Redis 主机地址
self.REDIS_HOST = self.config.get('redis', 'host', fallback='localhost')
# Redis 端口
self.REDIS_PORT = self.config.getint('redis', 'port', fallback=6379)
# Redis 密码
self.REDIS_PASSWORD = self.config.get('redis', 'password', fallback='1234')
# Redis 数据库编号
self.REDIS_DB = self.config.getint('redis', 'db', fallback=0)
# Milvus 配置
# Milvus 主机地址
self.MILVUS_HOST = self.config.get('milvus', 'host', fallback='localhost')
# Milvus 端口
self.MILVUS_PORT = self.config.get('milvus', 'port', fallback='19530')
# Milvus 数据库名
self.MILVUS_DATABASE_NAME = self.config.get('milvus', 'database_name', fallback='itcast')
# Milvus 集合名
self.MILVUS_COLLECTION_NAME = self.config.get('milvus', 'collection_name', fallback='edurag_final')
# LLM 配置
# LLM 模型名
self.LLM_MODEL = self.config.get('llm', 'model')
# DashScope API 密钥
self.DASHSCOPE_API_KEY = self.config.get('llm', 'dashscope_api_key')
# DashScope API 地址
self.DASHSCOPE_BASE_URL = self.config.get('llm', 'dashscope_base_url')
# 检索参数
# 父块大小
self.PARENT_CHUNK_SIZE = self.config.getint('retrieval', 'parent_chunk_size', fallback=1200)
# 子块大小
self.CHILD_CHUNK_SIZE = self.config.getint('retrieval', 'child_chunk_size', fallback=300)
# 块重叠大小
self.CHUNK_OVERLAP = self.config.getint('retrieval', 'chunk_overlap', fallback=50)
# 检索返回数量
self.RETRIEVAL_K = self.config.getint('retrieval', 'retrieval_k', fallback=5)
# 最终候选数量
self.CANDIDATE_M = self.config.getint('retrieval', 'candidate_m', fallback=2)
# 应用配置
self.CUSTOMER_SERVICE_PHONE = self.config.get('app', 'customer_service_phone')
self.VALID_SOURCES = eval(self.config.get('app', 'valid_sources', fallback=["ai", "java", "test", "ops", "bigdata"]))
# 日志文件路径
self.LOG_FILE = self.config.get('logger', 'log_file', fallback='logs/app.log')
if __name__ == '__main__':
conf = Config()
print(conf.CHUNK_OVERLAP)
print(conf.VALID_SOURCES)
print(type(conf.VALID_SOURCES))
base/logger.py
python
# -*- coding:utf-8 -*-
# 导入日志库
import logging
# 导入路径操作库
import os
# 导入配置类
from d_multi_layer_rag.base.config import Config
# 获取当前文件的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在目录的绝对路径
current_dir_path = os.path.dirname(current_file_path)
# 获取项目根目录的绝对路径
project_root = os.path.dirname(current_dir_path)
log_file_path = os.path.join(project_root, Config().LOG_FILE)
def setup_logging(log_file=log_file_path):
# 创建日志目录
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# 获取日志器
logger = logging.getLogger("EduRAG")
# 设置日志级别
logger.setLevel(logging.INFO)
# 避免重复添加处理器
if not logger.handlers:
# 创建文件处理器
file_handler = logging.FileHandler(log_file, encoding='utf-8')
# 设置文件处理器级别
file_handler.setLevel(logging.INFO)
# 创建控制台处理器
console_handler = logging.StreamHandler()
# 设置控制台处理器级别
console_handler.setLevel(logging.INFO)
# 设置日志格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# 为文件处理器设置格式
file_handler.setFormatter(formatter)
# 为控制台处理器设置格式
console_handler.setFormatter(formatter)
# 添加文件处理器
logger.addHandler(file_handler)
# 添加控制台处理器
logger.addHandler(console_handler)
# 返回日志器
return logger
# 初始化日志器
logger = setup_logging()
cache/redis_client.py
python
# cache/redis_client.py
# 导入 Redis 客户端
import redis
# 导入 JSON 处理
import json
import os
import sys
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
module_dir = os.path.dirname(current_dir)
project_root = os.path.dirname(module_dir)
sys.path.insert(0, project_root)
# 导入配置和日志
from d_multi_layer_rag.base import Config, logger
class RedisClient:
def __init__(self):
# 初始化日志
self.logger = logger
try:
# 连接 Redis
self.client = redis.StrictRedis(
host=Config().REDIS_HOST,
port=Config().REDIS_PORT,
password=Config().REDIS_PASSWORD,
db=Config().REDIS_DB,
decode_responses=True
)
# 记录连接成功
self.logger.info("Redis 连接成功")
except redis.RedisError as e:
# 记录连接失败
self.logger.error(f"Redis 连接失败: {e}")
raise
def set_data(self, key, value):
# 存储数据到 Redis
try:
# 存储 JSON 数据
self.client.set(key, json.dumps(value, ensure_ascii=False))
# 记录存储成功
self.logger.info(f"存储数据到 Redis: {key}")
except redis.RedisError as e:
# 记录存储失败
self.logger.error(f"Redis 存储失败: {e}")
def get_data(self, key):
# 从 Redis 获取数据
try:
# 获取数据
data = self.client.get(key)
# 返回解析后的 JSON 数据或 None
return json.loads(data) if data else None
except redis.RedisError as e:
# 记录获取失败
self.logger.error(f"Redis 获取失败: {e}")
# 返回 None
return None
def get_answer(self, query):
# 获取查询的缓存答案
try:
# 从 Redis 获取答案
answer = self.client.get(f"answer:{query}")
if answer:
# 记录获取成功
self.logger.info(f"从 Redis 获取答案: {query}")
# 返回答案
return answer
# 返回 None
return None
except redis.RedisError as e:
# 记录查询失败
self.logger.error(f"Redis 查询失败: {e}")
# 返回 None
return None
if __name__ == '__main__':
redcli = RedisClient()
print(redcli)
redcli.client.delete("qa_tokenized_questions", "qa_original_questions", "answer:VMware安装VMware时显示灰色如何解决")
print(redcli.client.keys("*"))
db/mysql_client.py
python
# -*- coding:utf-8 -*-
import pymysql
# 导入pandas
import pandas as pd
# 导入配置和日志
import sys
import os
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
module_dir = os.path.dirname(current_dir)
project_root = os.path.dirname(module_dir)
sys.path.insert(0, project_root)
from d_multi_layer_rag.base import Config, logger
class MySQLClient:
def __init__(self):
# 初始化日志
self.logger = logger
try:
# 连接 MySQL 数据库
self.connection = pymysql.connect(
host=Config().MYSQL_HOST,
user=Config().MYSQL_USER,
port=Config().MYSQL_PORT,
password=Config().MYSQL_PASSWORD,
database=Config().MYSQL_DATABASE
)
# 创建游标
self.cursor = self.connection.cursor()
# 记录连接成功
self.logger.info("MySQL 连接成功")
except pymysql.MySQLError as e:
# 记录连接失败
self.logger.error(f"MySQL 连接失败: {e}")
raise
def create_table(self):
create_table_query = '''
CREATE TABLE IF NOT EXISTS jpkb (
id INT AUTO_INCREMENT PRIMARY KEY,
subject_name VARCHAR(20),
question VARCHAR(1000),
answer VARCHAR(1000))
'''
try:
self.cursor.execute(create_table_query)
self.connection.commit()
self.logger.info("表创建成功")
except pymysql.MySQLError as e:
self.logger.error(f"表创建失败: {e}")
raise
def insert_data(self, csv_path):
try:
data = pd.read_csv(csv_path)
print(data.head())
for _, row in data.iterrows():
insert_query = "INSERT INTO jpkb (subject_name, question, answer) VALUES (%s, %s, %s)"
self.cursor.execute(insert_query, (row["学科名称"], row["问题"], row["答案"]))
self.connection.commit()
self.logger.info("Mysql数据插入成功")
except Exception as e:
self.logger.error(f'Mysql数据插入失败:{e}')
self.connection.rollback()
raise
def fetch_questions(self):
# 获取所有问题
try:
# 执行查询
self.cursor.execute("SELECT question FROM jpkb")
# 获取结果
results = self.cursor.fetchall()
# 记录获取成功
self.logger.info("成功获取问题")
# 返回结果
return results
except pymysql.MySQLError as e:
# 记录查询失败
self.logger.error(f"查询失败: {e}")
# 返回空列表
return []
def fetch_answer(self, question):
# 获取指定问题的答案
try:
# 执行查询
self.cursor.execute("SELECT answer FROM jpkb WHERE question=%s", (question,))
# 获取结果
result = self.cursor.fetchone()
# 返回答案或 None
return result[0] if result else None
except pymysql.MySQLError as e:
# 记录答案获取失败
self.logger.error(f"答案获取失败: {e}")
# 返回 None
return None
def close(self):
# 关闭数据库连接
try:
# 关闭连接
self.connection.close()
# 记录关闭成功
self.logger.info("MySQL 连接已关闭")
except pymysql.MySQLError as e:
# 记录关闭失败
self.logger.error(f"关闭连接失败: {e}")
if __name__ == '__main__':
mysql_client = MySQLClient()
a = mysql_client.fetch_answer(question="在磁盘中无法新建文本文档")
print(f'a--》{a}')
mysql_client.close()
mysql_qa/utils/preprocess.py
python
# utils/preprocess.py
# 导入分词库
import jieba
# 导入日志
import os
import sys
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
module_dir = os.path.dirname(current_dir)
project_root = os.path.dirname(module_dir)
sys.path.insert(0, project_root)
from d_multi_layer_rag.base import logger
def preprocess_text(text):
# 预处理文本
logger.info("开始预处理文本")
try:
# 分词并转换为小写
return jieba.lcut(text.lower())
except AttributeError as e:
# 记录预处理失败
logger.error(f"文本预处理失败: {e}")
# 返回空列表
return []
if __name__ == '__main__':
print(preprocess_text(text="AI程序员"))
retrieval/bm25_search.py
python
# 导入 BM25 算法
from rank_bm25 import BM25Okapi
# 导入数值计算库
import numpy as np
# 导入文本预处理
import sys
import os
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
module_dir = os.path.dirname(current_dir)
sys.path.insert(0, module_dir)
project_root = os.path.dirname(module_dir)
sys.path.insert(0, project_root)
from d_multi_layer_rag.mysql_qa.utils.preprocess import preprocess_text
from d_multi_layer_rag.mysql_qa.db.mysql_client import MySQLClient
from d_multi_layer_rag.mysql_qa.cache.redis_client import RedisClient
# 导入日志
from d_multi_layer_rag.base import logger
class BM25Search:
def __init__(self, redis_client, mysql_client):
# 初始化日志
self.logger = logger
# 初始化 Redis 客户端
self.redis_client = redis_client
# 初始化 MySQL 客户端
self.mysql_client = mysql_client
# 初始化 BM25 模型
self.bm25 = None
# 初始化问题列表
self.questions = None
# 初始化原始问题
self.original_questions = None
# 加载数据
self._load_data()
def _load_data(self):
# 加载数据
original_key = "qa_original_questions"
tokenized_key = "qa_tokenized_questions"
# 从redis中获取原始问题(快)
self.original_questions = self.redis_client.get_data(original_key)
# 从redis中获取分词后的问题(快)
tokenized_questions = self.redis_client.get_data(tokenized_key)
# 如果 Redis 中没有数据,从 MySQL 加载
if not self.original_questions or not tokenized_questions:
# 从Mysql中获取问题
self.original_questions = self.mysql_client.fetch_questions()
# 如果mysql中未获得问题,那么给出警告
if not self.original_questions:
self.logger.warning("未加载问题")
return
# 对问题进行分词
tokenized_questions = [preprocess_text(q[0]) for q in self.original_questions]
# 把原始的问题存储到redis
self.redis_client.set_data(original_key, [(q[0]) for q in self.original_questions])
# 把分词之后的问题存储到redis
self.redis_client.set_data(tokenized_key, tokenized_questions)
# 设置问题列表
self.questions = tokenized_questions
# 初始化 BM25 模型
self.bm25 = BM25Okapi(self.questions)
# 记录 BM25 初始化成功
self.logger.info("BM25 模型初始化完成")
def _softmax(self, scores):
# 计算softmax分数:但是我们对每个score都减去一个最大值,为了防止数据过大,内存爆炸
exp_scores = np.exp(scores - np.max(scores))
# 返回归一化分数
return exp_scores / exp_scores.sum()
def search(self, query, threshold=0.85):
# 搜索查询
if not query or not isinstance(query, str):
# 记录无效查询
self.logger.error("无效查询")
return None, False
# 检查Redis缓存
cached_answer = self.redis_client.get_answer(query)
if cached_answer:
return cached_answer, False
try:
# 分词
query_tokens = preprocess_text(query)
# 计算BM25的分数
scores = self.bm25.get_scores(query_tokens)
# 进行分数的归一化
softmax_score = self._softmax(scores)
# 获取最高分对应的索引
best_idx = softmax_score.argmax()
# 根据上述的索引获取最高分值
best_score = softmax_score[best_idx]
# 检查分数是否超过阈值
if best_score >= threshold:
# 获取原始的问题
original_question = self.original_questions[best_idx][0]
# 查数据库获得问题对应的答案
answer = self.mysql_client.fetch_answer(original_question)
if answer:
# 缓存qa
self.redis_client.set_data(f'answer:{query}', answer)
# 记录搜索成功
self.logger.info(f'搜索成功,Softamx相似度:{best_score:.3f}')
return answer, False
# 记录无可靠答案
self.logger.info(f"未找到可靠答案,最高 Softmax 相似度: {best_score:.3f}")
# 返回 None 和 True
return None, True
except Exception as e:
self.logger.error(f'搜索查询失败:{e}')
return None, True
if __name__ == "__main__":
redis_client = RedisClient()
mysql_client = MySQLClient()
bm25_search = BM25Search(redis_client, mysql_client)
bm25_search.search(query="VMware安装VMware时显示灰色如何解决")
rag_qa/edu_document_loaders/edu_ocr.py
python
from typing import TYPE_CHECKING
def get_ocr(use_cuda: bool = True):
try:
from rapidocr_paddle import RapidOCR
ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda)
except ImportError:
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
return ocr
rag_qa/edu_document_loaders/edu_pdfloader.py
python
import cv2
import fitz
import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Iterator
from d_multi_layer_rag.rag_qa.edu_document_loaders.edu_ocr import get_ocr
from langchain_core.documents import Document
from langchain_core.document_loaders import BaseLoader
PDF_OCR_THRESHOLD = (0.6, 0.6)
class OCRPDFLoader(BaseLoader):
def __init__(self, file_path: str) -> None:
self.file_path = file_path
def lazy_load(self) -> Iterator[Document]:
line = self.pdf2text()
yield Document(page_content=line, metadata={"source": self.file_path})
def pdf2text(self):
ocr = get_ocr()
doc = fitz.open(self.file_path)
resp = ""
b_unit = tqdm(total=doc.page_count, desc="OCRPDFLoader context page index: 0")
for i, page in enumerate(doc):
b_unit.set_description("OCRPDFLoader context page index: {}".format(i))
b_unit.refresh()
text = page.get_text("text")
resp += text + "\n"
img_list = page.get_image_info(xrefs=True)
for img in img_list:
if xref := img.get("xref"):
bbox = img["bbox"]
if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
continue
pix = fitz.Pixmap(doc, xref)
if int(page.rotation) != 0:
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
tmp_img = Image.fromarray(img_array)
ori_img = cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR)
rot_img = self.rotate_img(img=ori_img, angle=360 - page.rotation)
img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR)
else:
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
b_unit.update(1)
return resp
def rotate_img(self, img, angle):
h, w = img.shape[:2]
rotate_center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0)
new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0]))
new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1]))
M[0, 2] += (new_w - w) / 2
M[1, 2] += (new_h - h) / 2
rotated_img = cv2.warpAffine(img, M, (new_w, new_h))
return rotated_img
if __name__ == '__main__':
pdf_loader = OCRPDFLoader(file_path="../../../data/ocr_samples/ocr_03.pdf")
doc = pdf_loader.load()
print(type(doc))
print(doc)
rag_qa/edu_document_loaders/edu_docloader.py
python
from typing import Iterator
from d_multi_layer_rag.rag_qa.edu_document_loaders.edu_ocr import get_ocr
from tqdm import tqdm
from docx.table import _Cell, Table
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.text.paragraph import Paragraph
from docx import Document as Docu1
from docx.document import Document as Docu2
from docx import ImagePart
from PIL import Image
from io import BytesIO
import numpy as np
from langchain_core.documents import Document
from langchain_core.document_loaders import BaseLoader
class OCRDOCLoader(BaseLoader):
def __init__(self, filepath: str) -> None:
self.filepath = filepath
def lazy_load(self) -> Iterator[Document]:
line = self.doc2text(self.filepath)
yield Document(page_content=line, metadata={"source": self.filepath})
def doc2text(self, filepath):
ocr = get_ocr()
doc = Docu1(filepath)
resp = ""
def iter_block_items(parent):
if isinstance(parent, Docu2):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("OCRDOCLoader parse fail")
for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)
b_unit = tqdm(total=len(doc.paragraphs) + len(doc.tables),
desc="OCRDOCLoader block index: 0")
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description("OCRDOCLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath('.//pic:pic')
for image in images:
for img_id in image.xpath('.//a:blip/@r:embed'):
part = doc.part.related_parts[img_id]
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp
if __name__ == '__main__':
docx_loader = OCRDOCLoader(filepath='../../../data/ocr_samples/ocr_02.docx')
doc = docx_loader.load()
print(doc)
rag_qa/edu_document_loaders/edu_pptloader.py
python
from typing import Iterator
from d_multi_layer_rag.rag_qa.edu_document_loaders.edu_ocr import get_ocr
from langchain_core.documents import Document
from langchain_core.document_loaders import BaseLoader
from pptx import Presentation
from PIL import Image
import numpy as np
from io import BytesIO
from tqdm import tqdm
class OCRPPTLoader(BaseLoader):
def __init__(self, filepath: str) -> None:
self.filepath = filepath
def lazy_load(self) -> Iterator[Document]:
line = self.ppt2text(self.filepath)
yield Document(page_content=line, metadata={"source": self.filepath})
def ppt2text(self, filepath):
prs = Presentation(filepath)
ocr = get_ocr()
resp = ""
def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13:
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6:
for child_shape in shape.shapes:
extract_text(child_shape)
b_unit = tqdm(total=len(prs.slides), desc="OCRPPTLoader slide index: 1")
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description("OCRPPTLoader slide index: {}".format(slide_number))
b_unit.refresh()
sorted_shapes = sorted(slide.shapes, key=lambda x: (x.top, x.left))
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp
if __name__ == '__main__':
img_loader = OCRPPTLoader(filepath='../../../data/ocr_samples/ocr_01.pptx')
doc = img_loader.load()
print(doc)
rag_qa/edu_document_loaders/edu_imgloader.py
python
from typing import Iterator
from d_multi_layer_rag.rag_qa.edu_document_loaders.edu_ocr import get_ocr
from langchain_core.documents import Document
from langchain_core.document_loaders import BaseLoader
class OCRIMGLoader(BaseLoader):
def __init__(self, img_path: str) -> None:
self.img_path = img_path
def lazy_load(self) -> Iterator[Document]:
line = self.img2text()
yield Document(page_content=line, metadata={"source": self.img_path})
def img2text(self):
resp = ""
ocr = get_ocr()
result, _ = ocr(self.img_path)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp
if __name__ == '__main__':
img_loader = OCRIMGLoader(img_path='../../../data/ocr_samples/ocr_04.png')
doc = img_loader.load()
print(doc)
rag_qa/edu_text_spliter/edu_chinese_recursive_text_splitter.py
python
import re
from typing import List, Optional, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging
logger = logging.getLogger(__name__)
def _split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
if separator:
if keep_separator:
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
"\n\n",
"\n",
"。|!|?",
"\.\s|\!\s|\?\s",
";|;\s",
",|,\s"
]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
final_chunks = []
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip() != ""]
if __name__ == "__main__":
text_splitter = ChineseRecursiveTextSplitter(
keep_separator=True,
is_separator_regex=True,
chunk_size=150,
chunk_overlap=10
)
ls = [
"""中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的 62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业"缺芯"、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""",
]
for inum, text in enumerate(ls):
print(inum)
chunks = text_splitter.split_text(text)
for chunk in chunks:
print(chunk)
print('*' * 80)
rag_qa/edu_text_spliter/edu_model_text_spliter.py
python
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from modelscope.pipelines import pipeline
class AliTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
p = pipeline(
task="document-segmentation",
model='../models/nlp_bert_document-segmentation_chinese-base',
device="cpu")
result = p(documents=text)
sent_list = [i for i in result["text"].split("\n\t") if i]
return sent_list
if __name__ == '__main__':
model_split = AliTextSplitter()
result = model_split.split_text(text='移动端语音唤醒模型,检测关键词为"小云小云"。模型主体为4层FSMN结构,使用CTC训练准则,参数量750K,适用于移动端设备运行。模型输入为Fbank特征,输出为基于char建模的中文全集token预测,测试工具根据每一帧的预测数据进行后处理得到输入音频的实时检测结果。模型训练采用"basetrain + finetune"的模式,basetrain过程使用大量内部移动端数据,在此基础上,使用1万条设备端录制安静场景"小云小云"数据进行微调,得到最终面向业务的模型。后续用户可在basetrain模型基础上,使用其他关键词数据进行微调,得到新的语音唤醒模型,但暂时未开放模型finetune功能。')
print(result)
rag_qa/core/document_processor.py
python
# 这个脚本讲义的代码架构图没有体现,需要进行补充
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.text_splitter import MarkdownTextSplitter
from datetime import datetime
import sys
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取core文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
sys.path.insert(0, rag_qa_path)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)
from d_multi_layer_rag.rag_qa.edu_document_loaders import OCRPDFLoader, OCRDOCLoader, OCRPPTLoader, OCRIMGLoader
from d_multi_layer_rag.rag_qa.edu_text_spliter import ChineseRecursiveTextSplitter
from d_multi_layer_rag.base import logger, Config
conf = Config()
# 定义支持的文件类型及其对应的加载器字典
document_loaders = {
# 文本文件使用 TextLoader
".txt": TextLoader,
# PDF 文件使用 OCRPDFLoader
".pdf": OCRPDFLoader,
# Word 文件使用 OCRDOCLoader
".docx": OCRDOCLoader,
# PPT 文件使用 OCRPPTLoader
".ppt": OCRPPTLoader,
# PPTX 文件使用 OCRPPTLoader
".pptx": OCRPPTLoader,
# JPG 文件使用 OCRIMGLoader
".jpg": OCRIMGLoader,
# PNG 文件使用 OCRIMGLoader
".png": OCRIMGLoader,
# Markdown 文件使用 UnstructuredMarkdownLoader
".md": UnstructuredMarkdownLoader
}
# 定义函数,从指定文件夹加载多种类型文件并添加元数据
def load_documents_from_directory(directory_path):
# 初始化空列表,用于存储加载的文档
documents = []
# 获取支持的文件扩展名集合
supported_extensions = document_loaders.keys()
# 从目录名提取学科类别(如 "ai_data" -> "ai")
source = os.path.basename(directory_path).replace("_data", "")
# 遍历指定目录及其子目录
for root, _, files in os.walk(directory_path):
# 遍历当前目录下的所有文件
for file in files:
# 构造文件的完整路径
file_path = os.path.join(root, file)
# 获取文件扩展名并转换为小写
file_extension = os.path.splitext(file_path)[1].lower()
# 检查文件类型是否在支持的扩展名列表中
if file_extension in supported_extensions:
# 使用 try-except 捕获加载过程中的异常
try:
# 根据文件扩展名获取对应的加载器类
loader_class = document_loaders[file_extension]
# 实例化加载器对象,传入文件路径
if file_extension == ".txt":
loader = loader_class(file_path, encoding="utf-8")
else:
loader = loader_class(file_path)
# 调用加载器加载文档内容,返回文档列表
loaded_docs = loader.load()
for doc in loaded_docs:
# 为文档添加学科类别元数据
doc.metadata["source"] = source
# 为文档添加文件路径元数据
doc.metadata["file_path"] = file_path
# 为文档添加当前时间戳元数据
doc.metadata["timestamp"] = datetime.now().isoformat()
documents.extend(loaded_docs)
# 记录成功加载文件的日志
logger.info(f"成功加载文件: {file_path}")
except Exception as e:
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
# 如果文件类型不在支持列表中
else:
# 记录警告日志,提示不支持的文件类型
logger.warning(f"不支持的文件类型: {file_path}")
# 返回加载的所有文档列表
return documents
# 定义函数,处理文档并进行分层切分,返回子块结果
def process_documents(directory_path, parent_chunk_size=conf.PARENT_CHUNK_SIZE,
child_chunk_size=conf.CHILD_CHUNK_SIZE,
chunk_overlap=conf.CHUNK_OVERLAP):
# 从指定目录加载所有文档
documents = load_documents_from_directory(directory_path)
# 记录加载的文档总数日志
logger.info(f"加载的文档数量: {len(documents)}")
# 初始化父块和子块分词器(通用)
parent_splitter = ChineseRecursiveTextSplitter(chunk_size=parent_chunk_size, chunk_overlap=chunk_overlap)
child_splitter = ChineseRecursiveTextSplitter(chunk_size=child_chunk_size, chunk_overlap=chunk_overlap)
# 初始化 Markdown 专用分词器
markdown_parent_splitter = MarkdownTextSplitter(chunk_size=parent_chunk_size, chunk_overlap=chunk_overlap)
markdown_child_splitter = MarkdownTextSplitter(chunk_size=child_chunk_size, chunk_overlap=chunk_overlap)
# 初始化空列表,用于存储所有子块
child_chunks = []
# 遍历每个原始文档,带上索引 i
for i, doc in enumerate(documents):
file_extension = os.path.splitext(doc.metadata.get("file_path", ''))[1].lower()
# 选择分词器
is_markdown = (file_extension == '.md')
parent_splitter_to_use = markdown_parent_splitter if is_markdown else parent_splitter
child_splitter_to_use = markdown_child_splitter if is_markdown else child_splitter
logger.info(f"处理文档: {doc.metadata['file_path']}, 使用切分器: {'Markdown' if is_markdown else 'ChineseRecursive'}")
# 使用父块切分器将文档切分为父块
parent_docs = parent_splitter_to_use.split_documents([doc])
# 遍历每个父块,带上索引 j
for j, parent_doc in enumerate(parent_docs):
# 为父块生成唯一 ID,格式为 "doc_i_parent_j"
parent_id = f"doc_{i}_parent_{j}"
# 使用子块分词器将父块切分为子块
sub_chunks = child_splitter_to_use.split_documents([parent_doc])
# 遍历每个子块,为子块主要添加对应的父块文档
for k, sub_chunk in enumerate(sub_chunks):
# 为子块添加父块的ID
sub_chunk.metadata["parent_id"] = parent_id
# 为子块添加对应的父块文档(元数据)
sub_chunk.metadata["parent_content"] = parent_doc.page_content
# 为子块生成一个唯一的ID,格式为"parent_id_child_k"
sub_chunk.metadata["id"] = f"{parent_id}_child_{k}"
# 将子块添加到子块列表中
child_chunks.append(sub_chunk)
# 记录子块总数日志
logger.info(f"子块数量: {len(child_chunks)}")
# 返回所有子块列表
return child_chunks
if __name__ == '__main__':
directory_path = '/path/to/your/data/ai_data'
child_chunks = process_documents(directory_path)
print(f'child_chunks--》{child_chunks[0]}')
rag_qa/core/vector_store.py
python
# -*- coding:utf-8 -*-
# 导入 BGE-M3 嵌入函数,用于生成文档和查询的向量表示
import torch.cuda
from milvus_model.hybrid import BGEM3EmbeddingFunction
# 导入 Milvus 相关类,用于操作向量数据库
from pymilvus import MilvusClient, DataType, AnnSearchRequest, WeightedRanker
# 导入 Document 类,用于创建文档对象
from langchain.docstore.document import Document
# 导入 CrossEncoder,用于重排序和 NLI 判断
from sentence_transformers import CrossEncoder
# 导入 hashlib 模块,用于生成唯一 ID 的哈希值
import hashlib
import sys
import os
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取core文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
core_path = os.path.join(rag_qa_path, 'core')
sys.path.insert(0, core_path)
sys.path.insert(0, rag_qa_path)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)
from document_processor import *
from d_multi_layer_rag.base import logger, Config
conf = Config()
# core/vector_store.py
# 定义 VectorStore 类,封装向量存储和检索功能
class VectorStore:
# 初始化方法,设置向量存储的基本参数
def __init__(self,
collection_name=conf.MILVUS_COLLECTION_NAME,
host=conf.MILVUS_HOST,
port=conf.MILVUS_PORT,
database=conf.MILVUS_DATABASE_NAME):
# 设置 Milvus 集合名称
self.collection_name = collection_name
# 设置 Milvus 主机地址
self.host = host
# 设置 Milvus 端口号
self.port = port
# 设置 Milvus 数据库名称
self.database = database
# 设置日志记录器
self.logger = logger
# 检查CUDA是否可用
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 日志提醒使用的是什么设备
self.logger.info(f"使用设置:{self.device}")
# 初始化 BGE-Reranker 模型,用于重排序检索结果
reranker_path = os.path.join("../", 'models', 'bge-reranker-large')
self.reranker = CrossEncoder(reranker_path, device=self.device)
# 初始化 BGE-M3 嵌入函数,使用 CPU 设备,不启用 FP16
m3_path = os.path.join("../", 'models', 'bge-m3')
self.embedding_function = BGEM3EmbeddingFunction(model_name_or_path=m3_path, use_fp16=(self.device == 'cuda'), device=self.device)
# 获取稠密向量的维度# 1024
self.dense_dim = self.embedding_function.dim["dense"]
# 初始化 Milvus 客户端,连接到指定主机和数据库
self.client = MilvusClient(uri=f"http://{self.host}:{self.port}", db_name=self.database)
# 调用方法创建或加载 Milvus 集合
self._create_or_load_collection()
# 类私有化方法
def _create_or_load_collection(self):
# 检查指定集合是否已经存在
if not self.client.has_collection(self.collection_name):
# 创建集合 Schema,禁用自动 ID,启用动态字段
schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True)
# 添加 ID 字段,作为主键,VARCHAR 类型,最大长度 100
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
# 添加文本字段,VARCHAR 类型,最大长度 65535
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
# 添加稠密向量字段,FLOAT_VECTOR 类型,维度由嵌入函数指定
schema.add_field(field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=self.dense_dim)
# 添加稀疏向量字段,SPARSE_FLOAT_VECTOR 类型
schema.add_field(field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR)
# 添加父块 ID 字段,VARCHAR 类型,最大长度 100
schema.add_field(field_name="parent_id", datatype=DataType.VARCHAR, max_length=100)
# 添加父块内容字段,VARCHAR 类型,最大长度 65535
schema.add_field(field_name="parent_content", datatype=DataType.VARCHAR, max_length=65535)
# 添加学科类别字段,VARCHAR 类型,最大长度 50
schema.add_field(field_name="source", datatype=DataType.VARCHAR, max_length=50)
# 添加时间戳字段,VARCHAR 类型,最大长度 50
schema.add_field(field_name="timestamp", datatype=DataType.VARCHAR, max_length=50)
# 创建索引参数对象
index_params = self.client.prepare_index_params()
# 为稠密向量字段添加 IVF_FLAT 索引,度量类型为内积 (IP)
index_params.add_index(
field_name="dense_vector",
index_name="dense_index",
index_type="IVF_FLAT",
metric_type="IP",
params={"nlist": 128}
)
# 为稀疏向量字段添加 SPARSE_INVERTED_INDEX 索引,度量类型为内积 (IP)
index_params.add_index(
field_name="sparse_vector",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX",
metric_type="IP",
params={"drop_ratio_build": 0.2}
)
# 创建 Milvus 集合,应用定义的 Schema 和索引参数
self.client.create_collection(collection_name=self.collection_name, schema=schema,
index_params=index_params)
# 记录创建集合的日志
logger.info(f"已创建集合 {self.collection_name}")
# 如果集合已存在
else:
# 记录加载集合的日志
logger.info(f"已加载集合 {self.collection_name}")
# 将集合加载到内存,确保可立即查询
self.client.load_collection(self.collection_name)
# 定义方法,向向量存储添加文档
def add_documents(self, documents):
# 提取所有文档的内容列表
texts = [doc.page_content for doc in documents]
# 使用 BGE-M3 嵌入函数生成文档的嵌入
embeddings = self.embedding_function(texts)
# 初始化空列表,存储插入的数据
data = []
# 遍历每个文档,带上索引i
for i, doc in enumerate(documents):
# 生成文档内容的哈希值作为唯一的ID
text_hash = hashlib.md5(doc.page_content.encode('utf-8')).hexdigest()
# 初始化一个稀疏向量的字典(Milvus要求存储稀疏向量的格式)
sparse_vector = {}
# 获取第i行对应的稀疏向量数据
row = embeddings["sparse"].getrow(i)
# 获取稀疏向量的非零值的索引
indices = row.indices
# 获取稀疏向量的非零值
values = row.data
# 将索引和值进行配对,存储到字典中
for idx, value in zip(indices, values):
sparse_vector[idx] = value
# 创建数据字典,包含所有字段
data.append({
"id": text_hash,
"text": doc.page_content,
"dense_vector": embeddings["dense"][i],
"sparse_vector": sparse_vector,
"parent_id": doc.metadata["parent_id"],
"parent_content": doc.metadata["parent_content"],
"source": doc.metadata.get("source", "unknown"),
"timestamp": doc.metadata.get("timestamp", "unknown")
})
# 检查是否有数据需要插入
if data:
# 使用 upsert 操作插入数据,覆盖重复 ID
self.client.upsert(collection_name=self.collection_name, data=data)
# 记录插入或更新的文档数量日志
logger.info(f"已插入或更新 {len(data)} 个文档")
# 定义方法,执行混合检索并重排序
def hybrid_search_with_rerank(self, query, k=conf.RETRIEVAL_K, source_filter=None):
# 使用 BGE-M3 嵌入函数生成查询的嵌入
query_embeddings = self.embedding_function([query])
# 获取查询的稠密向量
dense_query_vector = query_embeddings["dense"][0]
# 初始化查询的稀疏向量字典
sparse_query_vector = {}
# 获取查询稀疏向量的第 0 行数据
row = query_embeddings["sparse"].getrow(0)
# 获取稀疏向量的非零值索引
indices = row.indices
# 获取稀疏向量的非零值
values = row.data
# 将索引和值配对,填充稀疏向量字典
for idx, value in zip(indices, values):
sparse_query_vector[idx] = value
# 初始化过滤表达式,默认不过滤
filter_expr = f"source == '{source_filter}'" if source_filter else ""
# 创建稠密向量搜索请求
dense_request = AnnSearchRequest(
data=[dense_query_vector],
anns_field="dense_vector",
param={"metric_type": "IP", "params": {"nprobe": 10}},
limit=k,
expr=filter_expr
)
# 创建稀疏向量搜索请求
sparse_request = AnnSearchRequest(
data=[sparse_query_vector],
anns_field="sparse_vector",
param={"metric_type": "IP", "params": {}},
limit=k,
expr=filter_expr
)
# 创建加权排序器,稀疏向量权重 0.7,稠密向量权重 1.0
ranker = WeightedRanker(1.0, 0.7)
# 执行混合搜索,返回 Top-K 结果
results = self.client.hybrid_search(
collection_name=self.collection_name,
reqs=[dense_request, sparse_request],
ranker=ranker,
limit=k,
output_fields=["text", "parent_id", "parent_content", "source", "timestamp"]
)[0]
# 将上述搜索到的结果进行Document对象封装,便于查询使用
sub_chunks = [self._doc_from_hit(hit["entity"]) for hit in results]
# 从子块中提取去重的父文档
parent_docs = self._get_unique_parent_docs(sub_chunks)
# 如果只有1个文档或者没有,直接返回跳过重排序
if len(parent_docs) < 2:
return parent_docs[:conf.CANDIDATE_M]
# 如果有父文档,进行重排序
if parent_docs:
# 创建查询与文档内容的配对列表
pairs = [[query, doc.page_content] for doc in parent_docs]
# 使用 BGE-Reranker 计算每个配对的得分
scores = self.reranker.predict(pairs)
# 根据得分从高到低排序文档
ranked_parent_docs = [doc for _, doc in sorted(zip(scores, parent_docs), reverse=True)]
else:
ranked_parent_docs = []
# 返回前 m 个重排序后的文档
return ranked_parent_docs[:conf.CANDIDATE_M]
def _get_unique_parent_docs(self, sub_chunks):
# 初始化集合,用于存储已处理的父块内容(去重)
parent_contents = set()
# 初始化列表,用于存储唯一父文档
unique_docs = []
# 遍历所有子块
for chunk in sub_chunks:
# 获取子块的父块内容,默认为子块内容
parent_content = chunk.metadata.get("parent_content", chunk.page_content)
# 检查父块内容是否非空且未重复
if parent_content and parent_content not in parent_contents:
# 创建新的 Document 对象,包含父块内容和元数据
unique_docs.append(Document(page_content=parent_content, metadata=chunk.metadata))
# 将父块内容添加到去重集合
parent_contents.add(parent_content)
return unique_docs
# 定义类似私有方法,从 Milvus 查询结果创建 Document 对象
def _doc_from_hit(self, hit):
# 创建并返回 Document 对象,填充内容和元数据
return Document(
page_content=hit.get("text"),
metadata={
"parent_id": hit.get("parent_id"),
"parent_content": hit.get("parent_content"),
"source": hit.get("source"),
"timestamp": hit.get("timestamp")
}
)
if __name__ == "__main__":
vector_store = VectorStore()
query = "AI学科的课程内容是什么"
results = vector_store.hybrid_search_with_rerank(query, source_filter='ai')
print(f'results-->{results}')
print(f'results-->{len(results)}')
rag_qa/core/prompts.py
python
# core/prompts.py
# 导入 PromptTemplate 类,用于创建 Prompt 模板
from langchain.prompts import PromptTemplate
# 定义 RAGPrompts 类,用于管理所有 Prompt 模板
class RAGPrompts:
@staticmethod
def rag_prompt():
'''添加了历史记录,注意用在:new_rag_system'''
return PromptTemplate(
template="""
你是一个智能助手,负责帮助用户回答问题。请按照以下步骤处理:
1. **分析问题和上下文**:
- 基于提供的上下文(如果有)和你的知识回答问题。
- 如果答案来源于检索到的文档,请在回答中明确说明,例如:"根据提供的文档,......"。
2. **评估对话历史**:
- 检查对话历史是否与当前问题相关(例如,是否涉及相同的话题、实体或问题背景)。
- 如果对话历史与问题相关,请结合历史信息生成更准确的回答。
- 如果对话历史无关(例如,仅包含问候或不相关的内容),忽略历史,仅基于上下文和问题回答。
3. **生成回答**:
- 提供清晰、准确的回答,避免无关信息。
- 如果上下文和历史消息均不足以回答问题,请回复:"信息不足,无法回答,请联系人工客服,电话:{phone}。"
**对话历史**:
{history}
**上下文**:
{context}
**问题**:
{question}
**回答**:
""",
input_variables=["context", "history", "question", "phone"],
)
@staticmethod
def hyde_prompt():
return PromptTemplate(
template="""
假设你是用户,想了解以下问题,请生成一个简短的假设答案:
问题: {query}
假设答案:
""",
input_variables=["query"],
)
@staticmethod
def subquery_prompt():
return PromptTemplate(
template="""
将以下复杂查询分解为多个简单子查询,每行一个子查询,最多生成两个子查询(只保留子查询问题,其他的文本都不需要):
eg:
用户原始query:"Milvus 和 Zilliz Cloud 在功能上有什么不同?
子查询:"Milvus 有哪些功能?","Zilliz Cloud 有哪些功能?"
查询: {query}
子查询:
""",
input_variables=["query"],
)
@staticmethod
def backtracking_prompt():
return PromptTemplate(
template="""
将以下复杂查询简化为一个更简单的问题:
查询: {query}
简化问题:
""",
input_variables=["query"],
)
if __name__ == '__main__':
hyde = RAGPrompts.subquery_prompt()
result = hyde.format(query="AI和JAVA有什么区别")
print(result)
rag_qa/core/query_classifier.py
python
# -*- coding:utf-8-*-
# 导入标准库
import json
import os
# 导入 PyTorch
import torch
import sys
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取core文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)
# 导入日志
from d_multi_layer_rag.base import logger
# 导入numpy
import numpy as np
# 导入 Transformers 库
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
# 导入train_test_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
class QueryClassifier:
def __init__(self, model_path="../models/bert_query_classifier"):
# 初始化模型路径
self.model_path = model_path
# 加载 BERT 分词器
bert_path = os.path.join("../", 'models', 'bert-base-chinese')
self.tokenizer = BertTokenizer.from_pretrained(bert_path)
# 初始化模型
self.model = None
# 确定设备(GPU 或 CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 记录设备信息
logger.info(f"使用设备: {self.device}")
# 定义标签映射
self.label_map = {"通用知识": 0, "专业咨询": 1}
# 加载模型
self.load_model()
def load_model(self):
# 检查模型路径是否存在
if os.path.exists(self.model_path):
# 加载预训练模型
self.model = BertForSequenceClassification.from_pretrained(self.model_path)
# 将模型移到指定设备
self.model.to(self.device)
# 记录加载成功的日志
logger.info(f"加载模型: {self.model_path}")
else:
# 初始化新模型
self.model = BertForSequenceClassification.from_pretrained("../models/bert-base-chinese", num_labels=2)
# 将模型移到指定设备
self.model.to(self.device)
# 记录初始化模型的日志
logger.info("初始化新 BERT 模型")
def save_model(self):
"""保存模型"""
self.model.save_pretrained("../models/bert_query_classifier_new")
self.tokenizer.save_pretrained("../models/bert_query_classifier_new")
logger.info(f"模型保存至: ../models/bert_query_classifier_new")
def train_model(self, data_file="training_dataset_hybrid_5000.json"):
"""训练 BERT 分类模型"""
# 加载数据集
if not os.path.exists(data_file):
logger.error(f"数据集文件 {data_file} 不存在")
raise FileNotFoundError(f"数据集文件 {data_file} 不存在")
with open(data_file, "r", encoding="utf-8") as f:
data = [json.loads(value) for value in f.readlines()]
texts = [item["query"] for item in data]
labels = [item["label"] for item in data]
# 数据划分
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2, random_state=42
)
# 数据预处理
train_encodings, train_labels = self.preprocess_data(train_texts, train_labels)
val_encodings, val_labels = self.preprocess_data(val_texts, val_labels)
# 得到dataset对象
train_dataset = self.create_dataset(train_encodings, train_labels)
val_dataset = self.create_dataset(val_encodings, val_labels)
# 设置训练参数
training_args = TrainingArguments(
output_dir="./bert_results",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=20,
weight_decay=0.01,
logging_dir="./bert_logs",
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
save_total_limit=1,
metric_for_best_model="eval_loss",
fp16=False,
)
# 初始化Trainer
trainer = Trainer(model=self.model, args=training_args,
train_dataset=train_dataset, eval_dataset=val_dataset,
compute_metrics=self.compute_metrics)
# 开始训练模型
logger.info("开始训练BERT模型")
trainer.train()
self.save_model()
# 对验证集进行训练好的模型验证
self.evaluate_model(val_texts, val_labels)
def preprocess_data(self, texts, labels):
"""预处理数据为 BERT 输入格式"""
encodings = self.tokenizer(
texts,
truncation=True,
padding='max_length',
max_length=128,
return_tensors="pt"
)
labels = [self.label_map[label] for label in labels]
return encodings, labels
def create_dataset(self, encodings, labels):
# 自定义Dataset类
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
super().__init__()
self.encodings = encodings
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
dicts = {key: value[idx] for key, value in self.encodings.items()}
dicts["labels"] = torch.tensor(self.labels[idx])
return dicts
return Dataset(encodings, labels)
def compute_metrics(self, eval_pred):
"""计算评估指标"""
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = (predictions == labels).mean()
return {"accuracy": accuracy}
def evaluate_model(self, texts, labels):
"""评估模型性能"""
encodings = self.tokenizer(
texts,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
dataset = self.create_dataset(encodings, labels)
trainer = Trainer(model=self.model)
predictions = trainer.predict(dataset)
pred_labels = np.argmax(predictions.predictions, axis=-1)
true_labels = labels
logger.info("分类报告:")
logger.info(classification_report(
true_labels,
pred_labels,
target_names=["通用知识", "专业咨询"]
))
logger.info("混淆矩阵:")
logger.info(confusion_matrix(true_labels, pred_labels))
def predict_category(self, query):
# 检查模型是否加载
if self.model is None:
# 模型未加载,记录错误
logger.error("模型未训练或加载")
# 默认返回通用知识
return "通用知识"
# 对查询进行编码
encoding = self.tokenizer(query, truncation=True, padding=True, max_length=128, return_tensors="pt")
# 将编码移到指定设备
encoding = {k: v.to(self.device) for k, v in encoding.items()}
# 不计算梯度,进行预测
with torch.no_grad():
# 获取模型输出
outputs = self.model(**encoding)
prediction = torch.argmax(outputs.logits, dim=1).item()
# 根据预测结果返回类别
return "专业咨询" if prediction == 1 else "通用知识"
if __name__ == '__main__':
query_classify = QueryClassifier()
result = query_classify.predict_category(query="AI的课程大纲是什么")
print(result)
rag_qa/core/strategy_selector.py
python
# -*-coding:utf-8-*-
# core/strategy_selector.py 源码
import sys
import os
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取core文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
sys.path.insert(0, rag_qa_path)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)
# 导入 LangChain 提示模板
from langchain.prompts import PromptTemplate
# 导入日志和配置
from d_multi_layer_rag.base import logger, Config
# 导入 OpenAI
from openai import OpenAI
class StrategySelector:
def __init__(self):
# 初始化 OpenAI 客户端
self.client = OpenAI(api_key=Config().DASHSCOPE_API_KEY,
base_url=Config().DASHSCOPE_BASE_URL)
# 获取策略选择提示模板
self.strategy_prompt_template = self._get_strategy_prompt()
def call_dashscope(self, prompt):
import ollama
result = "直接检索"
# 调用 DashScope API
try:
# 创建聊天完成请求
response = ollama.chat(model='qwen2.5:7b',
messages=[
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": prompt},
],
options={
"temperature": 0.5
}
)
result = response['message']['content']
return result
except Exception as e:
# 记录 API 调用失败
logger.error(f"DashScope API 调用失败: {e}")
# 默认返回直接检索
return result
def _get_strategy_prompt(self):
return PromptTemplate(
template="""
你是一个智能助手,负责分析用户查询 {query},并从以下四种检索增强策略中选择一个最适合的策略,直接返回策略名称,不需要解释过程。
以下是几种检索增强策略及其适用场景:
1. **直接检索:**
* 描述:对用户查询直接进行检索,不进行任何增强处理。
* 适用场景:适用于查询意图明确,需要从知识库中检索**特定信息**的问题,例如:
* 示例:
* 查询:AI 学科学费是多少?
* 策略:直接检索
* 查询:JAVA的课程大纲是什么?
* 策略:直接检索
2. **假设问题检索(HyDE):**
* 描述:使用 LLM 生成一个假设的答案,然后基于假设答案进行检索。
* 适用场景:适用于查询较为抽象,直接检索效果不佳的问题,例如:
* 示例:
* 查询:人工智能在教育领域的应用有哪些?
* 策略:假设问题检索
3. **子查询检索:**
* 描述:将复杂的用户查询拆分为多个简单的子查询,分别检索并合并结果。
* 适用场景:适用于查询涉及多个实体或方面,需要分别检索不同信息的问题,例如:
* 示例:
* 查询:比较 Milvus 和 Zilliz Cloud 的优缺点。
* 策略:子查询检索
4. **回溯问题检索:**
* 描述:将复杂的用户查询转化为更基础、更易于检索的问题,然后进行检索。
* 适用场景:适用于查询较为复杂,需要简化后才能有效检索的问题,例如:
* 示例:
* 查询:我有一个包含 100 亿条记录的数据集,想把它存储到 Milvus 中进行查询。可以吗?
* 策略:回溯问题检索
根据用户查询 {query},直接返回最适合的策略名称,例如 "直接检索"。不要输出任何分析过程或其他内容。
""",
input_variables=["query"],
)
def select_strategy(self, query):
strategy = self.call_dashscope(self.strategy_prompt_template.format(query=query)).strip()
logger.info(f"为查询 '{query}' 选择的检索策略:{strategy}")
return strategy
if __name__ == '__main__':
ss = StrategySelector()
ss.select_strategy(query="Mysql数据库能不能支持100w个样本的插入")
rag_qa/core/rag_system.py
python
# -*-coding:utf-8-*-
# core/rag_system.py 源码
import sys
import os
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取core文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
sys.path.insert(0, rag_qa_path)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)
from prompts import RAGPrompts
# 导入 time 模块,用于计算时间
import time
from d_multi_layer_rag.base import logger, Config
from query_classifier import QueryClassifier # 导入查询分类器
from strategy_selector import StrategySelector # 导入策略选择器
from vector_store import VectorStore # 导入向量数据库对象
conf = Config()
# 定义 RAGSystem 类,封装 RAG 系统的核心逻辑
class RAGSystem:
# 初始化方法,设置 RAG 系统的基本参数
def __init__(self, vector_store, llm):
# 设置向量数据库对象
self.vector_store = vector_store
# 设置大语言模型调用函数
self.llm = llm
# 获取 RAG 提示模板
self.rag_prompt = RAGPrompts.rag_prompt()
# 初始化查询分类器
classifier_path = os.path.join("../", 'models', 'bert_query_classifier')
self.query_classifier = QueryClassifier(model_path=classifier_path)
# 初始化策略选择器
self.strategy_selector = StrategySelector()
# 定义类似私有方法,使用回溯问题进行检索
def _retrieve_with_backtracking(self, query, source_filter):
logger.info(f"使用回溯问题策略进行检索 (查询: '{query}')")
backtrack_prompt_template = RAGPrompts.backtracking_prompt()
try:
simplified_query = self.llm(backtrack_prompt_template.format(query=query)).strip()
logger.info(f"生成的回溯问题: '{simplified_query}'")
return self.vector_store.hybrid_search_with_rerank(
simplified_query, k=conf.RETRIEVAL_K, source_filter=source_filter
)
except Exception as e:
logger.error(f"回溯问题策略执行失败: {e}")
return []
# 定义类似私有方法,使用子查询进行检索
def _retrieve_with_subqueries(self, query, source_filter):
logger.info(f"使用子查询策略进行检索 (查询: '{query}')")
subquery_prompt_template = RAGPrompts.subquery_prompt()
try:
subqueries_text = self.llm(subquery_prompt_template.format(query=query)).strip()
subqueries = [q.strip() for q in subqueries_text.split("\n") if q.strip()]
logger.info(f"生成的子查询: {subqueries}")
if not subqueries:
logger.warning("未能生成有效的子查询")
return []
all_docs = []
for sub_q in subqueries:
docs = self.vector_store.hybrid_search_with_rerank(
sub_q, k=conf.CANDIDATE_M // 2, source_filter=source_filter
)
all_docs.extend(docs)
logger.info(f"子查询 '{sub_q}' 检索到 {len(docs)} 个文档")
unique_docs_dict = {doc.page_content: doc for doc in all_docs}
unique_docs = list(unique_docs_dict.values())
logger.info(f"所有子查询共检索到 {len(all_docs)} 个文档, 去重后剩 {len(unique_docs)} 个")
return unique_docs
except Exception as e:
logger.error(f'子查询存在错误:{e}')
return []
# 定义私有方法,使用假设文档进行检索(HyDE)
def _retrieve_with_hyde(self, query, source_filter):
logger.info(f"使用 HyDE 策略进行检索 (查询: '{query}')")
hyde_prompt_template = RAGPrompts.hyde_prompt()
try:
hypo_answer = self.llm(hyde_prompt_template.format(query=query)).strip()
logger.info(f"HyDE 生成的假设答案: '{hypo_answer}'")
return self.vector_store.hybrid_search_with_rerank(
hypo_answer, k=conf.RETRIEVAL_K, source_filter=source_filter
)
except Exception as e:
logger.error(f"HyDE 策略执行失败: {e}")
return []
def retrieve_and_merge(self, query, source_filter=None, strategy=None):
if not strategy:
strategy = self.strategy_selector.select_strategy(query)
ranked_chunks = []
if strategy == "回溯问题检索":
ranked_chunks = self._retrieve_with_backtracking(query, source_filter)
elif strategy == '子查询检索':
ranked_chunks = self._retrieve_with_subqueries(query, source_filter)
elif strategy == "假设问题检索":
ranked_chunks = self._retrieve_with_hyde(query, source_filter)
else:
logger.info(f"使用直接检索策略 (查询: '{query}')")
ranked_chunks = self.vector_store.hybrid_search_with_rerank(
query, k=conf.RETRIEVAL_K, source_filter=source_filter
)
logger.info(f"策略 '{strategy}' 检索到 {len(ranked_chunks)} 个候选文档 (可能已是父文档)")
final_context_docs = ranked_chunks[:conf.CANDIDATE_M]
logger.info(f"最终选取 {len(final_context_docs)} 个文档作为上下文")
return final_context_docs
def generate_answer(self, query, source_filter=None):
start_time = time.time()
logger.info(f"开始处理查询: '{query}', 学科过滤: {source_filter}")
query_category = self.query_classifier.predict_category(query)
logger.info(f"查询分类结果:{query_category} (查询: '{query}')")
if query_category == "通用知识":
logger.info("查询为通用知识,直接调用 LLM")
prompt_input = self.rag_prompt.format(
context="", history="", question=query, phone=conf.CUSTOMER_SERVICE_PHONE
)
try:
answer = self.llm(prompt_input)
except Exception as e:
logger.error(f"直接调用 LLM 失败: {e}")
answer = f"抱歉,处理您的通用知识问题时出错。请联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}"
processing_time = time.time() - start_time
logger.info(f"通用知识查询处理完成 (耗时: {processing_time:.2f}s, 查询: '{query}')")
return answer
logger.info("查询为专业咨询,执行 RAG 流程")
strategy = self.strategy_selector.select_strategy(query)
context_docs = self.retrieve_and_merge(query, source_filter=source_filter, strategy=strategy)
if context_docs:
context = "\n\n".join([doc.page_content for doc in context_docs])
logger.info(f"构建上下文完成,包含 {len(context_docs)} 个文档块")
else:
context = ""
logger.info("未检索到相关文档,上下文为空")
prompt_input = self.rag_prompt.format(
context=context, history="", question=query, phone=conf.CUSTOMER_SERVICE_PHONE
)
try:
answer = self.llm(prompt_input)
except Exception as e:
logger.error(f"调用 LLM 生成最终答案失败: {e}")
answer = f"抱歉,处理您的专业咨询问题时出错。请联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}"
processing_time = time.time() - start_time
logger.info(f"查询处理完成 (耗时: {processing_time:.2f}s, 查询: '{query}')")
return answer
if __name__ == '__main__':
vector_store = VectorStore()
llm = StrategySelector().call_dashscope
rag_system = RAGSystem(vector_store, llm)
answer = rag_system.generate_answer(query="AI学科的课程大纲内容有什么", source_filter="ai")
print(answer)
rag_qa/core/new_rag_system.py
python
'''
todo: 和之前的rag_system不一样的地方是:生成答案时,考虑了历史对话记录,以及我们大模型输出结果时stream流式输出结果
'''
# -*-coding:utf-8-*-
# core/rag_system.py 源码
import sys
import os
# 导入 OpenAI 客户端,用于调用 DashScope API
from openai import OpenAI
# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取core文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
sys.path.insert(0, rag_qa_path)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)
from prompts import RAGPrompts
# 导入 time 模块,用于计算时间
import time
from d_multi_layer_rag.base import logger, Config
from query_classifier import QueryClassifier # 导入查询分类器
from strategy_selector import StrategySelector # 导入策略选择器
from vector_store import VectorStore # 导入向量数据库对象
conf = Config()
# 定义 RAGSystem 类,封装 RAG 系统的核心逻辑
class RAGSystem:
# 初始化方法,设置 RAG 系统的基本参数
def __init__(self, vector_store, llm):
# 设置向量数据库对象
self.vector_store = vector_store
# 设置大语言模型调用函数
self.llm = llm
# 获取 RAG 提示模板
self.rag_prompt = RAGPrompts.rag_prompt()
# 初始化查询分类器
classifier_path = os.path.join(rag_qa_path, 'core', 'bert_query_classifier')
self.query_classifier = QueryClassifier(model_path=classifier_path)
# 初始化策略选择器
self.strategy_selector = StrategySelector()
# 定义类似私有方法,使用回溯问题进行检索
def _retrieve_with_backtracking(self, query, source_filter):
logger.info(f"使用回溯问题策略进行检索 (查询: '{query}')")
backtrack_prompt_template = RAGPrompts.backtracking_prompt()
try:
simplified_query = self.llm(backtrack_prompt_template.format(query=query)).strip()
logger.info(f"生成的回溯问题: '{simplified_query}'")
return self.vector_store.hybrid_search_with_rerank(
simplified_query, k=conf.RETRIEVAL_K, source_filter=source_filter
)
except Exception as e:
logger.error(f"回溯问题策略执行失败: {e}")
return []
# 定义类似私有方法,使用子查询进行检索
def _retrieve_with_subqueries(self, query, source_filter):
logger.info(f"使用子查询策略进行检索 (查询: '{query}')")
subquery_prompt_template = RAGPrompts.subquery_prompt()
try:
subqueries_text = self.llm(subquery_prompt_template.format(query=query)).strip()
subqueries = [q.strip() for q in subqueries_text.split("\n") if q.strip()]
logger.info(f"生成的子查询: {subqueries}")
if not subqueries:
logger.warning("未能生成有效的子查询")
return []
all_docs = []
for sub_q in subqueries:
docs = self.vector_store.hybrid_search_with_rerank(
sub_q, k=conf.CANDIDATE_M // 2, source_filter=source_filter
)
all_docs.extend(docs)
logger.info(f"子查询 '{sub_q}' 检索到 {len(docs)} 个文档")
unique_docs_dict = {doc.page_content: doc for doc in all_docs}
unique_docs = list(unique_docs_dict.values())
logger.info(f"所有子查询共检索到 {len(all_docs)} 个文档, 去重后剩 {len(unique_docs)} 个")
return unique_docs
except Exception as e:
logger.error(f'子查询存在错误:{e}')
return []
# 定义私有方法,使用假设文档进行检索(HyDE)
def _retrieve_with_hyde(self, query, source_filter):
logger.info(f"使用 HyDE 策略进行检索 (查询: '{query}')")
hyde_prompt_template = RAGPrompts.hyde_prompt()
try:
hypo_answer = self.llm(hyde_prompt_template.format(query=query)).strip()
logger.info(f"HyDE 生成的假设答案: '{hypo_answer}'")
return self.vector_store.hybrid_search_with_rerank(
hypo_answer, k=conf.RETRIEVAL_K, source_filter=source_filter
)
except Exception as e:
logger.error(f"HyDE 策略执行失败: {e}")
return []
def retrieve_and_merge(self, query, source_filter=None, strategy=None):
if not strategy:
strategy = self.strategy_selector.select_strategy(query)
ranked_chunks = []
if strategy == "回溯问题检索":
ranked_chunks = self._retrieve_with_backtracking(query, source_filter)
elif strategy == '子查询检索':
ranked_chunks = self._retrieve_with_subqueries(query, source_filter)
elif strategy == "假设问题检索":
ranked_chunks = self._retrieve_with_hyde(query, source_filter)
else:
logger.info(f"使用直接检索策略 (查询: '{query}')")
ranked_chunks = self.vector_store.hybrid_search_with_rerank(
query, k=conf.RETRIEVAL_K, source_filter=source_filter
)
logger.info(f"策略 '{strategy}' 检索到 {len(ranked_chunks)} 个候选文档 (可能已是父文档)")
final_context_docs = ranked_chunks[:conf.CANDIDATE_M]
logger.info(f"最终选取 {len(final_context_docs)} 个文档作为上下文")
return final_context_docs
def generate_answer(self, query, source_filter=None, history=None):
start_time = time.time()
logger.info(f"开始处理查询: '{query}', 学科过滤: {source_filter}")
if history is not None and not isinstance(history, list):
logger.warning(f'无效的历史格式:{type(history)},忽略历史')
history = []
elif history:
history = history[-5:]
history_context = ''
if history:
history_context = "\n".join([f"Q:{h['question']}\nA:{h['answer']}" for h in history])
logger.info(f'使用对话历史:{history_context[:50]}')
query_category = self.query_classifier.predict_category(query)
logger.info(f"查询分类结果:{query_category} (查询: '{query}')")
if query_category == "通用知识":
logger.info("查询为通用知识,直接调用 LLM")
context = ''
else:
logger.info("查询为专业咨询,执行 RAG 流程")
strategy = self.strategy_selector.select_strategy(query)
context_docs = self.retrieve_and_merge(query, source_filter=source_filter, strategy=strategy)
if context_docs:
context = "\n\n".join([doc.page_content for doc in context_docs])
logger.info(f"构建上下文完成,包含 {len(context_docs)} 个文档块")
else:
context = ""
logger.info("未检索到相关文档,上下文为空")
prompt_input = self.rag_prompt.format(context=context,
question=query,
history=history_context,
phone=conf.CUSTOMER_SERVICE_PHONE)
try:
for token in self.llm(prompt_input):
yield token
process_time = time.time() - start_time
logger.info(f'LLM查询处理完成(耗时:{process_time:.2f}s, 查询:{query})')
except Exception as e:
logger.error(f'调用LLM失败:{e}')
yield f'抱歉,处理问题时出错,请你联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}'
if __name__ == '__main__':
vector_store = VectorStore()
def call_dashscope(prompt):
client = OpenAI(api_key=Config().DASHSCOPE_API_KEY,
base_url=Config().DASHSCOPE_BASE_URL)
try:
completion = client.chat.completions.create(
model=Config().LLM_MODEL,
messages=[
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": prompt},
],
timeout=30,
stream=True
)
for chunk in completion:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
except Exception as e:
logger.error(f"LLM调用失败: {e}")
return f"错误:LLM调用失败 - {e}"
rag_system = RAGSystem(vector_store, call_dashscope)
answer = rag_system.generate_answer(query="AI学科的课程大纲内容有什么", source_filter="ai")
for value in answer:
print(value)
new_main.py(新版集成问答系统入口)
python
# -*- coding:utf-8 -*-
# 导入 MySQL 和 Redis 客户端,管理数据库和缓存
from d_multi_layer_rag.mysql_qa import MySQLClient, RedisClient, BM25Search
# 导入 RAG 系统组件,用于知识库检索和答案生成
from d_multi_layer_rag.rag_qa.core.new_rag_system import VectorStore, RAGSystem
# 导入配置和日志工具,用于系统配置和日志记录
from d_multi_layer_rag.base import logger, Config
# 导入 OpenAI 客户端,用于调用 DashScope API
from openai import OpenAI
# 导入时间库,用于记录处理时间
import time
# 导入 UUID 库,生成唯一会话 ID
import uuid
# 导入 pymysql 错误处理,用于数据库操作的异常捕获
import pymysql
import ollama
class IntegratedQASystem:
def __init__(self):
# 初始化日志工具,用于记录系统运行信息
self.logger = logger
# 初始化配置对象,加载系统参数
self.config = Config()
# 初始化 MySQL 客户端,用于数据库操作
self.mysql_client = MySQLClient()
# 初始化 Redis 客户端,用于缓存管理
self.redis_client = RedisClient()
# 初始化 BM25 搜索模块,结合 MySQL 和 Redis
self.bm25_search = BM25Search(self.redis_client, self.mysql_client)
try:
# 初始化 OpenAI 客户端,连接 DashScope API
self.client = OpenAI(api_key=self.config.DASHSCOPE_API_KEY,
base_url=self.config.DASHSCOPE_BASE_URL)
except Exception as e:
# 记录 OpenAI 初始化失败的错误日志
self.logger.error(f"OpenAI 客户端初始化失败: {e}")
# 抛出异常,终止初始化
raise
# 初始化向量存储,用于 RAG 系统的知识库管理
self.vector_store = VectorStore()
# 初始化 RAG 系统,传入向量存储和 DashScope API 调用函数
self.rag_system = RAGSystem(self.vector_store, self.call_dashscope)
# 初始化对话历史表,用于存储会话记录
self.init_conversation_table()
def init_conversation_table(self):
"""初始化MySQL中的conversations表,用于存储对话历史"""
try:
# 创建 conversations 表,包含会话 ID、问题、答案和时间戳
self.mysql_client.cursor.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id INT AUTO_INCREMENT PRIMARY KEY,
session_id VARCHAR(36) NOT NULL,
question TEXT NOT NULL,
answer TEXT NOT NULL,
timestamp DATETIME NOT NULL,
INDEX idx_session_id (session_id)
)
""")
# 提交数据库事务
self.mysql_client.connection.commit()
# 记录表初始化成功的日志
self.logger.info("对话历史表初始化成功")
except pymysql.MySQLError as e:
# 记录表初始化失败的错误日志
self.logger.error(f"初始化对话历史表失败: {e}")
# 抛出异常,终止初始化
raise
def call_dashscope(self, prompt):
result = "错误: LLM返回无效响应"
try:
response = ollama.chat(model='qwen2.5:7b',
messages=[
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": prompt},
],
options={
"temperature": 0.5
},
stream=True
)
for chunk in response:
yield chunk['message']['content']
except Exception as e:
logger.error(f"LLM API (call_dashscope) 调用失败: {e}")
return f"错误: 调用LLM失败 - {e}"
def _fetch_recent_history(self, session_id):
"""获取最近5轮对话历史"""
try:
self.mysql_client.cursor.execute("""
SELECT question, answer
FROM conversations
WHERE session_id = %s
ORDER BY timestamp DESC
LIMIT %s
""", (session_id, 5))
history = [{"question": row[0], "answer": row[1]} for row in self.mysql_client.cursor.fetchall()]
return history[::-1]
except pymysql.MySQLError as e:
self.logger.error(f"获取对话历史失败: {e}")
return []
def get_session_history(self, session_id):
"""从MySQL获取会话历史"""
return self._fetch_recent_history(session_id)
def update_session_history(self, session_id: str, question: str, answer: str) -> list:
"""更新会话历史到MySQL,保留最近5轮对话"""
try:
self.mysql_client.cursor.execute("""
INSERT INTO conversations (session_id, question, answer, timestamp)
VALUES (%s, %s, %s, NOW())
""", (session_id, question, answer))
history = self._fetch_recent_history(session_id)
self.mysql_client.cursor.execute("""
DELETE FROM conversations
WHERE session_id = %s AND id NOT IN (
SELECT id FROM (
SELECT id
FROM conversations
WHERE session_id = %s
ORDER BY timestamp DESC
LIMIT %s
) AS sub
)
""", (session_id, session_id, 5))
self.mysql_client.connection.commit()
self.logger.info(f"会话 {session_id} 历史更新成功")
return history
except pymysql.MySQLError as e:
self.logger.error(f"更新会话历史失败: {e}")
self.mysql_client.connection.rollback()
raise
except Exception as e:
self.logger.error(f"更新会话历史意外错误: {e}")
self.mysql_client.connection.rollback()
raise
def clear_session_history(self, session_id: str) -> bool:
"""清除指定会话历史"""
try:
self.mysql_client.cursor.execute("""
DELETE FROM conversations
WHERE session_id = %s
""", (session_id,))
self.mysql_client.connection.commit()
self.logger.info(f"会话 {session_id} 历史已清除")
return True
except pymysql.MySQLError as e:
self.logger.error(f"清除会话历史失败: {e}")
self.mysql_client.connection.rollback()
return False
def query(self, query, source_filter=None, session_id=None):
"""查询集成系统,支持对话历史和流式输出"""
start_time = time.time()
self.logger.info(f"处理查询: '{query}' (会话ID: {session_id})")
history = self.get_session_history(session_id) if session_id else []
answer, need_rag = self.bm25_search.search(query, threshold=0.85)
if answer:
self.logger.info(f"MySQL答案: {answer}")
if session_id:
self.update_session_history(session_id, query, answer)
processing_time = time.time() - start_time
self.logger.info(f"查询处理耗时 {processing_time:.2f}秒")
yield answer, True
elif need_rag:
self.logger.info("无可靠MySQL答案,回退到RAG")
collected_answer = ""
for token in self.rag_system.generate_answer(query, source_filter=source_filter, history=history):
collected_answer += token
yield token, False
if session_id:
self.update_session_history(session_id, query, collected_answer)
processing_time = time.time() - start_time
self.logger.info(f"查询处理耗时 {processing_time:.2f}秒")
yield "", True
else:
self.logger.info("未找到答案")
processing_time = time.time() - start_time
self.logger.info(f"查询处理耗时 {processing_time:.2f}秒")
yield "未找到答案", True
def main():
new_qa_system = IntegratedQASystem()
try:
print("\n欢迎使用集成问答系统!")
print(f"支持的来源: {new_qa_system.config.VALID_SOURCES}")
print("输入查询进行问答,输入 'exit' 退出。")
while True:
session_id = input("\n请您输入对话ID: ").strip()
while not session_id:
session_id = input("\n请您输入对话ID: ").strip()
query = input("\n输入查询: ").strip()
if query.lower() == "exit":
logger.info("退出系统")
print("再见!")
break
source_filter = input(f"输入来源过滤 ({'/'.join(new_qa_system.config.VALID_SOURCES)}) (按 Enter 跳过): ").strip()
if source_filter and source_filter not in new_qa_system.config.VALID_SOURCES:
logger.warning(f"无效来源 '{source_filter}',忽略过滤")
print(f"无效来源 '{source_filter}',继续无过滤。")
source_filter = None
answer = new_qa_system.query(query, source_filter, session_id=session_id)
for value in answer:
if value[1] == "False":
print(value[0], end="")
else:
print(value[0])
except Exception as e:
logger.error(f"系统错误: {e}")
print(f"发生错误: {e}")
finally:
new_qa_system.mysql_client.close()
if __name__ == "__main__":
main()
app.py(FastAPI 服务)
python
# -*- coding: utf-8 -*-
# 导入 FastAPI 核心类
from fastapi import FastAPI, HTTPException, Request
# 导入 StreamingResponse,用于支持流式响应
from fastapi.responses import StreamingResponse
# 导入 json 模块
import json
# 导入 uuid 模块
import uuid
from new_main import IntegratedQASystem
# 创建一个 FastAPI 应用实例
app = FastAPI(title="集成问答系统 API", description="基于 RAG + MySQL + Redis 的问答系统 FastAPI 接口")
# 全局初始化一个问答系统实例
qa_system = IntegratedQASystem()
@app.post("/query")
async def handle_query(request: Request):
"""
接收客户端发送的 JSON 请求,支持流式返回答案。
请求体示例:
{
"query": "什么是人工智能?",
"source_filter": "ai",
"session_id": "a1b2c3d4-..."
}
响应为 SSE 流式格式
"""
try:
body = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="无效的 JSON 数据")
query = body.get("query", "").strip()
source_filter = body.get("source_filter", None)
session_id = body.get("session_id", None)
if not query:
raise HTTPException(status_code=400, detail="查询内容不能为空")
if not session_id:
session_id = str(uuid.uuid4())
valid_sources = qa_system.config.VALID_SOURCES
if source_filter and source_filter not in valid_sources:
raise HTTPException(
status_code=400,
detail=f"无效的学科类别。支持: {valid_sources}"
)
def generate_response():
try:
for token, is_complete in qa_system.query(
query=query,
source_filter=source_filter,
session_id=session_id
):
message = {
"token": token,
"is_complete": is_complete,
"session_id": session_id
}
yield f"data: {json.dumps(message, ensure_ascii=False)}\n\n"
except Exception as e:
error_msg = f"处理查询时发生错误: {str(e)}"
qa_system.logger.error(error_msg)
message = {
"error": error_msg,
"is_complete": True
}
yield f"data: {json.dumps(message, ensure_ascii=False)}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream"
)
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)
兄弟们,共赏!!!