一、为什么要进行代码重构?
1.1 "脚本式" 代码的三大致命问题
问题 1:可维护性差
- 所有代码都写在一个或几个大文件中,动辄几千行
- 变量和函数命名混乱,没有统一规范
- 缺乏注释和文档,过一段时间自己都看不懂
- 修改一个功能可能会影响其他不相关的功能
问题 2:可扩展性差
- 模块之间耦合度高,添加新功能需要修改大量现有代码
- 没有统一的接口规范,替换组件非常困难
- 无法方便地支持不同的模型、不同的向量库、不同的检索策略
问题 3:调试和排错困难
- 没有统一的错误处理机制,错误信息不明确
- 没有日志系统,无法追溯问题发生的过程
- 没有单元测试,修改代码后无法快速验证是否引入了新的 bug
1.2 工业级 RAG 系统标准分层架构
经过多年的实践,工业界已经形成了一套成熟的 RAG 系统分层架构,它将系统划分为多个独立的模块,每个模块有明确的职责,模块之间通过统一的接口进行交互。
RAG系统标准分层架构
┌─────────────────────────────────────────────────────┐
│ 服务层 (Service Layer) │
│ 对外提供API接口、Web界面、处理用户请求和会话管理 │
├─────────────────────────────────────────────────────┤
│ 核心层 (Core Layer) │
│ 实现RAG的核心业务逻辑:检索、生成、重排序、查询优化 │
├─────────────────────────────────────────────────────┤
│ 模型层 (Models Layer) │
│ 负责所有模型的加载、管理和推理:嵌入、重排序、大模型 │
├─────────────────────────────────────────────────────┤
│ 工具层 (Utils Layer) │
│ 通用工具函数:分块、解析、缓存、日志、异常处理等 │
├─────────────────────────────────────────────────────┤
│ 配置层 (Config Layer) │
│ 所有配置的集中管理,支持环境变量和配置文件 │
└─────────────────────────────────────────────────────┘
各层职责详解
- 配置层:所有可配置的参数都集中在这里,避免硬编码。支持从环境变量、.env 文件、配置文件中加载配置。
- 工具层:提供通用的工具函数,不包含业务逻辑。例如:文档解析、文本分块、缓存管理、日志记录等。
- 模型层:封装所有与模型相关的逻辑,提供统一的推理接口。上层模块不需要关心模型的具体实现细节。
- 核心层:实现 RAG 的核心业务逻辑,将各个模块组合起来,完成完整的 RAG 流程。
- 服务层:对外提供服务,将核心层的功能暴露给用户。例如:RESTful API、Web 界面、命令行界面等。
1.3 代码规范化的四大原则
原则 1:单一职责原则
每个模块、每个类、每个函数都应该只负责一个功能。例如:
DocumentParser类只负责文档解析TextSplitter类只负责文本分块Retriever类只负责检索
原则 2:开闭原则
对扩展开放,对修改关闭。当需要添加新功能时,应该通过添加新的代码来实现,而不是修改现有的代码。例如:
- 添加新的检索策略时,应该创建一个新的
Retriever子类,而不是修改现有的HybridRetriever类
原则 3:依赖倒置原则
高层模块不应该依赖低层模块,两者都应该依赖抽象。例如:
- 核心层不应该直接依赖具体的向量库实现,而是依赖一个抽象的
VectorStore接口 - 这样可以方便地在 Chroma、FAISS、Milvus 等不同的向量库之间切换
原则 4:接口隔离原则
客户端不应该依赖它不需要的接口。每个接口应该只包含客户端需要的方法。
二、核心代码实现
2.1 第一步:重构项目结构
首先,需要按照标准分层架构重新组织项目的目录结构。将原来的零散文件移动到对应的目录中。
最终项目结构
rag_project/
├── config/ # 配置层
│ ├── __init__.py
│ └── settings.py # 全局配置
├── core/ # 核心层
│ ├── __init__.py
│ ├── rag_system.py # RAG系统核心类
│ ├── retriever.py # 检索器(已修复除以零错误)
│ ├── generator.py # 生成器(已放宽反幻觉)
│ ├── reranker.py # 重排序器
│ └── query_optimizer.py # 查询优化器
├── models/ # 模型层
│ ├── __init__.py
│ ├── embedding_model.py # 嵌入模型
│ ├── reranker_model.py # 重排序模型
│ └── llm_model.py # 大语言模型(已加入容错)
├── utils/ # 工具层
│ ├── __init__.py
│ ├── document_parser.py # 文档解析器
│ ├── text_splitter.py # 文本分块器(纯Python,无NLTK)
│ ├── cache_utils.py # 缓存工具
│ ├── logger.py # 日志工具(已修复所有问题)
│ └── exceptions.py # 自定义异常
├── web/ # 服务层(Web界面)
│ └── app.py # Streamlit前端
├── data/ # 数据目录
│ ├── documents/ # 原始文档
│ ├── chunks/ # 分块后的文档
│ ├── vector_db/ # 向量数据库
│ └── cache/ # 缓存文件
├── logs/ # 日志目录
├── tests/ # 测试目录
│ ├── __init__.py
│ ├── test_utils.py # 工具层测试
│ ├── test_models.py # 模型层测试
│ └── test_core.py # 核心层测试
├── .env # 环境变量文件
├── pyproject.toml # 项目配置文件
├── requirements.txt # 依赖列表
└── start.py # 一键启动脚本
目录创建脚本
创建一个create_structure.py文件,自动创建所有目录和空文件:
python
import os
dirs = [
"config",
"core",
"models",
"utils",
"web",
"data/documents",
"data/chunks",
"data/vector_db",
"data/cache",
"logs",
"tests"
]
for dir_path in dirs:
os.makedirs(dir_path, exist_ok=True)
# 创建__init__.py文件
init_file = os.path.join(dir_path, "__init__.py")
if not os.path.exists(init_file) and not dir_path.startswith("data/") and dir_path != "logs":
with open(init_file, 'w') as f:
pass
# 创建空文件
files = [
"config/settings.py",
"core/rag_system.py",
"core/retriever.py",
"core/generator.py",
"core/reranker.py",
"core/query_optimizer.py",
"models/embedding_model.py",
"models/reranker_model.py",
"models/llm_model.py",
"utils/document_parser.py",
"utils/text_splitter.py",
"utils/cache_utils.py",
"utils/logger.py",
"utils/exceptions.py",
"web/app.py",
"tests/test_utils.py",
"tests/test_models.py",
"tests/test_core.py",
".env",
"pyproject.toml",
"requirements.txt",
"start.py"
]
for file_path in files:
if not os.path.exists(file_path):
with open(file_path, 'w') as f:
pass
运行这个脚本,自动创建完整的项目结构。
2.2 第二步:实现集中式配置管理
我们将使用pydantic-settings来实现集中式配置管理。它支持从环境变量、.env 文件、配置文件中加载配置,并且提供类型检查和自动验证。
2.2.1 编写配置类
打开config/settings.py,复制以下代码:
python
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from pathlib import Path
class Settings(BaseSettings):
"""全局配置类"""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="forbid" # 禁止未定义的配置项
)
# 项目基本配置
app_name: str = Field(default="本地RAG系统", description="应用名称")
debug: bool = Field(default=False, description="是否开启调试模式")
# 模型配置
embedding_model_path: str = Field(
default=r"C:\Users\87624\.cache\modelscope\hub\models\AI-ModelScope\bge-large-zh-v1.5",
description="BGE嵌入模型路径"
)
reranker_model_path: str = Field(
default=r"C:\Users\87624\.cache\modelscope\hub\models\AI-ModelScope\bge-reranker-v2-m3",
description="BGE重排序模型路径"
)
llm_model_path: str = Field(
default=r"C:\Users\87624\.cache\modelscope\hub\models\deepseek-ai\deepseek-llm-7b-chat",
description="本地大模型路径"
)
device: str = Field(default="cpu", description="运行设备:cpu/cuda")
# 检索配置
top_k: int = Field(default=5, description="检索返回的文档数量")
retrieval_method: str = Field(default="hybrid", description="检索方法:sentence_window/parent_document/hybrid")
window_size: int = Field(default=5, description="句子窗口大小")
# 生成配置
max_new_tokens: int = Field(default=1024, description="最大生成长度")
temperature: float = Field(default=0.1, description="温度参数")
# 缓存配置
cache_dir: Path = Field(default=Path("./data/cache"), description="缓存目录")
cache_ttl: int = Field(default=86400, description="缓存过期时间(秒)")
max_memory_cache_size: int = Field(default=1000, description="内存缓存最大条目数")
# 日志配置
log_dir: Path = Field(default=Path("./logs"), description="日志目录")
log_level: str = Field(default="INFO", description="日志级别:DEBUG/INFO/WARNING/ERROR/CRITICAL")
log_max_bytes: int = Field(default=10*1024*1024, description="单个日志文件最大大小(字节)")
log_backup_count: int = Field(default=5, description="日志文件备份数量")
# 向量库配置
vector_db_path: Path = Field(default=Path("./data/vector_db"), description="向量数据库路径")
vector_db_collection: str = Field(default="rag_collection", description="向量库集合名称")
# 全局配置实例
settings = Settings()
2.2.2 编写.env 文件
打开.env文件,复制以下内容,根据你的实际情况修改:
python
# 项目配置
DEBUG=True
# 模型路径(请修改为你的实际路径)
EMBEDDING_MODEL_PATH=C:\Users\87624\.cache\modelscope\hub\models\AI-ModelScope\bge-large-zh-v1.5
RERANKER_MODEL_PATH=C:\Users\87624\.cache\modelscope\hub\models\AI-ModelScope\bge-reranker-v2-m3
LLM_MODEL_PATH=C:\Users\87624\.cache\modelscope\hub\models\deepseek-ai\deepseek-llm-7b-chat
# 设备配置
DEVICE=cpu
# 检索配置
TOP_K=5
RETRIEVAL_METHOD=hybrid
WINDOW_SIZE=5
2.3 第三步:实现工业级日志系统
python
# 必须在导入任何transformers相关库之前设置
import os
os.environ["TRANSFORMERS_SKIP_DYNAMIC_MODULES_SCAN"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
import logging
from logging.handlers import RotatingFileHandler
from pathlib import Path
from config.settings import settings
class SensitiveDataFilter(logging.Filter):
"""敏感信息过滤器,自动脱敏日志中的敏感信息"""
def filter(self, record):
# 创建record的副本,避免修改原始对象
msg = str(record.getMessage())
# 脱敏密码
if "password" in msg.lower():
msg = msg.replace(msg[msg.lower().find("password")+9:], "***")
# 脱敏API密钥
if "api_key" in msg.lower() or "apikey" in msg.lower():
msg = msg.replace(msg[msg.lower().find("api_key")+8:], "***")
# 脱敏路径中的用户名
if "Users" in msg:
import re
msg = re.sub(r'Users\\[^\\]+\\', r'Users\***\\', msg)
# 将脱敏后的消息赋值给record
record.msg = msg
record.args = ()
return True
class SafeFormatter(logging.Formatter):
"""安全格式化器,处理第三方库的日志格式化错误"""
def format(self, record):
try:
return super().format(record)
except TypeError:
# 当格式化失败时,返回原始消息和参数
return f"{record.asctime} - {record.name} - {record.levelname} - {record.filename}:{record.lineno} - {record.msg} {record.args}"
class TransformersWarningFilter(logging.Filter):
"""专门过滤transformers的__path__警告"""
def filter(self, record):
msg = str(record.getMessage())
if "Accessing `__path__` from" in msg:
return False
if "Behavior may be different and this alias will be removed" in msg:
return False
return True
def setup_logger():
"""配置全局日志系统"""
# 创建日志目录
settings.log_dir.mkdir(exist_ok=True)
# 获取根日志器
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.handlers.clear() # 清除默认处理器
# 定义日志格式
detailed_formatter = SafeFormatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
simple_formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 控制台处理器(只输出INFO及以上级别)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(simple_formatter)
console_handler.addFilter(SensitiveDataFilter())
# 普通文件处理器(输出所有级别)
file_handler = RotatingFileHandler(
filename=settings.log_dir / "app.log",
maxBytes=settings.log_max_bytes,
backupCount=settings.log_backup_count,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(detailed_formatter)
file_handler.addFilter(SensitiveDataFilter())
file_handler.addFilter(TransformersWarningFilter())
# 错误文件处理器(只输出ERROR及以上级别)
error_file_handler = RotatingFileHandler(
filename=settings.log_dir / "error.log",
maxBytes=settings.log_max_bytes,
backupCount=settings.log_backup_count,
encoding='utf-8'
)
error_file_handler.setLevel(logging.ERROR)
error_file_handler.setFormatter(detailed_formatter)
error_file_handler.addFilter(SensitiveDataFilter())
# 添加处理器到根日志器
logger.addHandler(console_handler)
logger.addHandler(file_handler)
logger.addHandler(error_file_handler)
# 设置第三方库的日志级别
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("chromadb").setLevel(logging.WARNING)
logging.getLogger("uvicorn").setLevel(logging.WARNING)
logging.getLogger("starlette").setLevel(logging.WARNING)
logging.getLogger("streamlit").setLevel(logging.WARNING)
logging.getLogger("python_multipart").setLevel(logging.INFO)
logging.getLogger("multipart").setLevel(logging.INFO)
# 为transformers日志器添加专门的过滤器
transformers_logger = logging.getLogger("transformers")
transformers_logger.addFilter(TransformersWarningFilter())
logger.info("日志系统初始化完成")
return logger
# 全局日志器实例
logger = setup_logger()
2.4 第四步:实现统一的错误处理机制
我们将定义一系列自定义异常类,实现全局异常捕获,确保所有错误都有明确的错误信息和处理方式。
python
from utils.logger import logger
class RAGBaseException(Exception):
"""RAG系统基础异常类"""
def __init__(self, message: str, error_code: int = 1000):
self.message = message
self.error_code = error_code
super().__init__(self.message)
# 自动记录异常日志
logger.error(f"[{self.error_code}] {self.message}")
class ModelLoadError(RAGBaseException):
"""模型加载失败异常"""
def __init__(self, model_name: str, details: str = ""):
message = f"模型加载失败:{model_name}"
if details:
message += f",详细信息:{details}"
super().__init__(message, error_code=1001)
class DocumentParseError(RAGBaseException):
"""文档解析失败异常"""
def __init__(self, file_path: str, details: str = ""):
message = f"文档解析失败:{file_path}"
if details:
message += f",详细信息:{details}"
super().__init__(message, error_code=1002)
class VectorDBError(RAGBaseException):
"""向量数据库操作异常"""
def __init__(self, operation: str, details: str = ""):
message = f"向量数据库操作失败:{operation}"
if details:
message += f",详细信息:{details}"
super().__init__(message, error_code=1003)
class RetrievalError(RAGBaseException):
"""检索失败异常"""
def __init__(self, query: str, details: str = ""):
message = f"检索失败:查询='{query}'"
if details:
message += f",详细信息:{details}"
super().__init__(message, error_code=1004)
class GenerationError(RAGBaseException):
"""生成失败异常"""
def __init__(self, details: str = ""):
message = "回答生成失败"
if details:
message += f",详细信息:{details}"
super().__init__(message, error_code=1005)
class ConfigurationError(RAGBaseException):
"""配置错误异常"""
def __init__(self, config_name: str, details: str = ""):
message = f"配置错误:{config_name}"
if details:
message += f",详细信息:{details}"
super().__init__(message, error_code=1006)