LLM如何辅助RAG从大量文档中筛选目标文档

RAG可以从大量文档中筛选出相关文档,解决LLM上下文窗口有限的问题。

然而,由于语义相似特性,在召回相关文档同时,也同时会带入不相关或仅边缘相关的文档。

这里尝试基于LLM的理解能力,示例如何辅助RAG过滤掉不相关或仅边缘相关的文档。

所用实验方案和代码示例,均参考和修改自网络资料。

1 数据和环境

1.1 数据准备

这里用LLM模拟生成400多条文档测试数据,单条记录示例如下

id: 438

file_name: "空压机-信然-维护记录-20250620.pdf"

source: "设备维护记录"

content: "空压机(EQ-5678901234)2025-06-20维护:更换油细分离器,检查安全阀,清洁进气阀。"

以下是数据测试文件链接

https://github.com/LeeXYZABC/llm_learning/blob/main/doc_bench.json

1.2 json解析

未更好处理LLM输出,一般会要求LLM输出json。

jsonfinder是专门提取字符串中的 JSON的第三方库。

安装实例如下,这里使用jsonfinder从llm输出中解析json。

pip install jsonfinder -i https://pypi.tuna.tsinghua.edu.cn/simple

参考链接如下

https://blog.csdn.net/liliang199/article/details/160257030

1.3 LLM配置

这里采用openai兼容方式访问LLM,一般情况需要提供api_key、base_url、model_name等。

可以是本地Ollama、vLLM部署的大模型,也可以是由Deepseek、qwen、kimi等提供的大模型。

复制代码
api_key =  "sk-xxxxxx"
base_url = "https://llm_provier.com/xxx/v1"
model_name = "qwen3-xxxx"

2 RAG准备

这里初始化文档RAG库,使用使用Sqlite3存储和管理文档的元数据,Faiss存储文档的向量数据。

2.1 Sqlite3元数据

Sqlite3存储和管理文档元数据,核心数据梳理如下

id: 文档标识

title:文档标题

content: 文档内容概要

source:文档来源

代码示例如下。

复制代码
import os
import json
import sqlite3
import hashlib
from datetime import datetime
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from sentence_transformers import SentenceTransformer

import numpy as np
import faiss


# ============================================================================
# 第一部分:数据模型与基础设施
# ============================================================================

@dataclass
class Document:
    """知识库文档模型"""
    id: str
    title: str
    content: str
    source: str
    created_at: str
    embedding: Optional[List[float]] = None

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


