RAG可以从大量文档中筛选出相关文档,解决LLM上下文窗口有限的问题。
然而,由于语义相似特性,在召回相关文档同时,也同时会带入不相关或仅边缘相关的文档。
这里尝试基于LLM的理解能力,示例如何辅助RAG过滤掉不相关或仅边缘相关的文档。
所用实验方案和代码示例,均参考和修改自网络资料。
1 数据和环境
1.1 数据准备
这里用LLM模拟生成400多条文档测试数据,单条记录示例如下
id: 438
file_name: "空压机-信然-维护记录-20250620.pdf"
source: "设备维护记录"
content: "空压机(EQ-5678901234)2025-06-20维护:更换油细分离器,检查安全阀,清洁进气阀。"
以下是数据测试文件链接
https://github.com/LeeXYZABC/llm_learning/blob/main/doc_bench.json
1.2 json解析
未更好处理LLM输出,一般会要求LLM输出json。
jsonfinder是专门提取字符串中的 JSON的第三方库。
安装实例如下,这里使用jsonfinder从llm输出中解析json。
pip install jsonfinder -i https://pypi.tuna.tsinghua.edu.cn/simple
参考链接如下
https://blog.csdn.net/liliang199/article/details/160257030
1.3 LLM配置
这里采用openai兼容方式访问LLM,一般情况需要提供api_key、base_url、model_name等。
可以是本地Ollama、vLLM部署的大模型,也可以是由Deepseek、qwen、kimi等提供的大模型。
api_key = "sk-xxxxxx"
base_url = "https://llm_provier.com/xxx/v1"
model_name = "qwen3-xxxx"
2 RAG准备
这里初始化文档RAG库,使用使用Sqlite3存储和管理文档的元数据,Faiss存储文档的向量数据。
2.1 Sqlite3元数据
Sqlite3存储和管理文档元数据,核心数据梳理如下
id: 文档标识
title:文档标题
content: 文档内容概要
source:文档来源
代码示例如下。
import os
import json
import sqlite3
import hashlib
from datetime import datetime
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
# ============================================================================
# 第一部分:数据模型与基础设施
# ============================================================================
@dataclass
class Document:
"""知识库文档模型"""
id: str
title: str
content: str
source: str
created_at: str
embedding: Optional[List[float]] = None
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class SQLiteStore:
"""SQLite 数据存储层 - 管理文档元数据和问答记录"""
def __init__(self, db_path: str = "knowledge_base.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
"""初始化数据库表结构"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 文档表
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
content TEXT NOT NULL,
source TEXT,
created_at TEXT NOT NULL,
embedding_hash TEXT
)
""")
# 问答记录表 - 用于追踪 LLM 调用历史
cursor.execute("""
CREATE TABLE IF NOT EXISTS qa_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
question TEXT NOT NULL,
answer TEXT NOT NULL,
retrieved_docs TEXT,
created_at TEXT NOT NULL,
latency_ms INTEGER
)
""")
conn.commit()
def insert_document(self, doc: Document) -> bool:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO documents
(id, title, content, source, created_at, embedding_hash)
VALUES (?, ?, ?, ?, ?, ?)
""", (doc.id, doc.title, doc.content, doc.source,
doc.created_at, self._hash_embedding(doc.embedding)))
conn.commit()
return True
except Exception as e:
print(f"Failed to insert document: {e}")
return False
def get_document(self, doc_id: str) -> Optional[Dict]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id, title, content, source, created_at FROM documents WHERE id = ?",
(doc_id,))
row = cursor.fetchone()
if row:
return {
"id": row[0], "title": row[1], "content": row[2],
"source": row[3], "created_at": row[4]
}
return None
def insert_qa_record(self, question: str, answer: str,
retrieved_docs: List[str], latency_ms: int) -> int:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO qa_records (question, answer, retrieved_docs, created_at, latency_ms)
VALUES (?, ?, ?, ?, ?)
""", (question, answer, json.dumps(retrieved_docs),
datetime.now().isoformat(), latency_ms))
conn.commit()
return cursor.lastrowid
def get_recent_qa(self, limit: int = 5) -> List[Dict]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT question, answer, retrieved_docs, created_at, latency_ms
FROM qa_records ORDER BY created_at DESC LIMIT ?
""", (limit,))
rows = cursor.fetchall()
return [{
"question": r[0], "answer": r[1], "retrieved_docs": json.loads(r[2]),
"created_at": r[3], "latency_ms": r[4]
} for r in rows]
def get_stats(self) -> Dict[str, Any]:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM documents")
doc_count = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM qa_records")
qa_count = cursor.fetchone()[0]
cursor.execute("SELECT AVG(latency_ms) FROM qa_records")
avg_latency = cursor.fetchone()[0] or 0
return {
"total_documents": doc_count,
"total_qa_records": qa_count,
"avg_latency_ms": round(avg_latency, 2)
}
@staticmethod
def _hash_embedding(emb: Optional[List[float]]) -> str:
if emb is None:
return ""
return hashlib.md5(np.array(emb).tobytes()).hexdigest()[:16]
2.2 Faiss向量库
这里使用Faiss存储文档的向量数据,主要对文档向量进行入库、查询。
代码示例如下
class FAISSIndex:
"""FAISS 向量索引 - 语义检索核心"""
def __init__(self, dimension: int = 768, index_path: str = "kb_index.faiss"):
self.dimension = dimension
self.index_path = index_path
self.index: Optional[faiss.IndexFlatIP] = None
self.doc_ids: List[str] = []
self._load_or_create()
def _load_or_create(self):
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
# 尝试加载对应的 ID 映射文件
id_map_path = self.index_path.replace(".faiss", "_ids.json")
if os.path.exists(id_map_path):
with open(id_map_path, "r") as f:
self.doc_ids = json.load(f)
else:
self.index = faiss.IndexFlatIP(self.dimension)
def add(self, doc_id: str, embedding: List[float]):
if self.index is None:
self.index = faiss.IndexFlatIP(self.dimension)
vec = np.array(embedding, dtype=np.float32).reshape(1, -1)
faiss.normalize_L2(vec)
self.index.add(vec)
self.doc_ids.append(doc_id)
def search(self, query_embedding: List[float], k: int = 5) -> List[tuple]:
if self.index is None or self.index.ntotal == 0:
return []
vec = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
faiss.normalize_L2(vec)
distances, indices = self.index.search(vec, min(k, self.index.ntotal))
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx >= 0 and idx < len(self.doc_ids):
results.append((self.doc_ids[idx], float(dist)))
return results
def save(self):
if self.index:
faiss.write_index(self.index, self.index_path)
id_map_path = self.index_path.replace(".faiss", "_ids.json")
with open(id_map_path, "w") as f:
json.dump(self.doc_ids, f)
2.3 向量模型
仅有Faiss向量库不能支持文档的入库和查询,换需要将文档向量化。
这里采用本地部署的Ollama模型,生成文档向量。
参考链接如下
https://blog.csdn.net/liliang199/article/details/152125310
代码示例如下。
import openai
from openai import OpenAI
client = OpenAI(
base_url='http://localhost:11434/v1/',
api_key='ollama', # API key 无所谓,任意填写即可
)
class EmbeddingService:
"""嵌入服务 - 使用Sentence-Transformer调用真实 LLM 生成向量"""
def __init__(self, model: str = "qllama/bce-embedding-base_v1:latest"):
self.model = model
def generate(self, text: str) -> List[float]:
"""
生成文本嵌入向量。
"""
try:
# 调用嵌入模型
embedding_response = client.embeddings.create(
model=self.model, # 这是你在 Ollama 里下载好的模型名
input=text
)
return embedding_response.data[0].embedding
except Exception as e:
print(f"Embedding generation failed: {e}")
# 降级:返回随机向量(仅用于演示,生产环境应使用本地模型)
return np.random.randn(768)
2.4 文档导入
然后就是文档数据的导入。
这里仅对文档的filename和content进行联合向量化,文档导入示例代码如下。
def add_document(title: str, content: str, source: str = "manual") -> str:
"""
向知识库添加新文档。
Args:
title: 文档标题
content: 文档内容
source: 文档来源(如 URL、文件名等)
Returns:
添加结果
"""
doc_id = hashlib.md5(f"{title}:{source}".encode()).hexdigest()[:16]
# 生成嵌入向量
embedding = emb_service.generate(title + " " + content)
doc = Document(
id=doc_id,
title=title,
content=content,
source=source,
created_at=datetime.now().isoformat(),
embedding=embedding
)
# 保存到 SQLite
if not db.insert_document(doc):
return f"❌ 文档保存失败"
# 添加到 FAISS 索引
faiss_index.add(doc_id, embedding)
faiss_index.save()
return f"✅ 文档已成功添加到知识库,ID: {doc_id}"
2.5 RAG初始化
这里初始化Sqlite数据库、Faiss向量库、以及向量Embedding模型,示例代码如下。
db = SQLiteStore(db_path="simulated_doc_v1.db")
faiss_index = FAISSIndex(dimension=768, index_path="simulated_doc_v1.faiss")
model_embedding="qllama/bce-embedding-base_v1:latest"
emb_service = EmbeddingService(model=model_embedding)
embedding = emb_service.generate("hello!")
print((embedding[:10]))
然后就是文档测试苏剧的导入,先load测试数据,然后通过add_document导入文档。
测试数据链接来自
https://github.com/LeeXYZABC/llm_learning/blob/main/doc_bench.json
示例代码如下。
import json
with open("./doc_bench.json") as f:
doc_list = json.load(f)
doc_list.sort(key=lambda a:a["file_name"])
print(len(doc_list))
print(f"doc_list: {len(doc_list)}")
for doc in doc_list:
filename = doc["file_name"]
content = doc["content"]
source = doc["source"]
result = add_document(filename, content, source)
3 LLM辅助
这里使用LLM优化RAG基础检索结果,使得系统仅返回最相关文档,过滤不太相关文档。
3.1 基础RAG检索
这里首先模拟基础RAG检索,
即为在知识库进行语义检索,返回最相关的文档内容。
示例代码如下
def search_knowledge_base(query: str, top_k: int = 3) -> str:
"""
在知识库中进行语义检索,返回最相关的文档内容。
Args:
query: 用户查询的自然语言问题
top_k: 返回的文档数量,默认 3 条,最大 10 条
Returns:
检索到的文档内容,以结构化格式呈现
"""
# 生成查询向量
query_emb = emb_service.generate(query)
# FAISS 向量检索
results = faiss_index.search(query_emb, k=min(top_k, 500))
if not results:
return []
# 获取文档详情
records = []
for i, (doc_id, score) in enumerate(results, 1):
doc = db.get_document(doc_id)
records.append(doc.copy())
return records
然后运行实际检索
results = search_knowledge_base("成都安检报告", top_k=10)
if results:
print("search_knowledge_base: {len(results)}")
for i, result in enumerate(results):
print(f" 结果{i}: {result}")
输出如下,可见除返回两个最相关的文档1和2外,还返回了8个不太相关的文档。
这些不太相关的文档,不仅会消耗token,还有可能影响系统的最终输出。
search_knowledge_base: 10
结果0: {'id': '60a15e0bc908d5de', 'title': '成都销售用房装修改造安检-20250604.docx', 'content': '成都销售用房装修改造项目:临边防护缺失,吊篮配重不足,责令停工整改。整改后复查通过。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.733501'}
结果1: {'id': 'f93ce4bc330d598f', 'title': '成都超算中心项目安检-20250928.pdf', 'content': '成都超算中心项目:液冷系统管道压力测试合格,防静电地板接地可靠,机房洁净度达标。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.506278'}
结果2: {'id': '590473feb0037213', 'title': '天津新能源厂房建设安检-20251101.docx', 'content': '天津新能源厂房建设:高处作业人员安全带佩戴不规范,安全通道堵塞,已现场纠正并处罚。其余部分合格。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:44.523279'}
结果3: {'id': 'ee4d69d2c462cf36', 'title': '重庆两江新区工厂安检-20251015.pdf', 'content': '重庆两江新区工厂:安全出口指示灯常亮,应急疏散图张贴齐全,员工安全培训记录完整。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:48.252167'}
结果4: {'id': 'c5c5a209ba386dc9', 'title': '重庆长寿化工项目安检-20251101.pdf', 'content': '重庆长寿化工项目:反应釜夹套蒸汽压力正常,紧急冷却系统测试通过,有毒气体报警联动正常。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:48.696135'}
结果5: {'id': '833ad8e64ec5514d', 'title': '泰州医药城项目安检-20251020.pdf', 'content': '泰州医药城项目:洁净车间压差符合规范,FFU运行正常,人员净化程序正确执行。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:15.916728'}
结果6: {'id': 'f4f6c9ec7810787d', 'title': '泰州医药项目安检-20250918.pdf', 'content': '泰州医药项目:高活性药物生产区负压保持良好,BSC生物安全柜检测合格,废弃物处理合规。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:16.402014'}
结果7: {'id': 'b67e07e90448ec6d', 'title': '义乌国际商贸城改造安检-20250705.pdf', 'content': '义乌国际商贸城改造:施工区域围挡完好,警示标识明显。高空作业人员系安全带。合格。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:23.629302'}
结果8: {'id': '8779a78e397e4daa', 'title': '重庆数据中心建设项目安检-20250915.pdf', 'content': '重庆数据中心建设项目:消防系统测试合格,气体灭火钢瓶压力正常。电缆桥架接地可靠,未发现安全隐患。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:48.477538'}
结果9: {'id': 'befa795f92ad5e6c', 'title': '郑州航空港项目安检-20251012.pdf', 'content': '郑州航空港项目:物流分拣中心消防通道畅通,自动卷帘门测试正常,应急广播系统清晰。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:45.976177'}
3.2 LLM辅助检索
这里我们进一步引入LLM,对RAG基础检索结果进行处理,仅保留最相关的文档。
首先是openai格式配置聊天LLM,示例代码如下。
chat_client = OpenAI(api_key=api_key, base_url=base_url)
然后基于LLM提示词构建过滤器。
提示词示例和过滤代码如下。
sys_prompt = """
你是一位专业的知识库检索专家,能从候选记录中挑出真正契合和回答问题的候选项。
请将候选记录组织成list原封不动的输出。
"""
user_tp = """
用户问题:
{}
---
候选记录:
{}
"""
from jsonfinder import jsonfinder
def json_parse(text):
# 使用 jsonfinder 迭代提取 JSON 对象
# 每次迭代会返回 (start_index, end_index, parsed_json_object)
for start, end, obj in jsonfinder(text):
# obj 是已经用 json.loads() 解析好的 Python 对象(dict 或 list)
if obj:
# print(text[start:end])
# print(f"找到了 JSON:{obj}")
return obj
return None
def llm_records_filter(question, records):
system_msg = {"role": "system", "content": sys_prompt}
user_prompt = user_tp.format(
question,
json.dumps(records, ensure_ascii=False)
)
user_msg = {"role": "user", "content": user_prompt}
# print(f"llm filter, q={question}, searched {len(records)} records!\ndetails={records}")
resp = chat_client.chat.completions.create(
model=model_name,
messages=[system_msg, user_msg],
temperature=0.7,
extra_body={"enable_thinking": False}
)
content = resp.choices[0].message.content
data = json_parse(content)
return data, content
def search_knowledge_llm(query: str, top_k: int = 3) -> str:
records = search_knowledge_base(query, top_k)
json_data, json_txt = llm_records_filter(query, records)
return json_data
检索代码示例如下
results = search_knowledge_llm("成都安检报告", top_k=10)
if results:
print(f"search_knowledge_llm: {len(results)}")
for i, result in enumerate(results):
print(f" 结果{i}: {result}")
输出示例如下,LLM精确过滤掉不太相关的文档,仅保留了最相关的文档。
search_knowledge_llm: 2
结果0: {'id': '60a15e0bc908d5de', 'title': '成都销售用房装修改造安检 -20250604.docx', 'content': '成都销售用房装修改造项目:临边防护缺失,吊篮配重不足,责令停工整改。整改后复查通过。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.733501'}
结果1: {'id': 'f93ce4bc330d598f', 'title': '成都超算中心项目安检 -20250928.pdf', 'content': '成都超算中心项目:液冷系统管道压力测试合格,防静电地板接地可靠,机房洁净度达标。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.506278'}
reference
bce-embedding-base_v1
https://hf-mirror.com/maidalun1020/bce-embedding-base_v1
从长字符串中解析合法json结构的示例
https://blog.csdn.net/liliang199/article/details/160257030
docker ollama部署轻量级嵌入模型 - EmbeddingGemma