三、RAG 核心引擎开发(第 6--7 周)
-
混合检索模块实现
-
BM25 关键词检索
-
语义向量检索
-
多路召回 + 融合策略
-
-
重排序(Rerank)模块
-
接入 bge-reranker 模型
-
对召回结果排序,提升精准度
-
-
Prompt 工程与大模型调用
-
设计基金领域专属提示词
-
接入 Qwen/DeepSeek 大模型
-
实现问答链:检索 → 提示 → 生成 → 输出
-
-
基础问答 Demo 验证
-
测试基金定义、风险、费率、业绩等常见问题
-
对比通用大模型,验证准确率提升
-
python
# 文件功能:基于 RAG (检索增强生成) 的基金智能问答系统
# 整合向量检索、Rerank 与豆包大模型(火山方舟 OpenAI 兼容接口),基于私有知识库回答。
import os
import warnings
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
import numpy as np
import pandas as pd
import torch
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Milvus
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from rank_bm25 import BM25Okapi
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from build_knowledge_base import resolve_embedding_device
warnings.filterwarnings("ignore")
# ===================== 毕设固定配置(不要改动) =====================
MILVUS_HOST = "localhost"
MILVUS_PORT = 19530
COLLECTION = "fund_knowledge_base"
EMBEDDING_MODEL = "BAAI/bge-large-zh"
RERANK_MODEL = "BAAI/bge-reranker-large"
CSV_PATH = "./fund_knowledge_index.csv"
EMBED_BATCH_SIZE = int(os.environ.get("EMBED_BATCH_SIZE", "32"))
# ===================== 豆包 / 火山方舟(OpenAI 兼容)=====================
# 推荐在系统环境变量中配置,勿将真实密钥写入代码仓库:
# ARK_API_KEY 或 DOUBAO_API_KEY
# DOUBAO_ENDPOINT_ID(方舟「推理接入点」ID,形如 ep-xxxx)
DOUBAO_API_KEY = os.environ.get("ARK_API_KEY") or os.environ.get("DOUBAO_API_KEY", "")
DOUBAO_BASE_URL = os.environ.get(
"DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3"
)
DOUBAO_MODEL = os.environ.get("DOUBAO_ENDPOINT_ID") or os.environ.get(
"DOUBAO_MODEL", ""
)
DOUBAO_TEMPERATURE = float(os.environ.get("DOUBAO_TEMPERATURE", "0.1"))
# =====================================================================
# 全局单例(防止重复加载大模型,加速运行)
_embeddings = None
_milvus_db = None
_rerank_tokenizer = None
_rerank_model = None
_llm = None
_rag_chain = None
def get_embeddings():
"""单例嵌入模型,避免每次问答重复加载 BGE-large。"""
global _embeddings
if _embeddings is not None:
return _embeddings
device, _ = resolve_embedding_device()
_embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": device},
encode_kwargs={
"normalize_embeddings": True,
"batch_size": EMBED_BATCH_SIZE,
},
)
return _embeddings
def load_milvus():
"""加载Milvus向量数据库(全局单例)"""
global _milvus_db
if _milvus_db is not None:
return _milvus_db
_milvus_db = Milvus(
embedding_function=get_embeddings(),
collection_name=COLLECTION,
connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT},
)
return _milvus_db
# ---------------------- 1.向量语义检索 ----------------------
def vector_search(db, query, top_k=10):
"""
语义向量检索:擅长语义理解、模糊提问
"""
ret = db.similarity_search(query, k=top_k)
return [{"content": r.page_content, "meta": r.metadata} for r in ret]
# ---------------------- 2.BM25关键词检索 ----------------------
def bm25_search(docs, query, top_k=10):
"""
BM25关键词检索:擅长专业名词、费率、代码、风险等硬性关键词
"""
corpus = [d["content"] for d in docs]
tokenized = [str(c).split(" ") for c in corpus]
bm25 = BM25Okapi(tokenized)
q_tokens = query.split(" ")
scores = bm25.get_scores(q_tokens)
top_idx = np.argsort(scores)[::-1][:top_k]
return [docs[i] for i in top_idx]
# ---------------------- 3.多路召回:混合检索融合策略 ----------------------
def hybrid_search(db, query, top=10):
"""
混合检索 = 向量检索 + BM25关键词检索
去重合并,实现多路召回
"""
v_res = vector_search(db, query, top)
b_res = bm25_search(v_res, query, top)
combined = []
contents = set()
for item in v_res + b_res:
if item["content"] not in contents:
contents.add(item["content"])
combined.append(item)
return combined[:top]
# ---------------------- 4.Rerank重排序模块 ----------------------
def rerank(query, docs):
"""
使用 bge-reranker-large 对召回文档精细排序
提升基金资料精准度,过滤无关文本
"""
global _rerank_tokenizer, _rerank_model
if _rerank_tokenizer is None:
_rerank_tokenizer = AutoTokenizer.from_pretrained(RERANK_MODEL)
if _rerank_model is None:
_rerank_model = AutoModelForSequenceClassification.from_pretrained(RERANK_MODEL)
_rerank_model.eval()
pairs = [[query, d["content"]] for d in docs]
with torch.no_grad():
inputs = _rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors="pt", max_length=512)
scores = _rerank_model(**inputs).logits.view(-1).float()
ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
return [r[0] for r in ranked]
# ---------------------- 5.豆包:LangChain ChatOpenAI + Prompt 链 ----------------------
RAG_CHAT_PROMPT = ChatPromptTemplate.from_messages(
[
(
"system",
"你是专业严谨的基金智能问答助手,只根据用户给出的「知识库」片段回答,"
"不编造、不杜撰、不脱离知识库谈无关话题;仅讨论基金与投资风险相关内容。"
"回答要通俗易懂、可分段分点;涉及风险等级、费率、收益、投资范围须与知识库一致。"
"若知识库中没有相关信息,明确说明:知识库暂无该数据。",
),
("user", "知识库:\n{context}\n\n用户问题:{question}"),
]
)
def get_llm():
"""豆包 / 方舟:OpenAI 兼容接口(单例)。"""
global _llm
if _llm is not None:
return _llm
if not DOUBAO_API_KEY.strip():
raise ValueError(
"未配置豆包 API 密钥:请设置环境变量 ARK_API_KEY 或 DOUBAO_API_KEY"
)
if not DOUBAO_MODEL.strip():
raise ValueError(
"未配置推理接入点:请设置环境变量 DOUBAO_ENDPOINT_ID(方舟控制台中的 ep-xxx)"
)
_llm = ChatOpenAI(
api_key=DOUBAO_API_KEY,
base_url=DOUBAO_BASE_URL,
model=DOUBAO_MODEL,
temperature=DOUBAO_TEMPERATURE,
timeout=60,
)
return _llm
def get_rag_chain():
"""Prompt | LLM | 解析器(单例)。"""
global _rag_chain
if _rag_chain is not None:
return _rag_chain
_rag_chain = RAG_CHAT_PROMPT | get_llm() | StrOutputParser()
return _rag_chain
def llm_answer_from_context(question: str, context_docs: list, top_n: int = 5):
"""
将 Rerank 后的文档拼成 context,经 LangChain 链调用豆包生成回答。
"""
ctx_text = "\n\n".join([d["content"] for d in context_docs[:top_n]])
chain = get_rag_chain()
try:
answer = chain.invoke({"context": ctx_text, "question": question})
print("\n🤖 【豆包大模型回答】:\n" + answer)
return answer
except Exception as e:
print(f"❌ 大模型调用失败:{e}")
return "当前大模型服务异常,请检查网络、接入点 ID 与 API 密钥。"
# ---------------------- 6.完整RAG链路:检索-排序-提示-生成 ----------------------
def rag_qa(query):
print("🔍 1/4 正在加载向量库...")
db = load_milvus()
print("🔍 2/4 执行混合检索(BM25+向量)...")
combined = hybrid_search(db, query)
print("🔍 3/4 Rerank重排序优化文档...")
reranked = rerank(query, combined)
print("🔍 4/4 调用豆包大模型生成回答...")
ans = llm_answer_from_context(query, reranked)
return ans
# ===================== 毕业设计演示Demo =====================
if __name__ == "__main__":
print("=" * 65)
print(" 基金智能问答RAG系统|毕设最终演示版本")
print("=" * 65)
print("✅ 混合检索(BM25+向量) | ✅ Rerank重排序 | ✅ 豆包大模型")
print("=" * 65)
while True:
q = input("\n💬 请输入基金问题(输入 q 退出系统):")
if q.lower() == "q":
print("\n✅ 系统退出成功")
break
rag_qa(q)