class SQLiteStore:
    """SQLite 数据存储层 - 管理文档元数据和问答记录"""

    def __init__(self, db_path: str = "knowledge_base.db"):
        self.db_path = db_path
        self._init_db()

    def _init_db(self):
        """初始化数据库表结构"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()

            # 文档表
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS documents (
                    id TEXT PRIMARY KEY,
                    title TEXT NOT NULL,
                    content TEXT NOT NULL,
                    source TEXT,
                    created_at TEXT NOT NULL,
                    embedding_hash TEXT
                )
            """)

            # 问答记录表 - 用于追踪 LLM 调用历史
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS qa_records (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    question TEXT NOT NULL,
                    answer TEXT NOT NULL,
                    retrieved_docs TEXT,
                    created_at TEXT NOT NULL,
                    latency_ms INTEGER
                )
            """)

            conn.commit()

    def insert_document(self, doc: Document) -> bool:
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    INSERT OR REPLACE INTO documents
                    (id, title, content, source, created_at, embedding_hash)
                    VALUES (?, ?, ?, ?, ?, ?)
                """, (doc.id, doc.title, doc.content, doc.source,
                      doc.created_at, self._hash_embedding(doc.embedding)))
                conn.commit()
            return True
        except Exception as e:
            print(f"Failed to insert document: {e}")
            return False

    def get_document(self, doc_id: str) -> Optional[Dict]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute(
                "SELECT id, title, content, source, created_at FROM documents WHERE id = ?",
                (doc_id,))
            row = cursor.fetchone()
            if row:
                return {
                    "id": row[0], "title": row[1], "content": row[2],
                    "source": row[3], "created_at": row[4]
                }
        return None

    def insert_qa_record(self, question: str, answer: str,
                         retrieved_docs: List[str], latency_ms: int) -> int:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute("""
                INSERT INTO qa_records (question, answer, retrieved_docs, created_at, latency_ms)
                VALUES (?, ?, ?, ?, ?)
            """, (question, answer, json.dumps(retrieved_docs),
                  datetime.now().isoformat(), latency_ms))
            conn.commit()
            return cursor.lastrowid

    def get_recent_qa(self, limit: int = 5) -> List[Dict]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute("""
                SELECT question, answer, retrieved_docs, created_at, latency_ms
                FROM qa_records ORDER BY created_at DESC LIMIT ?
            """, (limit,))
            rows = cursor.fetchall()
            return [{
                "question": r[0], "answer": r[1], "retrieved_docs": json.loads(r[2]),
                "created_at": r[3], "latency_ms": r[4]
            } for r in rows]

    def get_stats(self) -> Dict[str, Any]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT COUNT(*) FROM documents")
            doc_count = cursor.fetchone()[0]
            cursor.execute("SELECT COUNT(*) FROM qa_records")
            qa_count = cursor.fetchone()[0]
            cursor.execute("SELECT AVG(latency_ms) FROM qa_records")
            avg_latency = cursor.fetchone()[0] or 0
        return {
            "total_documents": doc_count,
            "total_qa_records": qa_count,
            "avg_latency_ms": round(avg_latency, 2)
        }

    @staticmethod
    def _hash_embedding(emb: Optional[List[float]]) -> str:
        if emb is None:
            return ""
        return hashlib.md5(np.array(emb).tobytes()).hexdigest()[:16]

2.2 Faiss向量库

这里使用Faiss存储文档的向量数据,主要对文档向量进行入库、查询。

代码示例如下

复制代码
class FAISSIndex:
    """FAISS 向量索引 - 语义检索核心"""

    def __init__(self, dimension: int = 768, index_path: str = "kb_index.faiss"):
        self.dimension = dimension
        self.index_path = index_path
        self.index: Optional[faiss.IndexFlatIP] = None
        self.doc_ids: List[str] = []
        self._load_or_create()

    def _load_or_create(self):
        if os.path.exists(self.index_path):
            self.index = faiss.read_index(self.index_path)
            # 尝试加载对应的 ID 映射文件
            id_map_path = self.index_path.replace(".faiss", "_ids.json")
            if os.path.exists(id_map_path):
                with open(id_map_path, "r") as f:
                    self.doc_ids = json.load(f)
        else:
            self.index = faiss.IndexFlatIP(self.dimension)

    def add(self, doc_id: str, embedding: List[float]):
        if self.index is None:
            self.index = faiss.IndexFlatIP(self.dimension)
        vec = np.array(embedding, dtype=np.float32).reshape(1, -1)
        faiss.normalize_L2(vec)
        self.index.add(vec)
        self.doc_ids.append(doc_id)

    def search(self, query_embedding: List[float], k: int = 5) -> List[tuple]:
        if self.index is None or self.index.ntotal == 0:
            return []
        vec = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
        faiss.normalize_L2(vec)
        distances, indices = self.index.search(vec, min(k, self.index.ntotal))
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            if idx >= 0 and idx < len(self.doc_ids):
                results.append((self.doc_ids[idx], float(dist)))
        return results

    def save(self):
        if self.index:
            faiss.write_index(self.index, self.index_path)
            id_map_path = self.index_path.replace(".faiss", "_ids.json")
            with open(id_map_path, "w") as f:
                json.dump(self.doc_ids, f)

2.3 向量模型

