文章导读
本文将从一个完整的项目实战出发,详细讲解如何构建一个基于传统检索技术的智能问答系统。我们会逐行分析代码,深入理解每个模块的设计思想、实现细节和优化思路。
让我们一起看看RAG的来时路!!!
一、系统整体架构
1.1 系统要解决什么问题?
想象一下,你有一个包含数千条问答对的知识库(比如:Java面试题、学科知识问答、客服常见问题等)。当用户输入一个问题时,系统需要:
-
理解用户问的是什么
-
在知识库中找到最相似的问题
-
返回对应的答案
这就是QA系统的核心需求。
1.2 技术选型理由
| 组件 | 为什么选择它? |
|---|---|
| MySQL | 数据持久化存储,支持SQL查询,稳定可靠 |
| Redis | 内存级缓存,读写速度快,适合存储热点数据 |
| BM25 | 经典的信息检索算法,对文本相似度计算效果好,可解释性强 |
| jieba | 中文分词工具,支持精确模式和搜索引擎模式 |
1.3 整体数据流
数据加载
python
"""
┌─────────────────────────────────────────────────────────────────┐
│ 系统启动 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 1. 检查Redis缓存 │
│ ├── Key: qa_original_questions(原始问题列表) │
│ └── Key: qa_tokenized_questions(分词后问题列表) │
└─────────────────────────────────────────────────────────────────┘
│
┌───────────────┴───────────────┐
│ 命中 │ 未命中
▼ ▼
┌───────────────┐ ┌───────────────┐
│ 直接从Redis │ │ 从MySQL查询 │
│ 加载数据 │ │ 所有问答对 │
└───────────────┘ └───────────────┘
│
▼
┌───────────────┐
│ jieba分词 │
│ 处理每个问题 │
└───────────────┘
│
▼
┌───────────────┐
│ 写入Redis缓存 │
└───────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 2. 构建BM25索引 │
│ └── 使用分词后的问题列表初始化 BM25Okapi 模型 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 3. 系统就绪,等待用户查询 │
└─────────────────────────────────────────────────────────────────┘
"""
查询阶段
python
"""
┌─────────────────────────────────────────────────────────────────┐
│ 用户输入问题 Query │
│ 例如:"如何在磁盘中新建文件?" │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤1:检查Redis缓存 │
│ Key: answer:{query} │
│ 例如:answer:如何在磁盘中新建文件? │
└─────────────────────────────────────────────────────────────────┘
│
┌───────────────┴───────────────┐
│ 命中 │ 未命中
▼ ▼
┌───────────────┐ ┌───────────────┐
│ 直接返回缓存 │ │ 进入BM25检索 │
│ 中的答案 │ │ 流程 │
└───────────────┘ └───────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤2:Query预处理(jieba分词) │
│ 输入:"如何在磁盘中新建文件?" │
│ 输出:['如何', '磁盘', '新建', '文件'] │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤3:BM25相似度计算 │
│ 调用 bm25.get_scores(query_tokens) │
│ 输出:每个问题的原始BM25分数 │
│ 例:[12.5, 8.3, 15.2, 4.1, ...] │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤4:Softmax归一化 │
│ 将原始分数转换为概率分布(总和为1) │
│ 例:[0.12, 0.03, 0.82, 0.01, ...] │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤5:取Top1 + 阈值判断 │
│ 最高分:0.82 │
│ 阈值:0.85 │
│ 判断:0.82 < 0.85 → 不通过 │
└─────────────────────────────────────────────────────────────────┘
│
┌───────────────┴───────────────┐
│ 通过(≥阈值) │ 不通过(<阈值)
▼ ▼
┌───────────────┐ ┌───────────────┐
│ 获取最佳匹配 │ │ 返回None │
│ 问题的索引 │ │ 需要降级处理 │
└───────────────┘ └───────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤6:获取答案 │
│ 方式A(当前实现):用问题去MySQL查询答案 │
│ 方式B(优化建议):从内存问答对直接取答案 │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤7:写入Redis缓存 │
│ Key: answer:{query} │
│ Value: 答案文本 │
│ 设置过期时间(可选) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ 步骤8:返回答案给用户 │
└─────────────────────────────────────────────────────────────────┘
"""
二、项目文件结构详解
python
"""
b_traditional_qa/ # 项目根目录
│
├── base/ # 【基础层】公共基础设施
│ ├── __init__.py # Python包标识文件
│ ├── config.py # 配置管理(读取config.ini)
│ └── logger.py # 日志封装(统一日志处理)
│
├── logs/ # 【日志层】运行日志存储
│ └── app.log # 应用程序运行日志文件
│
├── mysql_qa/ # 【业务层】核心业务逻辑
│ ├── __init__.py
│ │
│ ├── cache/ # 【子模块】缓存服务
│ │ └── redis_client.py # Redis客户端封装
│ │
│ ├── db/ # 【子模块】数据库服务
│ │ └── mysql_client.py # MySQL客户端封装
│ │
│ ├── retrieval/ # 【子模块】检索服务
│ │ └── bm25_search.py # BM25检索核心逻辑
│ │
│ └── utils/ # 【子模块】工具函数
│ └── preprocess.py # 文本预处理(jieba分词)
│
├── config.ini # 【配置文件】数据库/缓存配置
├── create_database.py # 【初始化脚本】建表+导入数据
└── mysql_main.py # 【主入口】启动问答系统
"""
目录设计思想:采用分层架构,base层提供基础能力,mysql_qa层实现业务逻辑,各子模块职责单一、相互独立。
三、配置模块深度解析(base/config.py)
3.1 为什么要单独做配置模块?
在实际项目中,配置信息(数据库密码、Redis地址等)不应该硬编码在代码中。原因:
-
安全性:密码等敏感信息不应该提交到代码仓库
-
环境隔离:开发、测试、生产环境配置不同
-
可维护性:修改配置不需要改代码
3.2 完整代码逐行分析
python
# -*- coding:utf-8 -*-
# 导入配置ini文件的解析库
import configparser
# 导入路径操作库
import os
# ========== 路径解析部分 ==========
# 这里为什么要做这么多路径操作?
# 因为Python脚本可能从任何目录被执行,我们需要动态找到config.ini的位置
# 获取当前文件的绝对路径
# __file__是Python内置变量,表示当前文件路径
# 例如:/Users/xxx/project/b_traditional_qa/base/config.py
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在目录的绝对路径
# 例如:/Users/xxx/project/b_traditional_qa/base
current_dir_path = os.path.dirname(current_file_path)
# 获取项目根目录的绝对路径(当前目录的父目录)
# 例如:/Users/xxx/project/b_traditional_qa
project_root = os.path.dirname(current_dir_path)
# 拼接配置文件的完整路径
# 例如:/Users/xxx/project/b_traditional_qa/config.ini
config_file_path = os.path.join(project_root, 'config.ini')
class Config():
"""配置管理类,负责读取config.ini中的配置信息"""
def __init__(self, config_file=config_file_path):
"""
初始化配置对象
设计要点:
1. config_file有默认值,调用时可以不传参
2. 使用fallback机制,配置缺失时不会崩溃
"""
# 步骤1:创建配置文件解析器
# ConfigParser是Python内置的INI文件解析器
self.config = configparser.ConfigParser()
# 步骤2:读取配置文件
# read()方法可以读取文件路径,也支持读取文件对象
self.config.read(config_file)
# ========== MySQL配置读取 ==========
# self.config.get(section, key, fallback=default)
# - section: INI文件中的段落名,如[mysql]
# - key: 段落中的键名,如host
# - fallback: 如果键不存在,返回这个默认值
self.MYSQL_HOST = self.config.get('mysql', 'host', fallback='localhost')
self.MYSQL_USER = self.config.get('mysql', 'user', fallback='root')
self.MYSQL_PASSWORD = self.config.get('mysql', 'password', fallback='123456')
self.MYSQL_DATABASE = self.config.get('mysql', 'database', fallback='subjects_kg')
# ========== Redis配置读取 ==========
# getint()方法会自动将字符串转为整数类型
# 例如:port = 6379 会被转为 int 6379
self.REDIS_HOST = self.config.get('redis', 'host', fallback='localhost')
self.REDIS_PORT = self.config.getint('redis', 'port', fallback=6379)
self.REDIS_PASSWORD = self.config.get('redis', 'password', fallback='1234')
self.REDIS_DB = self.config.getint('redis', 'db', fallback=0)
# ========== 日志配置读取 ==========
self.LOG_FILE = self.config.get('logger', 'log_file', fallback='logs/app.log')
if __name__ == '__main__':
# 测试代码:只有直接运行此文件时才执行
# 这样设计的好处:可以作为模块导入,也可以独立测试
conf = Config()
print(conf.MYSQL_HOST)
print(conf.LOG_FILE)
3.3 config.ini文件内容
python
# MySQL 配置
[mysql]
host = localhost
user = root
password = 123456
database = subjects_kg
# Redis 配置
[redis]
host = localhost
port = 6379
password = 1234
db = 0
# 日志配置
[logger]
log_file = logs/app.log
3.4 设计亮点总结
| 设计点 | 作用 | 好处 |
|---|---|---|
| 动态路径解析 | 自动找到项目根目录 | 脚本可以从任意位置执行 |
| fallback默认值 | 配置项缺失时有兜底 | 提高系统健壮性 |
| 类型转换方法 | getint()自动转换 | 减少类型错误 |
if __name__ == '__main__' |
独立测试代码 | 模块可测试、可导入 |
四、日志模块深度解析(base/logger.py)
4.1 日志的重要性
在生产环境中,日志是排查问题的唯一线索。一个好的日志系统应该具备:
-
同时输出到文件和控制台
-
格式统一、信息完整
-
不会重复记录
-
支持中文
4.2 完整代码逐行分析
python
# -*- coding:utf-8 -*-
import logging
import os
from config import Config
# ========== 路径解析 ==========
# 与config.py类似,动态获取日志文件路径
current_file_path = os.path.abspath(__file__)
current_dir_path = os.path.dirname(current_file_path)
project_root = os.path.dirname(current_dir_path)
log_file_path = os.path.join(project_root, Config().LOG_FILE)
def setup_logging(log_file=log_file_path):
"""
配置日志系统
这个函数的核心任务是:
1. 创建日志目录(如果不存在)
2. 创建logger对象
3. 添加文件处理器和控制台处理器
4. 设置统一的日志格式
"""
# ========== 步骤1:创建日志目录 ==========
# os.path.dirname(log_file) 获取日志文件所在的目录路径
# exist_ok=True 表示如果目录已存在,不报错
os.makedirs(os.path.dirname(log_file), exist_ok=True)
# ========== 步骤2:创建logger对象 ==========
# getLogger()的参数是logger的名称,可以通过名称获取同一个logger
# 同名logger在整个程序中是单例的
logger = logging.getLogger("EduRAG")
# 设置日志级别:INFO及以上级别的日志才会被记录
# 级别从低到高:DEBUG < INFO < WARNING < ERROR < CRITICAL
logger.setLevel(logging.INFO)
# ========== 步骤3:添加处理器(关键!) ==========
# handlers是logger的处理器列表
# 如果不做这个判断,每次调用setup_logging都会添加新的处理器
# 导致同一条日志被输出多次
if not logger.handlers:
# 3.1 文件处理器:将日志写入文件
# encoding='utf-8' 保证中文不乱码
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(logging.INFO)
# 3.2 控制台处理器:将日志输出到终端
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 3.3 设置日志格式
# 格式说明:
# %(asctime)s - 时间,如:2026-01-15 10:30:45,123
# %(name)s - logger名称,如:EduRAG
# %(levelname)s - 日志级别,如:INFO、ERROR
# %(message)s - 日志消息内容
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 为两个处理器分别设置格式
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 将处理器添加到logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
# ========== 模块级别的logger ==========
# 在模块加载时就初始化好logger
# 其他模块可以直接:from base.logger import logger
logger = setup_logging()
4.3 日志效果展示
运行系统后,控制台输出:
2026-01-15 10:30:45,123 - EduRAG - INFO - Redis 连接成功 2026-01-15 10:30:45,456 - EduRAG - INFO - MySQL 连接成功 2026-01-15 10:30:45,789 - EduRAG - INFO - BM25 模型初始化完成
同时,这些内容也会写入logs/app.log文件,方便事后排查。
4.4 设计亮点总结
| 设计点 | 作用 | 为什么重要 |
|---|---|---|
if not logger.handlers |
防止重复添加处理器 | 避免日志重复输出 |
| 双处理器(文件+控制台) | 同时输出到文件和终端 | 开发调试看控制台,生产排查看文件 |
encoding='utf-8' |
支持中文 | 日志中出现中文时不会乱码 |
| 模块级logger实例 | 在模块加载时初始化 | 其他模块直接导入使用,无需重复配置 |
五、Redis客户端模块深度解析(mysql_qa/cache/redis_client.py)
5.1 Redis在系统中的作用
这个系统中,Redis承担了三个缓存职责:
| 缓存类型 | Key格式 | 存储内容 | 作用 |
|---|---|---|---|
| 问题列表缓存 | qa_original_questions |
所有原始问题 | 避免每次都查MySQL |
| 分词结果缓存 | qa_tokenized_questions |
分词后的问题列表 | 避免重复分词 |
| 查询结果缓存 | answer:{用户问题} |
用户问题的答案 | 相同问题直接返回 |
5.2 完整代码逐行分析
python
# -*- coding:utf-8 -*-
import redis
import json
import os
import sys
# ========== 路径处理 ==========
# 这部分代码的目的是:无论从哪里运行,都能正确导入base模块
current_dir = os.path.dirname(os.path.abspath(__file__)) # .../mysql_qa/cache
module_dir = os.path.dirname(current_dir) # .../mysql_qa
project_root = os.path.dirname(module_dir) # .../b_traditional_qa
sys.path.insert(0, project_root) # 将项目根目录加入Python路径
# 现在可以正常导入base模块了
from b_traditional_qa.base import Config, logger
class RedisClient:
"""Redis客户端封装类
这个封装类的设计目标:
1. 统一管理Redis连接
2. 封装序列化/反序列化逻辑
3. 统一异常处理和日志记录
"""
def __init__(self):
"""初始化Redis连接"""
self.logger = logger
try:
# StrictRedis是官方推荐的Redis客户端
# decode_responses=True 是关键!
# 如果不设置,返回的数据是bytes类型,需要手动decode
self.client = redis.StrictRedis(
host=Config().REDIS_HOST, # 从配置读取主机地址
port=Config().REDIS_PORT, # 从配置读取端口
password=Config().REDIS_PASSWORD, # 从配置读取密码
db=Config().REDIS_DB, # 从配置读取数据库编号
decode_responses=True # 自动解码为字符串
)
self.logger.info("Redis 连接成功")
except redis.RedisError as e:
self.logger.error(f"Redis 连接失败: {e}")
raise # 连接失败时抛出异常,让上层处理
def set_data(self, key, value):
"""存储数据到Redis
参数:
key: Redis键名
value: 要存储的值(可以是任意可JSON序列化的Python对象)
"""
try:
# json.dumps() 将Python对象转为JSON字符串
# ensure_ascii=False 保证中文不会被转义为\uXXXX格式
# 例如:{"name": "张三"} 而不是 {"name": "\u5f20\u4e09"}
self.client.set(key, json.dumps(value, ensure_ascii=False))
self.logger.info(f"存储数据到 Redis: {key}")
except redis.RedisError as e:
self.logger.error(f"Redis 存储失败: {e}")
def get_data(self, key):
"""从Redis获取数据
返回:
JSON解析后的Python对象,如果key不存在则返回None
"""
try:
data = self.client.get(key)
# 如果data存在,用json.loads解析;否则返回None
return json.loads(data) if data else None
except redis.RedisError as e:
self.logger.error(f"Redis 获取失败: {e}")
return None
def get_answer(self, query):
"""获取查询的缓存答案(专用方法)
这是针对答案缓存的特化方法,key格式固定为 answer:{query}
"""
try:
# 构建固定格式的key
answer = self.client.get(f"answer:{query}")
if answer:
self.logger.info(f"从 Redis 获取答案: {query}")
return answer
return None
except redis.RedisError as e:
self.logger.error(f"Redis 查询失败: {e}")
return None
def delete_data(self, key):
"""删除Redis中的指定key"""
try:
self.client.delete(key)
self.logger.info(f"删除 Redis 数据: {key}")
except redis.RedisError as e:
self.logger.error(f"Redis 删除失败: {e}")
if __name__ == '__main__':
# 测试代码
redcli = RedisClient()
print(redcli.get_data(key="user2"))
print(redcli.get_answer(query="我爱写代码"))
5.3 关键知识点
1. decode_responses=True 的作用
python
# 不设置 decode_responses=True
client.get("name") # 返回 b'张三' (bytes类型)
# 设置 decode_responses=True
client.get("name") # 返回 '张三' (str类型)
2. JSON序列化中的 ensure_ascii=False
python
# ensure_ascii=True (默认)
json.dumps({"name": "张三"}) # 返回 '{"name": "\\u5f20\\u4e09"}'
# ensure_ascii=False
json.dumps({"name": "张三"}, ensure_ascii=False) # 返回 '{"name": "张三"}'
5.4 为什么答案缓存使用独立方法 get_answer()?
虽然功能上可以用get_data()实现,但独立方法的好处:
-
语义更清晰:方法名直接表达意图
-
key格式统一 :所有答案缓存都使用
answer:{query}格式 -
便于扩展:未来可以在方法中添加特殊逻辑(如过期时间)
六、MySQL客户端模块深度解析(mysql_qa/db/mysql_client.py)
6.1 MySQL在系统中的角色
MySQL作为持久化存储,存储问答对数据。表结构:
sql
CREATE TABLE IF NOT EXISTS jpkb (
id INT AUTO_INCREMENT PRIMARY KEY, -- 自增主键
subject_name VARCHAR(20), -- 学科名称
question VARCHAR(1000), -- 问题(最长1000字符)
answer VARCHAR(1000) -- 答案(最长1000字符)
)
6.2 完整代码逐行分析
python
# -*- coding:utf-8 -*-
import pymysql
import pandas as pd
import sys
import os
# ========== 路径处理 ==========
current_dir = os.path.dirname(os.path.abspath(__file__)) # .../mysql_qa/db
module_dir = os.path.dirname(current_dir) # .../mysql_qa
project_root = os.path.dirname(module_dir) # .../b_traditional_qa
sys.path.insert(0, project_root)
from b_traditional_qa.base import Config, logger
class MySQLClient:
"""MySQL客户端封装类
职责:
1. 管理数据库连接
2. 提供建表、插入、查询等基础操作
3. 统一异常处理和日志
"""
def __init__(self):
"""初始化MySQL连接"""
self.logger = logger
try:
# pymysql.connect() 建立数据库连接
# 注意:这里没有设置自动提交,需要手动commit
self.connection = pymysql.connect(
host=Config().MYSQL_HOST,
user=Config().MYSQL_USER,
password=Config().MYSQL_PASSWORD,
database=Config().MYSQL_DATABASE
)
# 创建游标:用于执行SQL语句
# 默认返回元组,每行数据是一个元组
self.cursor = self.connection.cursor()
self.logger.info("MySQL 连接成功")
except pymysql.MySQLError as e:
self.logger.error(f"MySQL 连接失败: {e}")
raise
def create_table(self):
"""创建数据表(如果不存在)"""
create_table_query = '''
CREATE TABLE IF NOT EXISTS jpkb (
id INT AUTO_INCREMENT PRIMARY KEY,
subject_name VARCHAR(20),
question VARCHAR(1000),
answer VARCHAR(1000)
)
'''
try:
self.cursor.execute(create_table_query)
self.connection.commit() # 提交事务
self.logger.info("表创建成功")
except pymysql.MySQLError as e:
self.logger.error(f"表创建失败: {e}")
raise
def insert_data(self, csv_path):
"""从CSV文件批量插入数据
参数:
csv_path: CSV文件路径,期望包含三列:学科名称、问题、答案
"""
try:
# 使用pandas读取CSV文件
data = pd.read_csv(csv_path)
print(data.head()) # 打印前5行,便于确认数据正确性
# iterrows() 遍历DataFrame的每一行
# _ 是行索引,row是行数据(Series对象)
for _, row in data.iterrows():
insert_query = "INSERT INTO jpkb (subject_name, question, answer) VALUES (%s, %s, %s)"
# 参数化查询:用%s作为占位符,防止SQL注入
self.cursor.execute(insert_query, (row["学科名称"], row["问题"], row["答案"]))
self.connection.commit()
self.logger.info("Mysql数据插入成功")
except Exception as e:
self.logger.error(f'Mysql数据插入失败:{e}')
# rollback() 回滚事务,取消当前事务的所有操作
# 让数据库回到执行这些操作之前的状态
self.connection.rollback()
raise
def fetch_questions(self):
"""获取所有问题
返回:
问题列表,每个元素是包含问题的元组,如 [('问题1',), ('问题2',), ...]
"""
try:
self.cursor.execute("SELECT question FROM jpkb")
results = self.cursor.fetchall() # 获取所有结果
self.logger.info("成功获取问题")
return results
except pymysql.MySQLError as e:
self.logger.error(f"查询失败: {e}")
return [] # 查询失败返回空列表
def fetch_answer(self, question):
"""获取指定问题的答案
参数:
question: 问题文本
返回:
答案文本,如果不存在则返回None
"""
try:
# 参数化查询:用%s占位符,第二个参数是元组
self.cursor.execute("SELECT answer FROM jpkb WHERE question=%s", (question,))
result = self.cursor.fetchone() # 获取第一条结果
print(f'result--》{result}')
# 如果result存在,返回第一个元素(答案);否则返回None
return result[0] if result else None
except pymysql.MySQLError as e:
self.logger.error(f"答案获取失败: {e}")
return None
def close(self):
"""关闭数据库连接"""
try:
self.connection.close()
self.logger.info("MySQL 连接已关闭")
except pymysql.MySQLError as e:
self.logger.error(f"关闭连接失败: {e}")
def drop_table(self):
"""删除数据表(危险操作,谨慎使用)"""
drop_table_query = '''drop table jpkb;'''
try:
self.cursor.execute(drop_table_query)
self.connection.commit()
self.logger.info("表删除成功")
except pymysql.MySQLError as e:
self.logger.error(f"表删除失败: {e}")
raise
if __name__ == '__main__':
mysql_client = MySQLClient()
# mysql_client.create_table()
# mysql_client.insert_data(csv_path='../../../data/JP学科知识问答.csv')
# results = mysql_client.fetch_questions()
# print(f'results--》{results}')
# a = mysql_client.fetch_answer(question="在磁盘中无法新建文本文档")
# print(f'a--》{a}')
mysql_client.close()
6.3 关键知识点
1. 为什么使用参数化查询?
python
# 危险写法:字符串拼接
cursor.execute(f"SELECT answer FROM jpkb WHERE question='{question}'")
# 如果question = "'; DROP TABLE jpkb; --",就会发生SQL注入攻击
# 安全写法:参数化查询
cursor.execute("SELECT answer FROM jpkb WHERE question=%s", (question,))
# pymysql会自动转义特殊字符,防止SQL注入
2. fetchone() vs fetchall()
| 方法 | 返回值 | 适用场景 |
|---|---|---|
fetchone() |
单个元组 | 预期只有一条结果 |
fetchall() |
元组列表 | 可能有多个结果 |
3. 事务管理
python
# 执行修改操作后需要commit
cursor.execute("INSERT ...")
connection.commit() # 提交事务
# 出错时需要rollback
try:
# 操作...
except Exception:
connection.rollback() # 回滚事务
七、文本预处理模块(mysql_qa/utils/preprocess.py)
7.1 为什么需要中文分词?
BM25算法是基于词的算法,需要将文本切分成词(token)。英文有天然的空格分隔,但中文需要专门的分词工具。
例如:"我喜欢编程" → ['我', '喜欢', '编程']
7.2 完整代码分析
python
# -*- coding:utf-8 -*-
import jieba # 中文分词库
import os
import sys
# 路径处理,以便导入base模块
current_dir = os.path.dirname(os.path.abspath(__file__))
module_dir = os.path.dirname(current_dir)
project_root = os.path.dirname(module_dir)
sys.path.insert(0, project_root)
from b_traditional_qa.base import logger
def preprocess_text(text):
"""
预处理文本:进行中文分词
参数:
text: 原始字符串,如 "程序员是什么"
返回:
分词后的列表,如 [ '程序员', '是', '什么']
注意:
这里将文本转为小写(.lower()),对中文没有影响
但可以兼容英文提问的情况
"""
logger.info("开始预处理文本")
try:
# jieba.lcut() 返回列表,l表示list
# .lower() 将英文转为小写
return jieba.lcut(text.lower())
except Exception as e:
logger.error(f"文本预处理失败: {e}", exc_info=True) # exc_info=True 打印堆栈信息
return [] # 预处理失败返回空列表
if __name__ == '__main__':
print(preprocess_text(text="我是程序员"))
# 输出:['程序员']
7.3 jieba分词的三种模式
| 模式 | 方法 | 示例 | 特点 |
|---|---|---|---|
| 精确模式 | jieba.lcut(text) |
['我', '喜欢', '编程'] |
最常用,适合文本分析 |
| 全模式 | jieba.lcut(text, cut_all=True) |
['我', '喜欢', '编程', '程'] |
输出所有可能的词 |
| 搜索引擎模式 | jieba.lcut_for_search(text) |
['我', '喜欢', '编程'] |
在精确模式基础上增加长词切分 |
八、BM25检索核心模块(mysql_qa/retrieval/bm25_search.py)
8.1 BM25算法原理
BM25(Okapi BM25)是一种用于信息检索的排序函数,用于评估文档与查询的相关性。
核心思想:
-
词频(TF):查询词在文档中出现的次数越多,相关性越高
-
逆文档频率(IDF):查询词在所有文档中出现越少,越能区分文档
-
文档长度归一化:长文档不应该因为词多而获得不公平的高分
8.2 完整代码分析
python
# -*- coding:utf-8 -*-
from rank_bm25 import BM25Okapi
import numpy as np
import sys
import os
# 路径处理
current_dir = os.path.dirname(os.path.abspath(__file__))
module_dir = os.path.dirname(current_dir)
sys.path.insert(0, module_dir)
project_root = os.path.dirname(module_dir)
sys.path.insert(0, project_root)
from b_traditional_qa.mysql_qa.utils.preprocess import preprocess_text
from b_traditional_qa.mysql_qa.db.mysql_client import MySQLClient
from b_traditional_qa.mysql_qa.cache.redis_client import RedisClient
from b_traditional_qa.base import logger
class BM25Search:
"""BM25检索类
核心职责:
1. 加载问答数据,构建BM25索引
2. 使用Redis缓存数据,避免重复加载
3. 接收用户查询,返回最匹配的答案
"""
def __init__(self, redis_client, mysql_client):
"""初始化BM25检索器
参数:
redis_client: Redis客户端实例
mysql_client: MySQL客户端实例
"""
self.logger = logger
self.redis_client = redis_client
self.mysql_client = mysql_client
self.bm25 = None
self.questions = None
self.original_questions = None
self._load_data() # 初始化时立即加载数据
def _load_data(self):
"""加载数据:优先从Redis读取,没有则从MySQL读取
数据加载流程:
1. 尝试从Redis获取原始问题列表和分词后的问题列表
2. 如果Redis中没有,从MySQL读取所有问题
3. 对问题进行jieba分词
4. 将原始问题和分词结果存入Redis(供下次使用)
5. 构建BM25索引
"""
original_key = "qa_original_questions"
tokenized_key = "qa_tokenized_questions"
# 步骤1:尝试从Redis获取
self.original_questions = self.redis_client.get_data(original_key)
tokenized_questions = self.redis_client.get_data(tokenized_key)
# 步骤2:如果Redis中没有,从MySQL加载
if not self.original_questions or not tokenized_questions:
# 从MySQL获取所有问题
self.original_questions = self.mysql_client.fetch_questions()
if not self.original_questions:
self.logger.warning("未加载问题")
return
# 对每个问题进行分词
# 注意:fetch_questions()返回的是元组列表,每个元组第一个元素是问题文本
tokenized_questions = [preprocess_text(q[0]) for q in self.original_questions]
# 步骤3:存入Redis缓存
self.redis_client.set_data(original_key, [(q[0]) for q in self.original_questions])
self.redis_client.set_data(tokenized_key, tokenized_questions)
# 步骤4:设置问题列表并构建BM25索引
self.questions = tokenized_questions
self.bm25 = BM25Okapi(self.questions)
self.logger.info("BM25 模型初始化完成")
def _softmax(self, scores):
"""Softmax归一化函数
作用:将任意实数向量转换为概率分布(所有值在0-1之间,总和为1)
参数:
scores: 原始分数列表,如 [2.5, 1.3, 0.8]
返回:
归一化后的概率分布,如 [0.65, 0.25, 0.10]
为什么要减去最大值?
防止数值过大导致exp溢出(数值稳定性)
"""
exp_scores = np.exp(scores - np.max(scores))
return exp_scores / exp_scores.sum()
def search(self, query, threshold=0.85):
"""执行查询搜索
参数:
query: 用户输入的问题字符串
threshold: 相似度阈值,默认0.85(85%)
返回:
(answer, need_fallback) 元组
- answer: 找到的答案,没有则返回None
- need_fallback: 是否需要降级处理(True表示需要调用其他系统)
"""
# ========== 步骤1:参数校验 ==========
if not query or not isinstance(query, str):
self.logger.error("无效查询")
return None, False
# ========== 步骤2:检查Redis缓存 ==========
# 如果同一个问题被问过,直接从缓存返回
cached_answer = self.redis_client.get_answer(query)
if cached_answer:
return cached_answer, False
try:
# ========== 步骤3:对查询进行分词 ==========
query_tokens = preprocess_text(query)
# ========== 步骤4:BM25计算分数 ==========
# get_scores() 返回每个文档与查询的BM25分数
scores = self.bm25.get_scores(query_tokens)
# ========== 步骤5:Softmax归一化 ==========
softmax_score = self._softmax(scores)
# ========== 步骤6:获取最佳匹配 ==========
best_idx = softmax_score.argmax() # 最高分对应的索引
best_score = softmax_score[best_idx] # 最高分
# ========== 步骤7:阈值判断 ==========
if best_score >= threshold:
# 获取原始问题文本
original_question = self.original_questions[best_idx]
# 从MySQL获取答案
answer = self.mysql_client.fetch_answer(original_question)
if answer:
# 将结果存入Redis缓存
self.redis_client.set_data(f'answer:{query}', answer)
self.logger.info(f'搜索成功,Softmax相似度:{best_score:.3f}')
return answer, False
# 没有找到可靠答案
self.logger.info(f"未找到可靠答案,最高 Softmax 相似度: {best_score:.3f}")
return None, True
except Exception as e:
self.logger.error(f'搜索查询失败:{e}')
return None, True
if __name__ == "__main__":
redis_client = RedisClient()
mysql_client = MySQLClient()
bm25_search = BM25Search(redis_client, mysql_client)
8.3 Softmax的作用
假设BM25返回的原始分数为:[15.2, 12.8, 10.5, 3.2]
这些分数难以解释,也不方便设定阈值。经过Softmax归一化后:
原始分数: [15.2, 12.8, 10.5, 3.2]
Softmax后: [0.85, 0.10, 0.04, 0.01]
现在可以直观理解:最匹配的问题占85%的"概率",超过了阈值0.85,所以返回答案。
8.4 为什么使用BM25而不是其他算法?
| 算法 | 优点 | 缺点 |
|---|---|---|
| BM25 | 效果好、可解释、速度快 | 需要分词、不考虑语义 |
| TF-IDF | 简单快速 | 效果不如BM25 |
| 余弦相似度 | 适合向量化文本 | 需要构建词向量 |
| 语义模型(BERT) | 理解语义 | 计算量大、需要GPU |
对于中小规模问答系统,BM25是性价比最高的选择。
九、系统主入口(mysql_main.py)
python
# -*- coding:utf-8 -*-
from b_traditional_qa.mysql_qa.db.mysql_client import MySQLClient
from b_traditional_qa.mysql_qa.cache.redis_client import RedisClient
from b_traditional_qa.mysql_qa.retrieval.bm25_search import BM25Search
from base import logger
import time
class MySQLQASystem:
"""MySQL问答系统主类
整合所有模块,提供统一的问答接口
"""
def __init__(self):
"""初始化系统:创建所有客户端和检索器"""
self.logger = logger
self.mysql_client = MySQLClient()
self.redis_client = RedisClient()
self.bm25_search = BM25Search(self.redis_client, self.mysql_client)
def query(self, query):
"""处理用户查询
参数:
query: 用户输入的问题
返回:
答案字符串
"""
start_time = time.time()
self.logger.info(f"处理查询: '{query}'")
# 执行搜索
answer, need_fallback = self.bm25_search.search(query, threshold=0.85)
if answer:
self.logger.info(f"MySQL 答案: {answer}")
else:
self.logger.info("SQL中未找到答案, 需要调用RAG系统")
answer = "SQL未找到答案"
processing_time = time.time() - start_time
self.logger.info(f"查询处理耗时 {processing_time:.2f}秒")
return answer
def main():
"""主函数:启动交互式问答系统"""
mysql_system = MySQLQASystem()
try:
print("\n欢迎使用 MySQL 问答系统!")
print("输入查询进行问答,输入 'exit' 退出。")
while True:
query = input("\n输入查询: ").strip()
if query.lower() == "exit":
logger.info("退出 MySQL 系统")
print("再见!")
break
answer = mysql_system.query(query)
print(f"\n答案: {answer}")
except Exception as e:
logger.error(f"系统错误: {e}")
print(f"发生错误: {e}")
finally:
mysql_system.mysql_client.close()
if __name__ == "__main__":
main()
9.1 运行效果示例
欢迎使用 MySQL 问答系统!
输入查询进行问答,输入 'exit' 退出。
输入查询: 如何在磁盘中新建文本文档?
答案: 右键点击空白处,选择"新建"→"文本文档"
输入查询: 硬盘无法新建文件怎么办?
答案: 检查磁盘格式是否为NTFS,如果不是需要格式化成NTFS格式...
输入查询: exit
再见!
十、总结与优化建议
10.1 系统优点
| 方面 | 说明 |
|---|---|
| 架构清晰 | 分层设计,职责分明 |
| 缓存完善 | Redis缓存问题列表、分词结果、查询答案 |
| 日志健全 | 文件+控制台双输出,便于排查 |
| 安全可靠 | 参数化查询防SQL注入,异常处理完善 |
10.2 可优化方向
| 问题 | 当前实现 | 优化建议 |
|---|---|---|
| 内存存储 | 只存问题,答案仍需查MySQL | 同时缓存问答对,匹配成功直接取答案 |
| 检索效率 | 每次全量计算BM25分数 O(N) | 使用倒排索引或向量检索(FAISS) |
| 阈值固定 | 固定0.85 | 支持动态阈值或返回Top-K |
| 数据库索引 | question字段无索引 | 添加索引加速查询 |
10.3 扩展思路
-
多路召回:结合BM25和向量检索,取并集后重排序
-
LLM增强:将检索结果作为上下文,用大模型生成更自然的回答
-
意图识别:先分类用户意图,再选择不同的检索策略
以上就是整个系统的完整解析。希望这篇文章能帮助你理解如何从零构建一个实用的问答系统。如有问题,欢迎在评论区交流讨论!