三、RAG 核心引擎开发(第 6–7 周)

三、RAG 核心引擎开发(第 6--7 周)

  1. 混合检索模块实现

    • BM25 关键词检索

    • 语义向量检索

    • 多路召回 + 融合策略

  2. 重排序(Rerank)模块

    • 接入 bge-reranker 模型

    • 对召回结果排序,提升精准度

  3. Prompt 工程与大模型调用

    • 设计基金领域专属提示词

    • 接入 Qwen/DeepSeek 大模型

    • 实现问答链:检索 → 提示 → 生成 → 输出

  4. 基础问答 Demo 验证

    • 测试基金定义、风险、费率、业绩等常见问题

    • 对比通用大模型,验证准确率提升

python 复制代码
# 文件功能:基于 RAG (检索增强生成) 的基金智能问答系统
# 整合向量检索、Rerank 与豆包大模型(火山方舟 OpenAI 兼容接口),基于私有知识库回答。
import os
import warnings

try:
    from dotenv import load_dotenv

    load_dotenv()
except ImportError:
    pass

import numpy as np
import pandas as pd
import torch
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Milvus
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from rank_bm25 import BM25Okapi
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from build_knowledge_base import resolve_embedding_device

warnings.filterwarnings("ignore")

# ===================== 毕设固定配置(不要改动) =====================
MILVUS_HOST = "localhost"
MILVUS_PORT = 19530
COLLECTION = "fund_knowledge_base"
EMBEDDING_MODEL = "BAAI/bge-large-zh"
RERANK_MODEL = "BAAI/bge-reranker-large"
CSV_PATH = "./fund_knowledge_index.csv"
EMBED_BATCH_SIZE = int(os.environ.get("EMBED_BATCH_SIZE", "32"))

# ===================== 豆包 / 火山方舟(OpenAI 兼容)=====================
# 推荐在系统环境变量中配置,勿将真实密钥写入代码仓库:
#   ARK_API_KEY 或 DOUBAO_API_KEY
#   DOUBAO_ENDPOINT_ID(方舟「推理接入点」ID,形如 ep-xxxx)
DOUBAO_API_KEY = os.environ.get("ARK_API_KEY") or os.environ.get("DOUBAO_API_KEY", "")
DOUBAO_BASE_URL = os.environ.get(
    "DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"
)
DOUBAO_MODEL = os.environ.get("DOUBAO_ENDPOINT_ID") or os.environ.get(
    "DOUBAO_MODEL", ""
)
DOUBAO_TEMPERATURE = float(os.environ.get("DOUBAO_TEMPERATURE", "0.1"))
# =====================================================================

# 全局单例(防止重复加载大模型,加速运行)
_embeddings = None
_milvus_db = None
_rerank_tokenizer = None
_rerank_model = None
_llm = None
_rag_chain = None


def get_embeddings():
    """单例嵌入模型,避免每次问答重复加载 BGE-large。"""
    global _embeddings
    if _embeddings is not None:
        return _embeddings
    device, _ = resolve_embedding_device()
    _embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL,
        model_kwargs={"device": device},
        encode_kwargs={
            "normalize_embeddings": True,
            "batch_size": EMBED_BATCH_SIZE,
        },
    )
    return _embeddings


def load_milvus():
    """加载Milvus向量数据库(全局单例)"""
    global _milvus_db
    if _milvus_db is not None:
        return _milvus_db
    _milvus_db = Milvus(
        embedding_function=get_embeddings(),
        collection_name=COLLECTION,
        connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT},
    )
    return _milvus_db


# ---------------------- 1.向量语义检索 ----------------------
def vector_search(db, query, top_k=10):
    """
    语义向量检索:擅长语义理解、模糊提问
    """
    ret = db.similarity_search(query, k=top_k)
    return [{"content": r.page_content, "meta": r.metadata} for r in ret]


# ---------------------- 2.BM25关键词检索 ----------------------
def bm25_search(docs, query, top_k=10):
    """
    BM25关键词检索:擅长专业名词、费率、代码、风险等硬性关键词
    """
    corpus = [d["content"] for d in docs]
    tokenized = [str(c).split(" ") for c in corpus]
    bm25 = BM25Okapi(tokenized)
    q_tokens = query.split(" ")
    scores = bm25.get_scores(q_tokens)
    top_idx = np.argsort(scores)[::-1][:top_k]
    return [docs[i] for i in top_idx]


# ---------------------- 3.多路召回:混合检索融合策略 ----------------------
def hybrid_search(db, query, top=10):
    """
    混合检索 = 向量检索 + BM25关键词检索
    去重合并,实现多路召回
    """
    v_res = vector_search(db, query, top)
    b_res = bm25_search(v_res, query, top)
    combined = []
    contents = set()
    for item in v_res + b_res:
        if item["content"] not in contents:
            contents.add(item["content"])
            combined.append(item)
    return combined[:top]