仅有Faiss向量库不能支持文档的入库和查询,换需要将文档向量化。

这里采用本地部署的Ollama模型,生成文档向量。

参考链接如下

https://blog.csdn.net/liliang199/article/details/152125310

代码示例如下。

复制代码
import openai
from openai import OpenAI

client = OpenAI(
    base_url='http://localhost:11434/v1/',
    api_key='ollama',   # API key 无所谓,任意填写即可
)

class EmbeddingService:
    """嵌入服务 - 使用Sentence-Transformer调用真实 LLM 生成向量"""

    def __init__(self, model: str = "qllama/bce-embedding-base_v1:latest"):
        self.model = model
   
    def generate(self, text: str) -> List[float]:
        """
        生成文本嵌入向量。
        """
        try:
            # 调用嵌入模型
            embedding_response = client.embeddings.create(
                model=self.model,   # 这是你在 Ollama 里下载好的模型名
                input=text
            )
            return embedding_response.data[0].embedding
        except Exception as e:
            print(f"Embedding generation failed: {e}")
            # 降级:返回随机向量(仅用于演示,生产环境应使用本地模型)
            return np.random.randn(768)

2.4 文档导入

然后就是文档数据的导入。

这里仅对文档的filename和content进行联合向量化,文档导入示例代码如下。

复制代码
def add_document(title: str, content: str, source: str = "manual") -> str:
    """
    向知识库添加新文档。

    Args:
        title: 文档标题
        content: 文档内容
        source: 文档来源(如 URL、文件名等)

    Returns:
        添加结果
    """
    doc_id = hashlib.md5(f"{title}:{source}".encode()).hexdigest()[:16]

    # 生成嵌入向量
    embedding = emb_service.generate(title + " " + content)

    doc = Document(
        id=doc_id,
        title=title,
        content=content,
        source=source,
        created_at=datetime.now().isoformat(),
        embedding=embedding
    )

    # 保存到 SQLite
    if not db.insert_document(doc):
        return f"❌ 文档保存失败"

    # 添加到 FAISS 索引
    faiss_index.add(doc_id, embedding)
    faiss_index.save()

    return f"✅ 文档已成功添加到知识库,ID: {doc_id}"

2.5 RAG初始化

这里初始化Sqlite数据库、Faiss向量库、以及向量Embedding模型,示例代码如下。

复制代码
db = SQLiteStore(db_path="simulated_doc_v1.db")
faiss_index = FAISSIndex(dimension=768, index_path="simulated_doc_v1.faiss")

model_embedding="qllama/bce-embedding-base_v1:latest"
emb_service = EmbeddingService(model=model_embedding)
embedding = emb_service.generate("hello!")
print((embedding[:10]))

然后就是文档测试苏剧的导入,先load测试数据,然后通过add_document导入文档。

测试数据链接来自

https://github.com/LeeXYZABC/llm_learning/blob/main/doc_bench.json

示例代码如下。

复制代码
import json 
with open("./doc_bench.json") as f:
    doc_list = json.load(f)
doc_list.sort(key=lambda a:a["file_name"])
print(len(doc_list))
print(f"doc_list: {len(doc_list)}")

for doc in doc_list:
    filename = doc["file_name"]
    content = doc["content"]
    source = doc["source"]
    result = add_document(filename, content, source)

3 LLM辅助

这里使用LLM优化RAG基础检索结果,使得系统仅返回最相关文档,过滤不太相关文档。

3.1 基础RAG检索

这里首先模拟基础RAG检索,

即为在知识库进行语义检索,返回最相关的文档内容。

示例代码如下

复制代码
def search_knowledge_base(query: str, top_k: int = 3) -> str:
    """
    在知识库中进行语义检索,返回最相关的文档内容。

    Args:
        query: 用户查询的自然语言问题
        top_k: 返回的文档数量,默认 3 条,最大 10 条

    Returns:
        检索到的文档内容,以结构化格式呈现
    """
    # 生成查询向量
    query_emb = emb_service.generate(query)

    # FAISS 向量检索
    results = faiss_index.search(query_emb, k=min(top_k, 500))

    if not results:
        return []

    # 获取文档详情
    records = []
    for i, (doc_id, score) in enumerate(results, 1):
        doc = db.get_document(doc_id)
        records.append(doc.copy())
    
    return records

然后运行实际检索

复制代码
results = search_knowledge_base("成都安检报告", top_k=10)
if results:
    print("search_knowledge_base: {len(results)}")
    for i, result in enumerate(results):
        print(f"  结果{i}: {result}")

输出如下,可见除返回两个最相关的文档1和2外,还返回了8个不太相关的文档。

这些不太相关的文档,不仅会消耗token,还有可能影响系统的最终输出。

search_knowledge_base: 10

结果0: {'id': '60a15e0bc908d5de', 'title': '成都销售用房装修改造安检-20250604.docx', 'content': '成都销售用房装修改造项目:临边防护缺失,吊篮配重不足,责令停工整改。整改后复查通过。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.733501'}

结果1: {'id': 'f93ce4bc330d598f', 'title': '成都超算中心项目安检-20250928.pdf', 'content': '成都超算中心项目:液冷系统管道压力测试合格,防静电地板接地可靠,机房洁净度达标。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.506278'}

结果2: {'id': '590473feb0037213', 'title': '天津新能源厂房建设安检-20251101.docx', 'content': '天津新能源厂房建设:高处作业人员安全带佩戴不规范,安全通道堵塞,已现场纠正并处罚。其余部分合格。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:44.523279'}

结果3: {'id': 'ee4d69d2c462cf36', 'title': '重庆两江新区工厂安检-20251015.pdf', 'content': '重庆两江新区工厂:安全出口指示灯常亮,应急疏散图张贴齐全,员工安全培训记录完整。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:48.252167'}

结果4: {'id': 'c5c5a209ba386dc9', 'title': '重庆长寿化工项目安检-20251101.pdf', 'content': '重庆长寿化工项目:反应釜夹套蒸汽压力正常,紧急冷却系统测试通过,有毒气体报警联动正常。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:48.696135'}

结果5: {'id': '833ad8e64ec5514d', 'title': '泰州医药城项目安检-20251020.pdf', 'content': '泰州医药城项目:洁净车间压差符合规范,FFU运行正常,人员净化程序正确执行。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:15.916728'}

结果6: {'id': 'f4f6c9ec7810787d', 'title': '泰州医药项目安检-20250918.pdf', 'content': '泰州医药项目:高活性药物生产区负压保持良好,BSC生物安全柜检测合格,废弃物处理合规。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:16.402014'}

结果7: {'id': 'b67e07e90448ec6d', 'title': '义乌国际商贸城改造安检-20250705.pdf', 'content': '义乌国际商贸城改造:施工区域围挡完好,警示标识明显。高空作业人员系安全带。合格。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:23.629302'}

结果8: {'id': '8779a78e397e4daa', 'title': '重庆数据中心建设项目安检-20250915.pdf', 'content': '重庆数据中心建设项目:消防系统测试合格,气体灭火钢瓶压力正常。电缆桥架接地可靠,未发现安全隐患。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:48.477538'}

结果9: {'id': 'befa795f92ad5e6c', 'title': '郑州航空港项目安检-20251012.pdf', 'content': '郑州航空港项目:物流分拣中心消防通道畅通,自动卷帘门测试正常,应急广播系统清晰。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:40:45.976177'}

3.2 LLM辅助检索

这里我们进一步引入LLM,对RAG基础检索结果进行处理,仅保留最相关的文档。

首先是openai格式配置聊天LLM,示例代码如下。

复制代码
chat_client = OpenAI(api_key=api_key, base_url=base_url)

然后基于LLM提示词构建过滤器。