# ---------------------- 4.Rerank重排序模块 ----------------------
def rerank(query, docs):
    """
    使用 bge-reranker-large 对召回文档精细排序
    提升基金资料精准度,过滤无关文本
    """
    global _rerank_tokenizer, _rerank_model
    if _rerank_tokenizer is None:
        _rerank_tokenizer = AutoTokenizer.from_pretrained(RERANK_MODEL)
    if _rerank_model is None:
        _rerank_model = AutoModelForSequenceClassification.from_pretrained(RERANK_MODEL)
        _rerank_model.eval()

    pairs = [[query, d["content"]] for d in docs]
    with torch.no_grad():
        inputs = _rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors="pt", max_length=512)
        scores = _rerank_model(**inputs).logits.view(-1).float()
    ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
    return [r[0] for r in ranked]


# ---------------------- 5.豆包:LangChain ChatOpenAI + Prompt 链 ----------------------
RAG_CHAT_PROMPT = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "你是专业严谨的基金智能问答助手,只根据用户给出的「知识库」片段回答,"
            "不编造、不杜撰、不脱离知识库谈无关话题;仅讨论基金与投资风险相关内容。"
            "回答要通俗易懂、可分段分点;涉及风险等级、费率、收益、投资范围须与知识库一致。"
            "若知识库中没有相关信息,明确说明:知识库暂无该数据。",
        ),
        ("user", "知识库:\n{context}\n\n用户问题:{question}"),
    ]
)


def get_llm():
    """豆包 / 方舟:OpenAI 兼容接口(单例)。"""
    global _llm
    if _llm is not None:
        return _llm
    if not DOUBAO_API_KEY.strip():
        raise ValueError(
            "未配置豆包 API 密钥:请设置环境变量 ARK_API_KEY 或 DOUBAO_API_KEY"
        )
    if not DOUBAO_MODEL.strip():
        raise ValueError(
            "未配置推理接入点:请设置环境变量 DOUBAO_ENDPOINT_ID(方舟控制台中的 ep-xxx)"
        )
    _llm = ChatOpenAI(
        api_key=DOUBAO_API_KEY,
        base_url=DOUBAO_BASE_URL,
        model=DOUBAO_MODEL,
        temperature=DOUBAO_TEMPERATURE,
        timeout=60,
    )
    return _llm


def get_rag_chain():
    """Prompt | LLM | 解析器(单例)。"""
    global _rag_chain
    if _rag_chain is not None:
        return _rag_chain
    _rag_chain = RAG_CHAT_PROMPT | get_llm() | StrOutputParser()
    return _rag_chain


def llm_answer_from_context(question: str, context_docs: list, top_n: int = 5):
    """
    将 Rerank 后的文档拼成 context,经 LangChain 链调用豆包生成回答。
    """
    ctx_text = "\n\n".join([d["content"] for d in context_docs[:top_n]])
    chain = get_rag_chain()
    try:
        answer = chain.invoke({"context": ctx_text, "question": question})
        print("\n🤖 【豆包大模型回答】:\n" + answer)
        return answer
    except Exception as e:
        print(f"❌ 大模型调用失败:{e}")
        return "当前大模型服务异常,请检查网络、接入点 ID 与 API 密钥。"


# ---------------------- 6.完整RAG链路:检索-排序-提示-生成 ----------------------
def rag_qa(query):
    print("🔍 1/4 正在加载向量库...")
    db = load_milvus()
    print("🔍 2/4 执行混合检索(BM25+向量)...")
    combined = hybrid_search(db, query)
    print("🔍 3/4 Rerank重排序优化文档...")
    reranked = rerank(query, combined)
    print("🔍 4/4 调用豆包大模型生成回答...")
    ans = llm_answer_from_context(query, reranked)
    return ans


# ===================== 毕业设计演示Demo =====================
if __name__ == "__main__":
    print("=" * 65)
    print("         基金智能问答RAG系统|毕设最终演示版本")
    print("=" * 65)
    print("✅ 混合检索(BM25+向量) | ✅ Rerank重排序 | ✅ 豆包大模型")
    print("=" * 65)

    while True:
        q = input("\n💬 请输入基金问题(输入 q 退出系统):")
        if q.lower() == "q":
            print("\n✅ 系统退出成功")
            break
        rag_qa(q)
相关推荐
H_unique3 小时前
Trae实现Web UI自动化测试
python·ui·ai编程·trae
陆业聪3 小时前
AI编码提效实战:Skill、Rule与上下文工程
android·ai编程·claude code
yuanlaile3 小时前
2026 AI 大模型原生应用开发全栈实战|从Prompt到DeepSeek+Dify+RAG,吃透企业级落地
prompt·ai编程·ai大模型应用开发·dify教程
嵌入式小企鹅3 小时前
国产算力突破、RISC-V车规生态成型、AI编程工具免费化浪潮
学习·开源·ai编程·risc-v·昇腾·deepseek v4
csdn2015_4 小时前
github copilot 在 IDEA里面怎么使用
ai·ai编程
Goboy13 小时前
「我的第一次移动端 AI 办公」TRAE SOLO 三端联动, 通勤路上就把活干了,这设计,老罗看了都想当场退役
人工智能·ai编程·trae
05候补工程师14 小时前
[实战复盘] 拒绝 AI 屎山!我从设计模式中学到的“调教”AI 新范式
人工智能·python·设计模式·ai·ai编程