提示词示例和过滤代码如下。

复制代码
sys_prompt =  """
你是一位专业的知识库检索专家,能从候选记录中挑出真正契合和回答问题的候选项。
请将候选记录组织成list原封不动的输出。
"""

user_tp = """
用户问题: 
{}

---

候选记录: 
{}
"""

from jsonfinder import jsonfinder

def json_parse(text):
    # 使用 jsonfinder 迭代提取 JSON 对象
    # 每次迭代会返回 (start_index, end_index, parsed_json_object)
    for start, end, obj in jsonfinder(text):
        # obj 是已经用 json.loads() 解析好的 Python 对象(dict 或 list)
        if obj:
            # print(text[start:end])
            # print(f"找到了 JSON:{obj}")
            return obj
    return None
    
def llm_records_filter(question, records):
    system_msg = {"role": "system", "content": sys_prompt}
    user_prompt = user_tp.format(
        question, 
        json.dumps(records, ensure_ascii=False)
    )
    user_msg = {"role": "user", "content": user_prompt}
    # print(f"llm filter, q={question}, searched {len(records)} records!\ndetails={records}")
    resp = chat_client.chat.completions.create(
            model=model_name,
            messages=[system_msg, user_msg],
            temperature=0.7,
            extra_body={"enable_thinking": False}
    )
    content = resp.choices[0].message.content
    data = json_parse(content)
    return data, content

def search_knowledge_llm(query: str, top_k: int = 3) -> str:
    records = search_knowledge_base(query, top_k)
    json_data, json_txt = llm_records_filter(query, records)
    return json_data

检索代码示例如下

复制代码
results = search_knowledge_llm("成都安检报告", top_k=10)
if results:
    print(f"search_knowledge_llm: {len(results)}")
    for i, result in enumerate(results):
        print(f"  结果{i}: {result}")

输出示例如下,LLM精确过滤掉不太相关的文档,仅保留了最相关的文档。

search_knowledge_llm: 2

结果0: {'id': '60a15e0bc908d5de', 'title': '成都销售用房装修改造安检 -20250604.docx', 'content': '成都销售用房装修改造项目:临边防护缺失,吊篮配重不足,责令停工整改。整改后复查通过。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.733501'}

结果1: {'id': 'f93ce4bc330d598f', 'title': '成都超算中心项目安检 -20250928.pdf', 'content': '成都超算中心项目:液冷系统管道压力测试合格,防静电地板接地可靠,机房洁净度达标。', 'source': '项目安全检查报告', 'created_at': '2026-05-03T20:39:57.506278'}

reference


bce-embedding-base_v1

https://hf-mirror.com/maidalun1020/bce-embedding-base_v1

从长字符串中解析合法json结构的示例

https://blog.csdn.net/liliang199/article/details/160257030

docker ollama部署轻量级嵌入模型 - EmbeddingGemma

https://blog.csdn.net/liliang199/article/details/152125310

相关推荐
Magic-Yuan1 小时前
泰勒制的崩塌 - 上
人工智能·管理
咚咚王者1 小时前
人工智能之提示词工程 第七章 行业场景深度落地案例
人工智能
无忧.芙桃1 小时前
C++IO库的超详细讲解
开发语言·c++
feasibility.1 小时前
量化:LLM与CV模型的极致压缩艺术
人工智能·科技·llm·边缘计算·量化·cv·压缩
β添砖java1 小时前
深度学习(15)卷积层
人工智能·深度学习·计算机视觉
β添砖java1 小时前
深度学习(14)确认GPU
人工智能·深度学习
浔川python社1 小时前
浔川社团第一次福利数据公布
人工智能·python·deepseek
薛定e的猫咪1 小时前
强化学习中的OOD检测:从状态异常到分布偏移
论文阅读·人工智能·深度学习
朗迹 - 张伟2 小时前
用AI开发QT——Qt与Trae开发环境搭建
开发语言·qt·策略模式