目录
[1.1 业务场景设定](#1.1 业务场景设定)
[1.2 核心需求分析](#1.2 核心需求分析)
[1.3 技术选型分析](#1.3 技术选型分析)
[1.4 预期效果展示](#1.4 预期效果展示)
[2.1 核心功能模块](#2.1 核心功能模块)
[2.2 功能详细说明](#2.2 功能详细说明)
[3.1 项目结构](#3.1 项目结构)
[3.2 环境配置](#3.2 环境配置)
[3.3 数据模型设计](#3.3 数据模型设计)
[3.4 数据库初始化](#3.4 数据库初始化)
[3.5 向量数据库封装](#3.5 向量数据库封装)
[3.6 文档处理工具](#3.6 文档处理工具)
[3.7 Embedding服务](#3.7 Embedding服务)
[3.8 大模型服务](#3.8 大模型服务)
[3.9 RAG核心服务](#3.9 RAG核心服务)
[3.10 异步任务处理](#3.10 异步任务处理)
[3.11 API接口实现](#3.11 API接口实现)
[4.1 项目结构](#4.1 项目结构)
[4.2 环境配置](#4.2 环境配置)
[4.3 主入口文件](#4.3 主入口文件)
[4.4 API封装](#4.4 API封装)
[4.5 状态管理](#4.5 状态管理)
[4.6 核心组件实现](#4.6 核心组件实现)
[5.1 技术要点回顾](#5.1 技术要点回顾)
[5.2 系统优势与特点](#5.2 系统优势与特点)
[5.3 应用场景扩展](#5.3 应用场景扩展)
[5.4 优化方向](#5.4 优化方向)
[5.5 学习建议](#5.5 学习建议)
[5.6 结语](#5.6 结语)
[1. PyTorch安装失败](#1. PyTorch安装失败)
[2. ChromaDB安装问题](#2. ChromaDB安装问题)
[3. Node依赖安装慢](#3. Node依赖安装慢)
[4. Redis连接问题](#4. Redis连接问题)
前言
在人工智能技术飞速发展的今天,大语言模型(LLM)已经展现出强大的文本理解和生成能力。然而,通用大模型在面对特定领域问题时,往往存在知识滞后、幻觉现象以及缺乏私有数据支撑等痛点。检索增强生成(Retrieval-Augmented Generation,RAG)技术的出现,为解决这些问题提供了优雅的方案。
RAG的核心思想是将信息检索与文本生成相结合,在用户提出问题后,首先从知识库中检索出相关的文档片段,然后将这些片段作为上下文与用户问题一起提交给大语言模型,从而生成更准确、更具时效性的回答。这种架构既保留了大模型的生成能力,又通过外部知识库实现了知识的动态更新和定制化扩展。
本文将带领读者从零开始,完整实现一个基于RAG架构的智能问答系统。我们将采用前后端分离的架构,后端使用Python语言和FastAPI框架,负责文档处理、向量化存储、检索召回和模型调用;前端使用Vue3框架,提供用户友好的交互界面。通过本项目的学习,读者不仅能够理解RAG的核心原理,还能掌握实际工程落地的完整流程,包括:文档切片、向量数据库应用、Embedding模型选择、大模型API集成、异步任务处理等关键技术点。
本文面向初级到中级的技术爱好者,无论您是后端开发、前端工程师、运维人员还是AI爱好者,都能从中获得实用的知识和技能。全文约20000字,包含完整的代码实现和详细注释,每个模块都可以直接运行和测试。我们将以构建一个企业技术文档问答助手为例,演示如何将技术文档转化为智能知识库,并实现精准的问答功能。
让我们开始这段RAG应用开发的实践之旅,亲手打造一个智能、高效、可扩展的AI问答系统。
一、项目需求效果分析
1.1 业务场景设定
假设我们是一家科技公司,拥有大量的技术文档、产品手册、API文档和常见问题解答。这些文档分散在各个部门和系统中,员工在查找信息时需要花费大量时间。我们希望构建一个智能问答系统,让员工能够用自然语言提问,系统自动从文档库中检索相关信息,并由AI生成准确、易懂的答案。
1.2 核心需求分析
功能性需求:
文档管理:支持上传多种格式的文档(TXT、Markdown、PDF等),对文档内容进行解析和预处理
知识库构建:将文档内容切分成合理的文本片段,生成向量嵌入并存储到向量数据库中
智能问答:用户输入问题后,系统检索相关文档片段,结合大模型生成答案
对话历史:保存用户与系统的对话记录,支持上下文关联的多轮对话
答案溯源:显示答案所依据的原始文档片段,增加可信度
管理后台:支持知识库的增删改查、文档更新等管理功能
非功能性需求:
响应速度:问答响应时间控制在3秒以内(不含大模型生成时间)
准确性:检索结果的准确率达到85%以上
可扩展性:支持后续增加新的文档类型和模型接口
易用性:界面简洁直观,交互流畅
1.3 技术选型分析
后端技术栈:
Web框架:FastAPI(高性能、异步支持、自动API文档)
向量数据库:ChromaDB(轻量级、易部署、支持持久化)
Embedding模型:text2vec-large-chinese(开源中文向量模型)
大语言模型:OpenAI API或国产模型(如智谱AI、百度文心等)
文档解析:PyPDF2(PDF解析)、python-docx(Word文档)、markdown(Markdown解析)
任务队列:Celery + Redis(处理耗时的文档处理任务)
前端技术栈:
框架:Vue3 + Composition API
UI组件库:Element Plus
状态管理:Pinia
HTTP客户端:Axios
富文本渲染:markdown-it
部署环境:
容器化:Docker + Docker Compose
反向代理:Nginx
1.4 预期效果展示
系统完成后,用户将获得如下体验:
-
在知识库管理页面,管理员可以上传文档,系统自动处理并显示处理进度
-
在问答界面,用户输入"如何使用JWT进行用户认证?",系统将:
-
检索出相关文档片段(如JWT配置指南、API安全规范等)
-
AI整合信息后生成:"根据技术文档,JWT认证需要以下步骤:1. 安装依赖... 2. 配置密钥... 3. 生成token... 详细配置请参考..."
-
展示参考文档来源,用户可点击查看原文
-
-
支持连续对话,例如追问"token过期时间如何设置?",系统能结合上下文给出准确回答
二、项目功能列表
2.1 核心功能模块
| 模块 | 功能点 | 优先级 | 说明 |
|---|---|---|---|
| 文档管理 | 文档上传 | 高 | 支持txt、md、pdf、docx格式 |
| 文档列表 | 高 | 展示已上传文档,支持搜索和筛选 | |
| 文档删除 | 高 | 删除文档及对应的向量数据 | |
| 文档更新 | 中 | 重新处理已更新的文档 | |
| 处理进度查询 | 中 | 实时显示文档处理状态 | |
| 知识库管理 | 向量化处理 | 高 | 文档切片、生成向量、存储 |
| 知识库统计 | 低 | 显示文档数量、切片数量等 | |
| 切片查看 | 中 | 查看文档的切片详情 | |
| 智能问答 | 单轮问答 | 高 | 基于知识库回答问题 |
| 多轮对话 | 高 | 支持上下文连续的对话 | |
| 对话历史 | 中 | 保存和查看历史对话 | |
| 答案溯源 | 高 | 显示答案的参考来源 | |
| 相似度调节 | 中 | 调节检索的相似度阈值 | |
| 系统管理 | API密钥配置 | 高 | 配置大模型API密钥 |
| 模型选择 | 中 | 支持切换不同的大模型 | |
| 系统日志 | 低 | 查看操作日志和错误日志 |
2.2 功能详细说明
文档处理流程:
- 用户上传文档 → 2. 后端接收并存储 → 3. 异步任务处理 → 4. 解析文档内容 → 5. 文本切片(按段落、固定大小等策略) → 6. 调用Embedding模型生成向量 → 7. 存储到向量数据库 → 8. 更新文档状态
问答流程:
- 用户输入问题 → 2. 前端发送请求 → 3. 后端将问题向量化 → 4. 在向量数据库中检索相似片段 → 5. 构建Prompt(包含检索到的上下文) → 6. 调用大模型API → 7. 解析返回结果 → 8. 返回答案和来源
三、服务端编码
3.1 项目结构
rag-ai-system-backend/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI应用入口
│ ├── config.py # 配置文件
│ ├── models/
│ │ ├── __init__.py
│ │ ├── document.py # 文档数据模型
│ │ ├── conversation.py # 对话数据模型
│ │ └── schemas.py # Pydantic模型
│ ├── services/
│ │ ├── __init__.py
│ │ ├── document_service.py # 文档处理服务
│ │ ├── embedding_service.py# 向量化服务
│ │ ├── rag_service.py # RAG问答服务
│ │ └── llm_service.py # 大模型调用服务
│ ├── api/
│ │ ├── __init__.py
│ │ ├── v1/
│ │ │ ├── __init__.py
│ │ │ ├── documents.py # 文档管理接口
│ │ │ ├── chat.py # 问答接口
│ │ │ └── knowledge.py # 知识库管理接口
│ ├── core/
│ │ ├── __init__.py
│ │ ├── database.py # 数据库连接
│ │ ├── vector_store.py # 向量数据库封装
│ │ └── tasks.py # Celery异步任务
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── file_parser.py # 文件解析工具
│ │ ├── text_splitter.py # 文本切片工具
│ │ └── logger.py # 日志工具
│ └── static/ # 静态文件
├── requirements.txt
├── .env.example
├── docker-compose.yml
└── README.md
3.2 环境配置
requirements.txt
# Web框架
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
# 数据库和向量存储
sqlalchemy==2.0.23
chromadb==0.4.18
pymysql==1.1.0
# 异步任务
celery==5.3.4
redis==5.0.1
# AI相关
openai==1.3.0
sentence-transformers==2.2.2
torch==2.1.0
# 文档解析
PyPDF2==3.0.1
python-docx==1.1.0
markdown==3.5.1
beautifulsoup4==4.12.2
# 工具库
python-dotenv==1.0.0
pydantic==2.5.0
pydantic-settings==2.1.0
httpx==0.25.1
loguru==0.7.2
.env.example
# 应用配置
APP_NAME="RAG AI System"
APP_VERSION="1.0.0"
DEBUG=True
SECRET_KEY="your-secret-key-here"
# 数据库配置
DATABASE_URL="mysql+pymysql://user:password@localhost:3306/rag_ai"
REDIS_URL="redis://localhost:6379/0"
# 向量数据库配置
CHROMA_PERSIST_DIR="./chroma_data"
CHROMA_COLLECTION_NAME="rag_knowledge_base"
# Embedding模型配置
EMBEDDING_MODEL_NAME="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
EMBEDDING_DEVICE="cpu" # cpu or cuda
# LLM配置
LLM_PROVIDER="openai" # openai, zhipu, qwen
OPENAI_API_KEY="your-openai-api-key"
OPENAI_BASE_URL="https://api.openai.com/v1"
OPENAI_MODEL="gpt-3.5-turbo"
# 智谱AI配置
ZHIPU_API_KEY="your-zhipu-api-key"
ZHIPU_MODEL="glm-4"
# 阿里千问配置
QWEN_API_KEY="your-qwen-api-key"
QWEN_MODEL="qwen-turbo"
# 文档处理配置
CHUNK_SIZE=500
CHUNK_OVERLAP=50
MAX_FILE_SIZE=10485760 # 10MB
ALLOWED_EXTENSIONS=".txt,.md,.pdf,.docx"
# 检索配置
TOP_K=5
SIMILARITY_THRESHOLD=0.7
# Celery配置
CELERY_BROKER_URL="redis://localhost:6379/0"
CELERY_RESULT_BACKEND="redis://localhost:6379/1"
app/config.py
"""
配置文件管理模块
使用pydantic-settings管理应用配置,支持从环境变量加载
"""
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import List, Optional
class Settings(BaseSettings):
"""应用配置类"""
# 基础配置
app_name: str = Field("RAG AI System", env="APP_NAME")
app_version: str = Field("1.0.0", env="APP_VERSION")
debug: bool = Field(False, env="DEBUG")
secret_key: str = Field("change-this-in-production", env="SECRET_KEY")
# 数据库配置
database_url: str = Field("sqlite:///./rag_ai.db", env="DATABASE_URL")
redis_url: str = Field("redis://localhost:6379/0", env="REDIS_URL")
# 向量数据库配置
chroma_persist_dir: str = Field("./chroma_data", env="CHROMA_PERSIST_DIR")
chroma_collection_name: str = Field("rag_knowledge_base", env="CHROMA_COLLECTION_NAME")
# Embedding模型配置
embedding_model_name: str = Field(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
env="EMBEDDING_MODEL_NAME"
)
embedding_device: str = Field("cpu", env="EMBEDDING_DEVICE")
# LLM配置
llm_provider: str = Field("openai", env="LLM_PROVIDER")
openai_api_key: Optional[str] = Field(None, env="OPENAI_API_KEY")
openai_base_url: str = Field("https://api.openai.com/v1", env="OPENAI_BASE_URL")
openai_model: str = Field("gpt-3.5-turbo", env="OPENAI_MODEL")
zhipu_api_key: Optional[str] = Field(None, env="ZHIPU_API_KEY")
zhipu_model: str = Field("glm-4", env="ZHIPU_MODEL")
qwen_api_key: Optional[str] = Field(None, env="QWEN_API_KEY")
qwen_model: str = Field("qwen-turbo", env="QWEN_MODEL")
# 文档处理配置
chunk_size: int = Field(500, env="CHUNK_SIZE")
chunk_overlap: int = Field(50, env="CHUNK_OVERLAP")
max_file_size: int = Field(10485760, env="MAX_FILE_SIZE") # 10MB
allowed_extensions: List[str] = Field(
[".txt", ".md", ".pdf", ".docx"],
env="ALLOWED_EXTENSIONS"
)
# 检索配置
top_k: int = Field(5, env="TOP_K")
similarity_threshold: float = Field(0.7, env="SIMILARITY_THRESHOLD")
# Celery配置
celery_broker_url: str = Field("redis://localhost:6379/0", env="CELERY_BROKER_URL")
celery_result_backend: str = Field("redis://localhost:6379/1", env="CELERY_RESULT_BACKEND")
class Config:
env_file = ".env"
case_sensitive = False
# 创建全局配置实例
settings = Settings()
3.3 数据模型设计
app/models/document.py
python
"""
文档管理数据模型
定义文档和文档片段的数据表结构
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, Float, Enum, Index
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declarative_base
import enum
Base = declarative_base()
class DocumentStatus(str, enum.Enum):
"""文档处理状态枚举"""
PENDING = "pending" # 等待处理
PROCESSING = "processing" # 处理中
COMPLETED = "completed" # 处理完成
FAILED = "failed" # 处理失败
class Document(Base):
"""文档主表"""
__tablename__ = "documents"
id = Column(Integer, primary_key=True, autoincrement=True)
filename = Column(String(255), nullable=False, comment="原始文件名")
file_path = Column(String(500), nullable=False, comment="文件存储路径")
file_size = Column(Integer, comment="文件大小(字节)")
file_type = Column(String(50), comment="文件类型")
# 文档元数据
title = Column(String(500), comment="文档标题")
author = Column(String(200), comment="作者")
create_time = Column(DateTime, server_default=func.now(), comment="创建时间")
update_time = Column(DateTime, onupdate=func.now(), comment="更新时间")
# 处理状态
status = Column(Enum(DocumentStatus), default=DocumentStatus.PENDING, comment="处理状态")
error_message = Column(Text, comment="错误信息")
chunks_count = Column(Integer, default=0, comment="切片数量")
# 统计信息
view_count = Column(Integer, default=0, comment="查看次数")
# 索引
__table_args__ = (
Index('idx_status', 'status'),
Index('idx_create_time', 'create_time'),
)
class DocumentChunk(Base):
"""文档切片表"""
__tablename__ = "document_chunks"
id = Column(Integer, primary_key=True, autoincrement=True)
document_id = Column(Integer, nullable=False, comment="关联的文档ID")
chunk_index = Column(Integer, nullable=False, comment="切片序号")
content = Column(Text, nullable=False, comment="切片文本内容")
# 向量存储的ID(关联到向量数据库)
vector_id = Column(String(100), comment="向量数据库中的ID")
# 元数据
create_time = Column(DateTime, server_default=func.now(), comment="创建时间")
# 索引
__table_args__ = (
Index('idx_document_id', 'document_id'),
Index('idx_vector_id', 'vector_id'),
)
app/models/schemas.py
python
"""
Pydantic模型定义
用于API请求和响应的数据验证
"""
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime
from enum import Enum
# ========== 文档相关模型 ==========
class DocumentUploadResponse(BaseModel):
"""文档上传响应"""
id: int
filename: str
status: str
message: str
class DocumentInfo(BaseModel):
"""文档信息"""
id: int
filename: str
title: Optional[str]
file_size: int
status: str
chunks_count: int
create_time: datetime
update_time: Optional[datetime]
class Config:
from_attributes = True
class DocumentDetail(BaseModel):
"""文档详情"""
id: int
filename: str
title: Optional[str]
author: Optional[str]
file_size: int
file_type: str
status: str
error_message: Optional[str]
chunks_count: int
create_time: datetime
update_time: Optional[datetime]
class DocumentChunkInfo(BaseModel):
"""文档切片信息"""
id: int
chunk_index: int
content: str
document_id: int
# ========== 问答相关模型 ==========
class ChatRequest(BaseModel):
"""问答请求"""
question: str = Field(..., min_length=1, max_length=2000, description="用户问题")
conversation_id: Optional[int] = Field(None, description="对话ID,用于多轮对话")
top_k: Optional[int] = Field(5, ge=1, le=20, description="检索数量")
temperature: Optional[float] = Field(0.7, ge=0, le=2, description="生成温度")
class Config:
json_schema_extra = {
"example": {
"question": "什么是RAG技术?",
"conversation_id": None,
"top_k": 5,
"temperature": 0.7
}
}
class SourceDocument(BaseModel):
"""参考来源"""
document_id: int
document_name: str
chunk_id: int
content: str
similarity: float
class ChatResponse(BaseModel):
"""问答响应"""
answer: str
sources: List[SourceDocument]
conversation_id: int
message_id: int
class ConversationInfo(BaseModel):
"""对话信息"""
id: int
title: str
created_at: datetime
updated_at: datetime
class MessageInfo(BaseModel):
"""消息信息"""
id: int
role: str
content: str
sources: Optional[str]
created_at: datetime
# ========== 知识库管理模型 ==========
class KnowledgeBaseStats(BaseModel):
"""知识库统计"""
total_documents: int
total_chunks: int
completed_documents: int
failed_documents: int
processing_documents: int
class SearchRequest(BaseModel):
"""搜索请求"""
query: str = Field(..., min_length=1, description="搜索查询")
top_k: int = Field(5, ge=1, le=20, description="返回结果数量")
similarity_threshold: Optional[float] = Field(0.7, ge=0, le=1, description="相似度阈值")
class SearchResult(BaseModel):
"""搜索结果"""
content: str
document_id: int
document_name: str
chunk_id: int
similarity: float
# ========== 系统管理模型 ==========
class SystemConfig(BaseModel):
"""系统配置"""
llm_provider: str
embedding_model: str
top_k: int
similarity_threshold: float
class APIKeyConfig(BaseModel):
"""API密钥配置"""
provider: str
api_key: str
model: Optional[str] = None
3.4 数据库初始化
app/core/database.py
python
"""
数据库连接管理模块
提供SQLAlchemy数据库会话管理和初始化功能
"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from typing import Generator
from app.config import settings
import logging
logger = logging.getLogger(__name__)
# 创建数据库引擎
# 根据数据库类型配置连接池
if settings.database_url.startswith("sqlite"):
# SQLite配置
engine = create_engine(
settings.database_url,
connect_args={"check_same_thread": False},
echo=settings.debug
)
else:
# MySQL/PostgreSQL配置
engine = create_engine(
settings.database_url,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
echo=settings.debug
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建基类
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""
依赖注入:获取数据库会话
每个请求结束后自动关闭会话
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_database():
"""
初始化数据库
创建所有表结构
"""
try:
# 导入所有模型,确保Base知道它们
from app.models import document, conversation
# 创建表
Base.metadata.create_all(bind=engine)
logger.info("数据库初始化成功")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise
3.5 向量数据库封装
app/core/vector_store.py
python
"""
向量数据库封装模块
基于ChromaDB实现文档向量的存储和检索
"""
import chromadb
from chromadb.config import Settings as ChromaSettings
from typing import List, Dict, Any, Optional, Tuple
import logging
from app.config import settings
logger = logging.getLogger(__name__)
class VectorStore:
"""
向量数据库管理类
封装ChromaDB的常见操作
"""
def __init__(self):
"""初始化向量数据库客户端"""
self.client = chromadb.Client(ChromaSettings(
chroma_db_impl="duckdb+parquet",
persist_directory=settings.chroma_persist_dir,
anonymized_telemetry=False
))
# 获取或创建集合
self.collection = self.client.get_or_create_collection(
name=settings.chroma_collection_name,
metadata={"hnsw:space": "cosine"} # 使用余弦相似度
)
logger.info(f"向量数据库初始化成功,集合名称: {settings.chroma_collection_name}")
def add_documents(
self,
ids: List[str],
embeddings: List[List[float]],
metadatas: List[Dict[str, Any]],
documents: List[str]
) -> bool:
"""
添加文档向量到数据库
Args:
ids: 文档ID列表
embeddings: 向量列表
metadatas: 元数据列表
documents: 原始文本列表
Returns:
是否添加成功
"""
try:
self.collection.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents
)
logger.info(f"成功添加 {len(ids)} 个文档向量")
return True
except Exception as e:
logger.error(f"添加文档向量失败: {e}")
return False
def search(
self,
query_embedding: List[float],
top_k: int = 5,
where: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
向量相似度搜索
Args:
query_embedding: 查询向量
top_k: 返回结果数量
where: 过滤条件
Returns:
搜索结果列表,每个结果包含id、distance、metadata、document
"""
try:
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where
)
# 格式化返回结果
formatted_results = []
if results['ids'] and results['ids'][0]:
for i in range(len(results['ids'][0])):
formatted_results.append({
'id': results['ids'][0][i],
'distance': results['distances'][0][i] if results['distances'] else 1.0,
'metadata': results['metadatas'][0][i] if results['metadatas'] else {},
'document': results['documents'][0][i] if results['documents'] else ''
})
logger.info(f"搜索完成,找到 {len(formatted_results)} 个结果")
return formatted_results
except Exception as e:
logger.error(f"向量搜索失败: {e}")
return []
def delete_documents(self, ids: List[str]) -> bool:
"""
删除文档向量
Args:
ids: 要删除的文档ID列表
Returns:
是否删除成功
"""
try:
self.collection.delete(ids=ids)
logger.info(f"成功删除 {len(ids)} 个文档向量")
return True
except Exception as e:
logger.error(f"删除文档向量失败: {e}")
return False
def delete_by_metadata(self, where: Dict[str, Any]) -> bool:
"""
根据元数据条件删除文档
Args:
where: 过滤条件,例如 {"document_id": 1}
Returns:
是否删除成功
"""
try:
# 先查询符合条件的文档
results = self.collection.get(where=where)
if results['ids']:
self.collection.delete(ids=results['ids'])
logger.info(f"成功删除 {len(results['ids'])} 个文档向量")
return True
except Exception as e:
logger.error(f"根据元数据删除文档向量失败: {e}")
return False
def get_collection_stats(self) -> Dict[str, Any]:
"""获取集合统计信息"""
try:
count = self.collection.count()
return {
'name': settings.chroma_collection_name,
'document_count': count,
'metadata': self.collection.metadata
}
except Exception as e:
logger.error(f"获取集合统计失败: {e}")
return {}
# 创建全局向量数据库实例
vector_store = VectorStore()
3.6 文档处理工具
app/utils/file_parser.py
python
"""
文件解析工具模块
支持多种格式文档的内容提取
"""
import os
import logging
from typing import Optional, List
import PyPDF2
from docx import Document as DocxDocument
import markdown
from bs4 import BeautifulSoup
logger = logging.getLogger(__name__)
class FileParser:
"""
文件解析器
根据文件类型提取文本内容
"""
@staticmethod
def parse_txt(file_path: str) -> str:
"""解析TXT文件"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
# 尝试其他编码
with open(file_path, 'r', encoding='gbk') as f:
return f.read()
@staticmethod
def parse_markdown(file_path: str) -> str:
"""解析Markdown文件,提取纯文本"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
md_content = f.read()
# 将Markdown转换为HTML
html = markdown.markdown(md_content)
# 从HTML中提取纯文本
soup = BeautifulSoup(html, 'html.parser')
return soup.get_text()
except Exception as e:
logger.error(f"解析Markdown文件失败: {e}")
raise
@staticmethod
def parse_pdf(file_path: str) -> str:
"""解析PDF文件"""
try:
text_content = []
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text = page.extract_text()
if text:
text_content.append(text)
return '\n'.join(text_content)
except Exception as e:
logger.error(f"解析PDF文件失败: {e}")
raise
@staticmethod
def parse_docx(file_path: str) -> str:
"""解析Word文档"""
try:
doc = DocxDocument(file_path)
paragraphs = []
for para in doc.paragraphs:
if para.text.strip():
paragraphs.append(para.text)
# 也提取表格内容
for table in doc.tables:
for row in table.rows:
row_text = []
for cell in row.cells:
if cell.text.strip():
row_text.append(cell.text)
if row_text:
paragraphs.append(' | '.join(row_text))
return '\n'.join(paragraphs)
except Exception as e:
logger.error(f"解析Word文档失败: {e}")
raise
@staticmethod
def parse_file(file_path: str, file_type: str) -> str:
"""
根据文件类型解析文件
Args:
file_path: 文件路径
file_type: 文件类型(扩展名)
Returns:
提取的文本内容
"""
file_type = file_type.lower()
parsers = {
'.txt': FileParser.parse_txt,
'.md': FileParser.parse_markdown,
'.pdf': FileParser.parse_pdf,
'.docx': FileParser.parse_docx,
}
parser = parsers.get(file_type)
if not parser:
raise ValueError(f"不支持的文件类型: {file_type}")
try:
content = parser(file_path)
logger.info(f"成功解析文件: {file_path}, 内容长度: {len(content)}")
return content
except Exception as e:
logger.error(f"解析文件失败 {file_path}: {e}")
raise
app/utils/text_splitter.py
python
"""
文本切片工具模块
将长文本切分成适合向量化的短文本块
"""
import re
import logging
from typing import List, Optional
from app.config import settings
logger = logging.getLogger(__name__)
class TextSplitter:
"""
文本切片器
支持多种切片策略
"""
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 50,
separators: Optional[List[str]] = None
):
"""
初始化文本切片器
Args:
chunk_size: 切片大小(字符数)
chunk_overlap: 切片重叠大小
separators: 分隔符列表,按优先级排序
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = separators or ["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
def split_by_separator(self, text: str, separator: str) -> List[str]:
"""使用分隔符分割文本"""
if separator == "":
# 按字符分割
return list(text)
# 使用正则表达式分割,保留分隔符
pattern = re.escape(separator)
parts = re.split(f'({pattern})', text)
# 将分隔符合并回原文本
chunks = []
for i in range(0, len(parts), 2):
if i + 1 < len(parts):
chunks.append(parts[i] + parts[i + 1])
else:
chunks.append(parts[i])
return [chunk for chunk in chunks if chunk.strip()]
def merge_chunks(self, chunks: List[str], separator: str) -> str:
"""合并文本块"""
return separator.join(chunks)
def split_text(self, text: str) -> List[str]:
"""
将文本切分成多个块
Args:
text: 原始文本
Returns:
文本块列表
"""
if not text:
return []
# 如果文本长度小于切片大小,直接返回
if len(text) <= self.chunk_size:
return [text]
chunks = []
current_chunk = ""
# 尝试使用不同的分隔符进行分割
for separator in self.separators:
if not current_chunk:
# 使用当前分隔符分割
splits = self.split_by_separator(text, separator)
# 合并小的块
for split in splits:
if len(current_chunk) + len(split) <= self.chunk_size:
current_chunk += split
else:
# 当前块已满,保存
if current_chunk:
chunks.append(current_chunk.strip())
# 如果split本身超过chunk_size,需要进一步分割
if len(split) > self.chunk_size:
# 递归分割长块
sub_chunks = self.split_text(split)
chunks.extend(sub_chunks)
current_chunk = ""
else:
current_chunk = split
# 处理剩余的块
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = ""
# 如果成功分割出多个块,返回
if len(chunks) > 1:
break
# 如果没有成功分割,使用固定大小分割
if not chunks:
chunks = self.fixed_size_split(text)
# 添加重叠
if self.chunk_overlap > 0 and len(chunks) > 1:
chunks = self.add_overlap(chunks)
logger.info(f"文本切片完成,共生成 {len(chunks)} 个块")
return chunks
def fixed_size_split(self, text: str) -> List[str]:
"""固定大小分割"""
chunks = []
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
chunk = text[i:i + self.chunk_size]
if chunk:
chunks.append(chunk)
return chunks
def add_overlap(self, chunks: List[str]) -> List[str]:
"""为相邻的块添加重叠内容"""
overlapped_chunks = []
for i, chunk in enumerate(chunks):
if i == 0:
overlapped_chunks.append(chunk)
else:
# 从前一个块中获取重叠部分
prev_chunk = chunks[i - 1]
overlap_text = prev_chunk[-self.chunk_overlap:] if len(prev_chunk) > self.chunk_overlap else prev_chunk
# 将重叠部分添加到当前块开头
overlapped_chunk = overlap_text + chunk
overlapped_chunks.append(overlapped_chunk)
return overlapped_chunks
def split_document(self, content: str, metadata: dict = None) -> List[dict]:
"""
分割文档并返回带元数据的块
Args:
content: 文档内容
metadata: 文档元数据
Returns:
带元数据的块列表,每个块包含text和metadata
"""
chunks = self.split_text(content)
result = []
for i, chunk in enumerate(chunks):
chunk_metadata = {
"chunk_index": i,
"chunk_count": len(chunks),
**metadata
} if metadata else {"chunk_index": i}
result.append({
"text": chunk,
"metadata": chunk_metadata
})
return result
# 使用配置创建全局切片器实例
text_splitter = TextSplitter(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap
)
3.7 Embedding服务
app/services/embedding_service.py
python
"""
向量化服务模块
使用Sentence Transformers将文本转换为向量
"""
import logging
from typing import List, Union
import numpy as np
from sentence_transformers import SentenceTransformer
from app.config import settings
logger = logging.getLogger(__name__)
class EmbeddingService:
"""
文本向量化服务
将文本转换为高维向量表示
"""
def __init__(self):
"""初始化Embedding模型"""
self.model_name = settings.embedding_model_name
self.device = settings.embedding_device
try:
# 加载模型
self.model = SentenceTransformer(self.model_name, device=self.device)
logger.info(f"Embedding模型加载成功: {self.model_name}, 设备: {self.device}")
# 获取向量维度
self.dimension = self.model.get_sentence_embedding_dimension()
logger.info(f"向量维度: {self.dimension}")
except Exception as e:
logger.error(f"加载Embedding模型失败: {e}")
raise
def embed_text(self, text: str) -> List[float]:
"""
将单个文本转换为向量
Args:
text: 输入文本
Returns:
向量列表
"""
try:
# 模型要求输入为列表
embedding = self.model.encode([text], convert_to_numpy=True)
return embedding[0].tolist()
except Exception as e:
logger.error(f"文本向量化失败: {e}")
raise
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""
批量将文本转换为向量
Args:
texts: 文本列表
Returns:
向量列表的列表
"""
try:
embeddings = self.model.encode(texts, convert_to_numpy=True)
return embeddings.tolist()
except Exception as e:
logger.error(f"批量文本向量化失败: {e}")
raise
def compute_similarity(self, text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 文本1
text2: 文本2
Returns:
余弦相似度,范围[-1, 1]
"""
try:
embedding1 = self.embed_text(text1)
embedding2 = self.embed_text(text2)
# 计算余弦相似度
vec1 = np.array(embedding1)
vec2 = np.array(embedding2)
cosine_sim = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
return float(cosine_sim)
except Exception as e:
logger.error(f"计算相似度失败: {e}")
return 0.0
def get_model_info(self) -> dict:
"""获取模型信息"""
return {
"model_name": self.model_name,
"dimension": self.dimension,
"device": self.device,
"max_sequence_length": self.model.get_max_seq_length()
}
# 创建全局Embedding服务实例
embedding_service = EmbeddingService()
3.8 大模型服务
app/services/llm_service.py
python
"""
大模型调用服务模块
支持多种大模型API的统一调用接口
"""
import logging
import json
from typing import List, Dict, Any, Optional
import httpx
from openai import OpenAI
from app.config import settings
logger = logging.getLogger(__name__)
class LLMService:
"""
大模型服务基类
定义统一的调用接口
"""
def __init__(self, provider: str = None):
self.provider = provider or settings.llm_provider
self._init_client()
def _init_client(self):
"""根据配置初始化客户端"""
if self.provider == "openai":
if not settings.openai_api_key:
raise ValueError("OpenAI API密钥未配置")
self.client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url
)
self.model = settings.openai_model
elif self.provider == "zhipu":
# 智谱AI配置
if not settings.zhipu_api_key:
raise ValueError("智谱AI API密钥未配置")
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=settings.zhipu_api_key)
self.model = settings.zhipu_model
elif self.provider == "qwen":
# 阿里千问配置
if not settings.qwen_api_key:
raise ValueError("千问API密钥未配置")
self.client = OpenAI(
api_key=settings.qwen_api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
self.model = settings.qwen_model
else:
raise ValueError(f"不支持的LLM提供商: {self.provider}")
logger.info(f"大模型客户端初始化成功,提供商: {self.provider}, 模型: {self.model}")
def chat(
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 2000,
**kwargs
) -> str:
"""
调用大模型进行对话
Args:
messages: 消息列表,格式为[{"role": "user", "content": "..."}]
temperature: 温度参数,控制随机性
max_tokens: 最大生成token数
**kwargs: 其他参数
Returns:
模型生成的回复文本
"""
try:
if self.provider == "openai":
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
return response.choices[0].message.content
elif self.provider == "zhipu":
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
return response.choices[0].message.content
elif self.provider == "qwen":
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
return response.choices[0].message.content
else:
raise ValueError(f"不支持的LLM提供商: {self.provider}")
except Exception as e:
logger.error(f"调用大模型失败: {e}")
raise
def stream_chat(
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
**kwargs
):
"""
流式对话,返回生成器
Args:
messages: 消息列表
temperature: 温度参数
**kwargs: 其他参数
Yields:
生成的文本片段
"""
try:
if self.provider == "openai":
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
stream=True,
**kwargs
)
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
elif self.provider == "zhipu":
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
stream=True,
**kwargs
)
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
elif self.provider == "qwen":
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
stream=True,
**kwargs
)
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"流式调用大模型失败: {e}")
yield f"错误: {str(e)}"
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"provider": self.provider,
"model": self.model,
"available": True
}
# 创建全局LLM服务实例
llm_service = LLMService()
3.9 RAG核心服务
app/services/rag_service.py
python
"""
RAG核心服务模块
整合检索和生成,实现检索增强生成
"""
import logging
import uuid
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.orm import Session
from app.services.embedding_service import embedding_service
from app.services.llm_service import llm_service
from app.core.vector_store import vector_store
from app.models.document import Document, DocumentChunk, DocumentStatus
from app.models.conversation import Conversation, Message
from app.config import settings
logger = logging.getLogger(__name__)
class RAGService:
"""
检索增强生成服务
实现完整的RAG流程
"""
def __init__(self):
self.top_k = settings.top_k
self.similarity_threshold = settings.similarity_threshold
def retrieve(
self,
query: str,
top_k: int = None,
similarity_threshold: float = None,
filter_condition: Optional[Dict] = None
) -> List[Dict[str, Any]]:
"""
检索相关文档片段
Args:
query: 查询文本
top_k: 返回结果数量
similarity_threshold: 相似度阈值
filter_condition: 过滤条件
Returns:
检索结果列表
"""
top_k = top_k or self.top_k
threshold = similarity_threshold or self.similarity_threshold
try:
# 1. 将查询文本向量化
query_vector = embedding_service.embed_text(query)
# 2. 在向量数据库中检索
results = vector_store.search(
query_embedding=query_vector,
top_k=top_k,
where=filter_condition
)
# 3. 过滤低于阈值的结果
filtered_results = []
for result in results:
# 距离转相似度(余弦距离转余弦相似度)
similarity = 1 - result['distance']
if similarity >= threshold:
filtered_results.append({
**result,
'similarity': similarity
})
logger.info(f"检索完成,查询: {query[:50]}..., 找到 {len(filtered_results)} 个相关片段")
return filtered_results
except Exception as e:
logger.error(f"检索失败: {e}")
return []
def build_prompt(
self,
query: str,
context: List[Dict[str, Any]],
conversation_history: Optional[List[Dict[str, str]]] = None
) -> List[Dict[str, str]]:
"""
构建提示词
Args:
query: 用户问题
context: 检索到的上下文
conversation_history: 对话历史
Returns:
消息列表
"""
# 系统提示词
system_prompt = """你是一个专业的技术文档助手,你的职责是基于提供的文档内容回答用户的问题。
请遵循以下规则:
1. 只根据提供的文档内容回答问题,不要使用你自己的知识
2. 如果文档中没有相关信息,请明确告知用户"未找到相关信息"
3. 回答要准确、简洁、有条理,可以使用列表形式
4. 如果引用文档内容,请标注来源
5. 回答使用中文
提供的参考文档:
"""
# 添加上下文
context_text = ""
for i, ctx in enumerate(context, 1):
doc_name = ctx.get('metadata', {}).get('filename', '未知文档')
content = ctx.get('document', '')
context_text += f"\n[{i}] 来源:{doc_name}\n{content}\n"
system_prompt += context_text
messages = [{"role": "system", "content": system_prompt}]
# 添加对话历史
if conversation_history:
for msg in conversation_history:
messages.append(msg)
# 添加当前问题
messages.append({"role": "user", "content": query})
return messages
def generate_answer(
self,
query: str,
context: List[Dict[str, Any]],
conversation_history: Optional[List[Dict[str, str]]] = None,
temperature: float = 0.7
) -> str:
"""
生成答案
Args:
query: 用户问题
context: 检索到的上下文
conversation_history: 对话历史
temperature: 生成温度
Returns:
生成的答案
"""
try:
# 构建提示词
messages = self.build_prompt(query, context, conversation_history)
# 调用大模型
answer = llm_service.chat(
messages=messages,
temperature=temperature,
max_tokens=2000
)
logger.info(f"答案生成成功,长度: {len(answer)}")
return answer
except Exception as e:
logger.error(f"生成答案失败: {e}")
return f"生成答案时出错: {str(e)}"
def chat(
self,
question: str,
conversation_id: Optional[int] = None,
db: Session = None,
top_k: int = None,
temperature: float = 0.7
) -> Tuple[str, List[Dict], int, int]:
"""
完整的问答流程
Args:
question: 用户问题
conversation_id: 对话ID
db: 数据库会话
top_k: 检索数量
temperature: 生成温度
Returns:
(答案, 来源列表, 对话ID, 消息ID)
"""
# 1. 检索相关文档
retrieved_docs = self.retrieve(question, top_k)
# 2. 获取对话历史
conversation_history = []
if conversation_id and db:
# 获取最近的对话历史
messages = db.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(Message.created_at.desc()).limit(10).all()
# 反转顺序
messages.reverse()
for msg in messages:
conversation_history.append({
"role": msg.role,
"content": msg.content
})
# 3. 生成答案
answer = self.generate_answer(
query=question,
context=retrieved_docs,
conversation_history=conversation_history,
temperature=temperature
)
# 4. 保存对话记录
if db:
# 如果没有对话ID,创建新对话
if not conversation_id:
conversation = Conversation()
conversation.title = question[:50] # 使用问题前50字作为标题
db.add(conversation)
db.flush()
conversation_id = conversation.id
# 保存用户消息
user_message = Message(
conversation_id=conversation_id,
role="user",
content=question
)
db.add(user_message)
db.flush()
# 保存助手消息
import json
sources = []
for doc in retrieved_docs:
sources.append({
"document_id": doc.get('metadata', {}).get('document_id'),
"document_name": doc.get('metadata', {}).get('filename', '未知'),
"content": doc.get('document', ''),
"similarity": doc.get('similarity', 0)
})
assistant_message = Message(
conversation_id=conversation_id,
role="assistant",
content=answer,
sources=json.dumps(sources, ensure_ascii=False)
)
db.add(assistant_message)
db.flush()
db.commit()
message_id = assistant_message.id
# 5. 格式化来源返回
formatted_sources = []
for doc in retrieved_docs:
formatted_sources.append({
"document_id": doc.get('metadata', {}).get('document_id'),
"document_name": doc.get('metadata', {}).get('filename', '未知'),
"chunk_id": doc.get('metadata', {}).get('chunk_id'),
"content": doc.get('document', ''),
"similarity": doc.get('similarity', 0)
})
return answer, formatted_sources, conversation_id, message_id if db else None
# 创建全局RAG服务实例
rag_service = RAGService()
3.10 异步任务处理
app/core/tasks.py
python
"""
Celery异步任务模块
处理文档上传后的耗时操作
"""
import os
import uuid
import logging
from celery import Celery
from sqlalchemy.orm import Session
from app.config import settings
from app.core.database import SessionLocal
from app.models.document import Document, DocumentChunk, DocumentStatus
from app.utils.file_parser import FileParser
from app.utils.text_splitter import text_splitter
from app.services.embedding_service import embedding_service
from app.core.vector_store import vector_store
from app.core.database import engine
logger = logging.getLogger(__name__)
# 创建Celery应用
celery_app = Celery(
"rag_ai_tasks",
broker=settings.celery_broker_url,
backend=settings.celery_result_backend
)
# 配置Celery
celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
timezone='Asia/Shanghai',
enable_utc=True,
task_track_started=True,
task_time_limit=30 * 60, # 30分钟超时
task_soft_time_limit=25 * 60, # 25分钟软超时
)
@celery_app.task(bind=True, name="process_document")
def process_document(self, document_id: int):
"""
异步处理文档任务
Args:
document_id: 文档ID
"""
db = SessionLocal()
try:
# 1. 获取文档信息
document = db.query(Document).filter(Document.id == document_id).first()
if not document:
logger.error(f"文档不存在: {document_id}")
return {"error": "文档不存在"}
# 更新状态为处理中
document.status = DocumentStatus.PROCESSING
db.commit()
logger.info(f"开始处理文档: {document.filename}, ID: {document_id}")
# 2. 解析文件内容
content = FileParser.parse_file(document.file_path, document.file_type)
if not content:
raise ValueError("文件内容为空")
logger.info(f"文档解析成功,内容长度: {len(content)}")
# 3. 文本切片
chunks = text_splitter.split_document(
content,
metadata={
"document_id": document.id,
"filename": document.filename,
"file_type": document.file_type
}
)
if not chunks:
raise ValueError("文本切片失败,未生成任何切片")
logger.info(f"文本切片完成,共 {len(chunks)} 个切片")
# 4. 生成向量并存储
texts = [chunk["text"] for chunk in chunks]
embeddings = embedding_service.embed_texts(texts)
# 准备向量数据库数据
vector_ids = []
metadatas = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
vector_id = f"doc_{document.id}_chunk_{i}_{uuid.uuid4().hex[:8]}"
vector_ids.append(vector_id)
# 添加chunk_id到元数据
metadata = {
**chunk["metadata"],
"chunk_id": i,
"vector_id": vector_id
}
metadatas.append(metadata)
# 保存到数据库
chunk_record = DocumentChunk(
document_id=document.id,
chunk_index=i,
content=chunk["text"],
vector_id=vector_id
)
db.add(chunk_record)
# 存储到向量数据库
success = vector_store.add_documents(
ids=vector_ids,
embeddings=embeddings,
metadatas=metadatas,
documents=texts
)
if not success:
raise ValueError("向量存储失败")
# 5. 更新文档状态
document.status = DocumentStatus.COMPLETED
document.chunks_count = len(chunks)
db.commit()
logger.info(f"文档处理完成: {document.filename}, 切片数: {len(chunks)}")
return {
"success": True,
"document_id": document_id,
"chunks_count": len(chunks)
}
except Exception as e:
logger.error(f"文档处理失败: {e}")
# 更新文档状态为失败
try:
document = db.query(Document).filter(Document.id == document_id).first()
if document:
document.status = DocumentStatus.FAILED
document.error_message = str(e)
db.commit()
except Exception as db_error:
logger.error(f"更新文档状态失败: {db_error}")
return {
"success": False,
"document_id": document_id,
"error": str(e)
}
finally:
db.close()
@celery_app.task(name="delete_document")
def delete_document(document_id: int):
"""
异步删除文档任务
删除文档相关的所有数据,包括向量数据库中的记录
Args:
document_id: 文档ID
"""
db = SessionLocal()
try:
# 1. 获取文档的切片
chunks = db.query(DocumentChunk).filter(
DocumentChunk.document_id == document_id
).all()
# 2. 从向量数据库中删除
vector_ids = [chunk.vector_id for chunk in chunks if chunk.vector_id]
if vector_ids:
vector_store.delete_documents(vector_ids)
# 3. 删除数据库记录
db.query(DocumentChunk).filter(
DocumentChunk.document_id == document_id
).delete()
db.query(Document).filter(Document.id == document_id).delete()
db.commit()
logger.info(f"文档删除成功: {document_id}")
return {
"success": True,
"document_id": document_id
}
except Exception as e:
logger.error(f"文档删除失败: {e}")
db.rollback()
return {
"success": False,
"document_id": document_id,
"error": str(e)
}
finally:
db.close()
3.11 API接口实现
app/api/v1/documents.py
python
"""
文档管理API接口
提供文档的上传、查询、删除等功能
"""
import os
import shutil
import uuid
from typing import List
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.document import Document, DocumentStatus
from app.models.schemas import (
DocumentInfo, DocumentDetail, DocumentUploadResponse,
DocumentChunkInfo
)
from app.core.tasks import process_document, delete_document
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/documents", tags=["文档管理"])
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(...),
db: Session = Depends(get_db)
):
"""
上传文档
"""
# 1. 验证文件类型
filename = file.filename
ext = os.path.splitext(filename)[1].lower()
if ext not in settings.allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型,仅支持: {', '.join(settings.allowed_extensions)}"
)
# 2. 验证文件大小
file_size = 0
content = await file.read()
file_size = len(content)
if file_size > settings.max_file_size:
raise HTTPException(
status_code=400,
detail=f"文件大小超过限制,最大: {settings.max_file_size / 1024 / 1024}MB"
)
# 3. 保存文件
upload_dir = "uploads"
os.makedirs(upload_dir, exist_ok=True)
# 生成唯一文件名
unique_filename = f"{uuid.uuid4().hex}_{filename}"
file_path = os.path.join(upload_dir, unique_filename)
with open(file_path, "wb") as f:
f.write(content)
# 4. 创建文档记录
document = Document(
filename=filename,
file_path=file_path,
file_size=file_size,
file_type=ext,
title=filename, # 默认标题为文件名
status=DocumentStatus.PENDING
)
db.add(document)
db.commit()
db.refresh(document)
# 5. 触发异步处理任务
process_document.delay(document.id)
logger.info(f"文档上传成功: {filename}, ID: {document.id}")
return DocumentUploadResponse(
id=document.id,
filename=filename,
status=document.status.value,
message="文档已上传,正在后台处理中"
)
@router.get("/", response_model=List[DocumentInfo])
async def get_documents(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
status: str = None,
db: Session = Depends(get_db)
):
"""
获取文档列表
"""
query = db.query(Document)
if status:
query = query.filter(Document.status == status)
documents = query.order_by(Document.create_time.desc()).offset(skip).limit(limit).all()
return [
DocumentInfo(
id=doc.id,
filename=doc.filename,
title=doc.title,
file_size=doc.file_size,
status=doc.status.value,
chunks_count=doc.chunks_count,
create_time=doc.create_time,
update_time=doc.update_time
)
for doc in documents
]
@router.get("/{document_id}", response_model=DocumentDetail)
async def get_document(
document_id: int,
db: Session = Depends(get_db)
):
"""
获取文档详情
"""
document = db.query(Document).filter(Document.id == document_id).first()
if not document:
raise HTTPException(status_code=404, detail="文档不存在")
# 增加查看次数
document.view_count += 1
db.commit()
return DocumentDetail(
id=document.id,
filename=document.filename,
title=document.title,
author=document.author,
file_size=document.file_size,
file_type=document.file_type,
status=document.status.value,
error_message=document.error_message,
chunks_count=document.chunks_count,
create_time=document.create_time,
update_time=document.update_time
)
@router.get("/{document_id}/chunks", response_model=List[DocumentChunkInfo])
async def get_document_chunks(
document_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db)
):
"""
获取文档的切片列表
"""
from app.models.document import DocumentChunk
document = db.query(Document).filter(Document.id == document_id).first()
if not document:
raise HTTPException(status_code=404, detail="文档不存在")
chunks = db.query(DocumentChunk).filter(
DocumentChunk.document_id == document_id
).order_by(DocumentChunk.chunk_index).offset(skip).limit(limit).all()
return [
DocumentChunkInfo(
id=chunk.id,
chunk_index=chunk.chunk_index,
content=chunk.content,
document_id=chunk.document_id
)
for chunk in chunks
]
@router.delete("/{document_id}")
async def delete_document_api(
document_id: int,
db: Session = Depends(get_db)
):
"""
删除文档
"""
document = db.query(Document).filter(Document.id == document_id).first()
if not document:
raise HTTPException(status_code=404, detail="文档不存在")
# 触发异步删除任务
delete_document.delay(document_id)
logger.info(f"文档删除任务已触发: {document_id}")
return {
"success": True,
"message": "文档删除任务已触发"
}
app/api/v1/chat.py
python
"""
智能问答API接口
提供问答、对话管理等接口
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.conversation import Conversation, Message
from app.models.schemas import (
ChatRequest, ChatResponse, SourceDocument,
ConversationInfo, MessageInfo
)
from app.services.rag_service import rag_service
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["智能问答"])
@router.post("/ask", response_model=ChatResponse)
async def ask_question(
request: ChatRequest,
db: Session = Depends(get_db)
):
"""
问答接口
"""
try:
# 调用RAG服务
answer, sources, conversation_id, message_id = rag_service.chat(
question=request.question,
conversation_id=request.conversation_id,
db=db,
top_k=request.top_k,
temperature=request.temperature
)
# 格式化来源
formatted_sources = [
SourceDocument(
document_id=src["document_id"],
document_name=src["document_name"],
chunk_id=src["chunk_id"],
content=src["content"][:500], # 限制来源内容长度
similarity=src["similarity"]
)
for src in sources
]
return ChatResponse(
answer=answer,
sources=formatted_sources,
conversation_id=conversation_id,
message_id=message_id
)
except Exception as e:
logger.error(f"问答失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/conversations", response_model=List[ConversationInfo])
async def get_conversations(
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db)
):
"""
获取对话列表
"""
conversations = db.query(Conversation).order_by(
Conversation.updated_at.desc()
).offset(skip).limit(limit).all()
return [
ConversationInfo(
id=conv.id,
title=conv.title,
created_at=conv.created_at,
updated_at=conv.updated_at
)
for conv in conversations
]
@router.get("/conversations/{conversation_id}", response_model=List[MessageInfo])
async def get_conversation_messages(
conversation_id: int,
db: Session = Depends(get_db)
):
"""
获取对话的消息历史
"""
conversation = db.query(Conversation).filter(
Conversation.id == conversation_id
).first()
if not conversation:
raise HTTPException(status_code=404, detail="对话不存在")
messages = db.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(Message.created_at).all()
return [
MessageInfo(
id=msg.id,
role=msg.role,
content=msg.content,
sources=msg.sources,
created_at=msg.created_at
)
for msg in messages
]
@router.delete("/conversations/{conversation_id}")
async def delete_conversation(
conversation_id: int,
db: Session = Depends(get_db)
):
"""
删除对话
"""
conversation = db.query(Conversation).filter(
Conversation.id == conversation_id
).first()
if not conversation:
raise HTTPException(status_code=404, detail="对话不存在")
db.delete(conversation)
db.commit()
return {"success": True, "message": "对话已删除"}
app/api/v1/knowledge.py
python
"""
知识库管理API接口
提供知识库统计、搜索等功能
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models.document import Document, DocumentStatus
from app.models.schemas import KnowledgeBaseStats, SearchRequest, SearchResult
from app.services.rag_service import rag_service
from app.core.vector_store import vector_store
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/knowledge", tags=["知识库管理"])
@router.get("/stats", response_model=KnowledgeBaseStats)
async def get_knowledge_base_stats(
db: Session = Depends(get_db)
):
"""
获取知识库统计信息
"""
total_documents = db.query(Document).count()
completed_documents = db.query(Document).filter(
Document.status == DocumentStatus.COMPLETED
).count()
failed_documents = db.query(Document).filter(
Document.status == DocumentStatus.FAILED
).count()
processing_documents = db.query(Document).filter(
Document.status == DocumentStatus.PROCESSING
).count()
# 获取向量数据库统计
vector_stats = vector_store.get_collection_stats()
return KnowledgeBaseStats(
total_documents=total_documents,
total_chunks=vector_stats.get('document_count', 0),
completed_documents=completed_documents,
failed_documents=failed_documents,
processing_documents=processing_documents
)
@router.post("/search", response_model=List[SearchResult])
async def search_knowledge(
request: SearchRequest,
db: Session = Depends(get_db)
):
"""
知识库搜索
"""
try:
results = rag_service.retrieve(
query=request.query,
top_k=request.top_k,
similarity_threshold=request.similarity_threshold
)
search_results = []
for result in results:
search_results.append(
SearchResult(
content=result['document'],
document_id=result.get('metadata', {}).get('document_id'),
document_name=result.get('metadata', {}).get('filename', '未知'),
chunk_id=result.get('metadata', {}).get('chunk_id'),
similarity=result.get('similarity', 0)
)
)
return search_results
except Exception as e:
logger.error(f"知识库搜索失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
app/main.py
python
"""
FastAPI应用入口
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import logging
from app.config import settings
from app.core.database import init_database
from app.api.v1 import documents, chat, knowledge
# 配置日志
logging.basicConfig(
level=logging.INFO if not settings.debug else logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description="基于RAG架构的智能问答系统API",
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境应配置具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(documents.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(knowledge.router, prefix="/api/v1")
# 挂载静态文件
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.on_event("startup")
async def startup_event():
"""应用启动时的初始化"""
logger.info("正在启动应用...")
# 初始化数据库
init_database()
logger.info("应用启动完成")
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时的清理"""
logger.info("应用正在关闭...")
@app.get("/")
async def root():
"""根路径"""
return {
"name": settings.app_name,
"version": settings.app_version,
"status": "running"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy"}
四、客户端编码
4.1 项目结构
python
rag-ai-system-frontend/
├── public/
│ └── index.html
├── src/
│ ├── assets/ # 静态资源
│ ├── components/ # 组件
│ │ ├── chat/
│ │ │ ├── ChatWindow.vue # 聊天窗口
│ │ │ ├── MessageItem.vue # 消息项
│ │ │ └── SourcePanel.vue # 来源面板
│ │ ├── document/
│ │ │ ├── DocumentList.vue # 文档列表
│ │ │ ├── DocumentUpload.vue # 文档上传
│ │ │ └── DocumentDetail.vue # 文档详情
│ │ └── common/
│ │ ├── Loading.vue # 加载组件
│ │ └── MarkdownRenderer.vue # Markdown渲染
│ ├── views/
│ │ ├── ChatView.vue # 问答页面
│ │ ├── DocumentsView.vue # 文档管理页面
│ │ └── KnowledgeView.vue # 知识库管理页面
│ ├── router/
│ │ └── index.js # 路由配置
│ ├── stores/
│ │ ├── chat.js # 对话状态管理
│ │ ├── document.js # 文档状态管理
│ │ └── knowledge.js # 知识库状态管理
│ ├── api/
│ │ ├── chat.js # 问答API
│ │ ├── document.js # 文档API
│ │ └── knowledge.js # 知识库API
│ ├── utils/
│ │ ├── request.js # Axios封装
│ │ └── markdown.js # Markdown工具
│ ├── App.vue
│ └── main.js
├── package.json
├── vite.config.js
└── .env.development
4.2 环境配置
.env.development
bash
VITE_API_BASE_URL=http://localhost:8000/api/v1
VITE_WS_URL=ws://localhost:8000/ws
package.json
python
{
"name": "rag-ai-system-frontend",
"version": "1.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"vue": "^3.3.8",
"vue-router": "^4.2.5",
"pinia": "^2.1.7",
"axios": "^1.6.2",
"element-plus": "^2.4.3",
"@element-plus/icons-vue": "^2.1.0",
"markdown-it": "^13.0.2",
"highlight.js": "^11.9.0",
"dayjs": "^1.11.10"
},
"devDependencies": {
"@vitejs/plugin-vue": "^4.5.0",
"vite": "^5.0.0"
}
}
vite.config.js
javascript
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import path from 'path'
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': path.resolve(__dirname, './src')
}
},
server: {
port: 3000,
proxy: {
'/api': {
target: 'http://localhost:8000',
changeOrigin: true
}
}
}
})
4.3 主入口文件
src/main.js
javascript
import { createApp } from 'vue'
import { createPinia } from 'pinia'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
import App from './App.vue'
import router from './router'
const app = createApp(App)
// 注册Element Plus图标
for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
app.component(key, component)
}
app.use(createPinia())
app.use(router)
app.use(ElementPlus)
app.mount('#app')
src/App.vue
bash
<template>
<div id="app">
<el-container>
<el-header>
<div class="header-content">
<h1>
<el-icon><ChatDotRound /></el-icon>
RAG智能问答系统
</h1>
<div class="nav">
<el-menu mode="horizontal" :router="true">
<el-menu-item index="/chat">
<el-icon><ChatLineRound /></el-icon>
智能问答
</el-menu-item>
<el-menu-item index="/documents">
<el-icon><Document /></el-icon>
文档管理
</el-menu-item>
<el-menu-item index="/knowledge">
<el-icon><DataAnalysis /></el-icon>
知识库管理
</el-menu-item>
</el-menu>
</div>
</div>
</el-header>
<el-main>
<router-view />
</el-main>
</el-container>
</div>
</template>
<script setup>
import { ChatDotRound, ChatLineRound, Document, DataAnalysis } from '@element-plus/icons-vue'
</script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
#app {
height: 100vh;
display: flex;
flex-direction: column;
}
.el-header {
background-color: #409EFF;
color: white;
padding: 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.header-content {
display: flex;
align-items: center;
justify-content: space-between;
height: 100%;
padding: 0 20px;
}
.header-content h1 {
font-size: 20px;
display: flex;
align-items: center;
gap: 10px;
}
.el-menu--horizontal {
background-color: transparent;
border-bottom: none;
}
.el-menu--horizontal .el-menu-item {
color: white;
}
.el-menu--horizontal .el-menu-item.is-active {
color: white;
border-bottom-color: white;
}
.el-main {
background-color: #f5f7fa;
padding: 20px;
overflow-y: auto;
}
</style>
4.4 API封装
src/utils/request.js
javascript
import axios from 'axios'
import { ElMessage } from 'element-plus'
// 创建axios实例
const request = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL,
timeout: 60000,
headers: {
'Content-Type': 'application/json'
}
})
// 请求拦截器
request.interceptors.request.use(
config => {
// 可以在这里添加token等认证信息
return config
},
error => {
return Promise.reject(error)
}
)
// 响应拦截器
request.interceptors.response.use(
response => {
return response.data
},
error => {
const message = error.response?.data?.detail || error.message || '请求失败'
ElMessage.error(message)
return Promise.reject(error)
}
)
export default request
src/api/chat.js
javascript
import request from '@/utils/request'
/**
* 问答接口
* @param {Object} data - 请求数据
* @returns {Promise}
*/
export const askQuestion = (data) => {
return request.post('/chat/ask', data)
}
/**
* 获取对话列表
* @param {Object} params - 分页参数
* @returns {Promise}
*/
export const getConversations = (params) => {
return request.get('/chat/conversations', { params })
}
/**
* 获取对话消息
* @param {number} conversationId - 对话ID
* @returns {Promise}
*/
export const getConversationMessages = (conversationId) => {
return request.get(`/chat/conversations/${conversationId}`)
}
/**
* 删除对话
* @param {number} conversationId - 对话ID
* @returns {Promise}
*/
export const deleteConversation = (conversationId) => {
return request.delete(`/chat/conversations/${conversationId}`)
}
src/api/document.js
javascript
import request from '@/utils/request'
/**
* 上传文档
* @param {FormData} formData - 表单数据
* @returns {Promise}
*/
export const uploadDocument = (formData) => {
return request.post('/documents/upload', formData, {
headers: {
'Content-Type': 'multipart/form-data'
}
})
}
/**
* 获取文档列表
* @param {Object} params - 查询参数
* @returns {Promise}
*/
export const getDocuments = (params) => {
return request.get('/documents/', { params })
}
/**
* 获取文档详情
* @param {number} id - 文档ID
* @returns {Promise}
*/
export const getDocumentDetail = (id) => {
return request.get(`/documents/${id}`)
}
/**
* 获取文档切片
* @param {number} id - 文档ID
* @param {Object} params - 分页参数
* @returns {Promise}
*/
export const getDocumentChunks = (id, params) => {
return request.get(`/documents/${id}/chunks`, { params })
}
/**
* 删除文档
* @param {number} id - 文档ID
* @returns {Promise}
*/
export const deleteDocument = (id) => {
return request.delete(`/documents/${id}`)
}
src/api/knowledge.js
javascript
import request from '@/utils/request'
/**
* 获取知识库统计
* @returns {Promise}
*/
export const getKnowledgeStats = () => {
return request.get('/knowledge/stats')
}
/**
* 搜索知识库
* @param {Object} data - 搜索参数
* @returns {Promise}
*/
export const searchKnowledge = (data) => {
return request.post('/knowledge/search', data)
}
4.5 状态管理
src/stores/chat.js
javascript
import { defineStore } from 'pinia'
import { ref } from 'vue'
import { askQuestion, getConversations, getConversationMessages } from '@/api/chat'
export const useChatStore = defineStore('chat', () => {
// 状态
const conversations = ref([])
const currentConversation = ref(null)
const messages = ref([])
const loading = ref(false)
// 发送消息
const sendMessage = async (question, conversationId = null, temperature = 0.7) => {
loading.value = true
try {
const res = await askQuestion({
question,
conversation_id: conversationId,
temperature
})
// 添加用户消息
messages.value.push({
id: Date.now(),
role: 'user',
content: question,
created_at: new Date().toISOString()
})
// 添加助手消息
messages.value.push({
id: res.message_id,
role: 'assistant',
content: res.answer,
sources: res.sources,
created_at: new Date().toISOString()
})
// 更新当前对话ID
if (!currentConversation.value || currentConversation.value.id !== res.conversation_id) {
currentConversation.value = {
id: res.conversation_id,
title: question.slice(0, 50)
}
// 刷新对话列表
await fetchConversations()
}
return res
} finally {
loading.value = false
}
}
// 获取对话列表
const fetchConversations = async () => {
try {
const res = await getConversations({ skip: 0, limit: 50 })
conversations.value = res
} catch (error) {
console.error('获取对话列表失败:', error)
}
}
// 加载对话消息
const loadConversation = async (conversationId) => {
try {
const res = await getConversationMessages(conversationId)
messages.value = res
currentConversation.value = conversations.value.find(c => c.id === conversationId)
} catch (error) {
console.error('加载对话失败:', error)
}
}
// 清空当前对话
const clearCurrentConversation = () => {
currentConversation.value = null
messages.value = []
}
return {
conversations,
currentConversation,
messages,
loading,
sendMessage,
fetchConversations,
loadConversation,
clearCurrentConversation
}
})
4.6 核心组件实现
src/components/chat/ChatWindow.vue
javascript
<template>
<div class="chat-window">
<div class="messages-container" ref="messagesContainer">
<div v-if="messages.length === 0" class="empty-state">
<el-icon :size="64"><ChatDotRound /></el-icon>
<h3>开始对话</h3>
<p>输入您的问题,我将从知识库中为您寻找答案</p>
</div>
<MessageItem
v-for="message in messages"
:key="message.id"
:message="message"
/>
<div v-if="loading" class="loading-message">
<el-skeleton animated>
<template #template>
<el-skeleton-item variant="text" style="width: 80%" />
<el-skeleton-item variant="text" style="width: 60%" />
</template>
</el-skeleton>
</div>
</div>
<div class="input-area">
<el-input
v-model="inputText"
type="textarea"
:rows="3"
placeholder="请输入您的问题..."
:disabled="loading"
@keydown.ctrl.enter="send"
/>
<div class="input-actions">
<div class="settings">
<el-slider
v-model="temperature"
:min="0"
:max="2"
:step="0.1"
:format-tooltip="formatTemperature"
style="width: 200px"
/>
<span class="label">温度: {{ temperature }}</span>
</div>
<el-button
type="primary"
:loading="loading"
@click="send"
>
发送
</el-button>
</div>
</div>
</div>
</template>
<script setup>
import { ref, nextTick, watch } from 'vue'
import { useChatStore } from '@/stores/chat'
import MessageItem from './MessageItem.vue'
import { ChatDotRound } from '@element-plus/icons-vue'
const chatStore = useChatStore()
const messages = chatStore.messages
const loading = chatStore.loading
const inputText = ref('')
const temperature = ref(0.7)
const messagesContainer = ref(null)
// 发送消息
const send = async () => {
if (!inputText.value.trim() || loading.value) return
const question = inputText.value
inputText.value = ''
await chatStore.sendMessage(
question,
chatStore.currentConversation?.id,
temperature.value
)
// 滚动到底部
await nextTick()
scrollToBottom()
}
// 滚动到底部
const scrollToBottom = () => {
if (messagesContainer.value) {
messagesContainer.value.scrollTop = messagesContainer.value.scrollHeight
}
}
// 格式化温度显示
const formatTemperature = (val) => {
return `随机性: ${val}`
}
// 监听消息变化,自动滚动
watch(messages, () => {
nextTick(() => {
scrollToBottom()
})
}, { deep: true })
</script>
<style scoped>
.chat-window {
display: flex;
flex-direction: column;
height: 100%;
background: white;
border-radius: 8px;
overflow: hidden;
}
.messages-container {
flex: 1;
overflow-y: auto;
padding: 20px;
}
.empty-state {
text-align: center;
padding: 60px 20px;
color: #909399;
}
.empty-state .el-icon {
font-size: 64px;
margin-bottom: 20px;
}
.empty-state h3 {
margin-bottom: 10px;
color: #606266;
}
.input-area {
border-top: 1px solid #e4e7ed;
padding: 20px;
background: #f5f7fa;
}
.input-actions {
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 12px;
}
.settings {
display: flex;
align-items: center;
gap: 12px;
}
.settings .label {
font-size: 12px;
color: #909399;
}
.loading-message {
padding: 12px;
background: #f5f7fa;
border-radius: 8px;
margin: 10px 0;
max-width: 70%;
}
</style>
src/components/chat/MessageItem.vue
javascript
<template>
<div class="message-item" :class="message.role">
<div class="message-avatar">
<el-avatar :size="40">
<el-icon v-if="message.role === 'user'"><User /></el-icon>
<el-icon v-else><Service /></el-icon>
</el-avatar>
</div>
<div class="message-content">
<div class="message-header">
<span class="role-name">{{ message.role === 'user' ? '我' : 'AI助手' }}</span>
<span class="time">{{ formatTime(message.created_at) }}</span>
</div>
<div class="message-body">
<MarkdownRenderer :content="message.content" />
<!-- 来源信息 -->
<div v-if="message.sources && message.sources.length" class="sources">
<el-divider content-position="left">
<el-icon><Link /></el-icon>
参考来源
</el-divider>
<SourcePanel :sources="message.sources" />
</div>
</div>
</div>
</div>
</template>
<script setup>
import { User, Service, Link } from '@element-plus/icons-vue'
import MarkdownRenderer from '@/components/common/MarkdownRenderer.vue'
import SourcePanel from './SourcePanel.vue'
import dayjs from 'dayjs'
const props = defineProps({
message: {
type: Object,
required: true
}
})
const formatTime = (time) => {
if (!time) return ''
return dayjs(time).format('HH:mm:ss')
}
</script>
<style scoped>
.message-item {
display: flex;
margin-bottom: 20px;
animation: fadeIn 0.3s ease;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.message-item.user {
flex-direction: row-reverse;
}
.message-item.user .message-content {
background: #409EFF;
color: white;
margin-right: 12px;
margin-left: 60px;
}
.message-item.assistant .message-content {
background: #f5f7fa;
margin-left: 12px;
margin-right: 60px;
}
.message-avatar {
flex-shrink: 0;
}
.message-content {
flex: 1;
border-radius: 12px;
padding: 12px 16px;
word-wrap: break-word;
}
.message-header {
display: flex;
justify-content: space-between;
margin-bottom: 8px;
font-size: 12px;
}
.message-item.user .message-header {
color: rgba(255, 255, 255, 0.8);
}
.message-item.assistant .message-header {
color: #909399;
}
.role-name {
font-weight: bold;
}
.message-body {
line-height: 1.6;
}
.sources {
margin-top: 12px;
font-size: 12px;
}
.message-item.user .sources {
color: rgba(255, 255, 255, 0.8);
}
</style>
src/components/chat/SourcePanel.vue
javascript
<template>
<div class="source-panel">
<div
v-for="(source, index) in sources"
:key="index"
class="source-item"
>
<div class="source-header">
<el-icon><Document /></el-icon>
<span class="source-name">{{ source.document_name }}</span>
<el-tag size="small" type="info">
相似度: {{ (source.similarity * 100).toFixed(1) }}%
</el-tag>
</div>
<div class="source-content">
{{ truncateContent(source.content, 200) }}
</div>
<div class="source-actions">
<el-button
size="small"
text
@click="showFullContent(source)"
>
查看详情
</el-button>
</div>
</div>
<!-- 详情对话框 -->
<el-dialog
v-model="dialogVisible"
title="参考来源详情"
width="60%"
>
<div class="dialog-content">
<div class="dialog-meta">
<p><strong>文档名称:</strong> {{ currentSource?.document_name }}</p>
<p><strong>相似度:</strong> {{ currentSource ? (currentSource.similarity * 100).toFixed(1) + '%' : '' }}</p>
</div>
<el-divider />
<div class="dialog-text">
{{ currentSource?.content }}
</div>
</div>
</el-dialog>
</div>
</template>
<script setup>
import { ref } from 'vue'
import { Document } from '@element-plus/icons-vue'
const props = defineProps({
sources: {
type: Array,
default: () => []
}
})
const dialogVisible = ref(false)
const currentSource = ref(null)
const truncateContent = (content, maxLength) => {
if (!content) return ''
if (content.length <= maxLength) return content
return content.slice(0, maxLength) + '...'
}
const showFullContent = (source) => {
currentSource.value = source
dialogVisible.value = true
}
</script>
<style scoped>
.source-panel {
margin-top: 8px;
}
.source-item {
background: rgba(255, 255, 255, 0.5);
border-radius: 8px;
padding: 12px;
margin-bottom: 8px;
border-left: 3px solid #409EFF;
}
.source-header {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 8px;
flex-wrap: wrap;
}
.source-name {
font-weight: bold;
font-size: 13px;
color: #409EFF;
}
.source-content {
font-size: 12px;
color: #606266;
line-height: 1.5;
margin-bottom: 8px;
word-break: break-all;
}
.source-actions {
text-align: right;
}
.dialog-content {
max-height: 500px;
overflow-y: auto;
}
.dialog-meta {
margin-bottom: 16px;
}
.dialog-meta p {
margin: 8px 0;
}
.dialog-text {
white-space: pre-wrap;
line-height: 1.6;
font-size: 14px;
}
</style>
五、总结
通过本文的完整实现,我们从零构建了一个功能完整的企业级RAG智能问答系统。整个项目涵盖了前后端开发的各个环节,展示了如何将理论知识转化为实际可用的产品。让我们回顾一下本项目的主要成果和收获。
5.1 技术要点回顾
1. RAG架构的核心实现
文档处理流水线:从上传、解析、切片到向量化存储的完整流程
检索模块:基于向量相似度的语义搜索,支持相似度阈值过滤
生成模块:通过精心设计的Prompt,将检索结果作为上下文增强大模型回答
2. 后端技术实践
FastAPI框架的应用:异步处理、依赖注入、自动文档
向量数据库集成:ChromaDB的封装和操作
异步任务处理:Celery处理耗时操作,提升用户体验
多模型支持:统一接口支持OpenAI、智谱AI、千问等多种大模型
3. 前端技术实践
Vue3 Composition API:响应式状态管理
Element Plus组件库:快速构建美观界面
Markdown渲染:支持富文本显示
流式交互:提升用户体验
4. 工程化实践
模块化设计:清晰的代码组织结构
配置管理:环境变量分离,便于部署
错误处理:完善的异常捕获和日志记录
数据模型:合理的关系型数据库设计
5.2 系统优势与特点
知识时效性强:通过动态更新知识库,回答始终基于最新文档
可解释性好:提供答案溯源,用户可以查看参考来源
扩展性高:支持添加新的文档类型、切换不同的Embedding模型和LLM
用户体验佳:多轮对话、对话历史、流式输出等功能
部署灵活:支持Docker容器化部署,可快速迁移
5.3 应用场景扩展
本系统的架构具有良好的通用性,可以应用于多种场景:
企业知识库:整合公司内部文档、规章制度、产品手册
客服机器人:基于产品FAQ和操作手册构建智能客服
技术文档助手:帮助开发者快速查找API文档和技术方案
法律文书检索:在法律文档库中快速定位相关条款
医疗知识问答:基于医学文献和临床指南构建辅助系统
5.4 优化方向
虽然当前系统已经具备完整功能,但仍有优化空间:
性能优化
实现向量检索的缓存机制
使用GPU加速Embedding计算
优化大模型调用的并发处理
检索质量提升
引入混合检索(向量检索+关键词检索)
实现重排序(Rerank)机制
添加查询改写和意图识别
功能增强
支持更多文档格式(Excel、PPT、图片OCR)
实现文档自动摘要
添加用户反馈和模型微调功能
安全加固
添加用户认证和权限管理
实现API密钥加密存储
增加内容安全过滤
5.5 学习建议
对于想要深入学习RAG技术的读者,建议:
-
理解原理:深入学习Transformer架构、Attention机制、向量检索原理
-
动手实践:尝试修改代码,如添加新的文档解析器、更换Embedding模型
-
关注前沿:关注RAG领域的最新研究,如Self-RAG、RAPTOR等
-
参与开源:贡献代码到开源项目,在实践中提升
5.6 结语
RAG技术的出现,为解决大语言模型的知识局限提供了有效途径。通过本文的项目实践,我们不仅掌握了RAG系统的开发技能,更重要的是理解了如何将AI技术落地到实际业务场景中。随着大模型技术的不断发展,RAG架构也将持续演进,相信在不久的将来,会有更多创新的应用场景涌现。
希望本文能够帮助读者建立起RAG系统开发的完整知识体系,并能够在此基础上进行创新和扩展。技术的学习永无止境,让我们保持探索的热情,在AI技术的浪潮中不断前行。
感谢您的阅读!如果您在实践过程中遇到任何问题,欢迎交流讨论。
附录:
一、后端项目依赖安装
完整依赖列表(requirements.txt)
javascript
# Web框架
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
# 数据库和向量存储
sqlalchemy==2.0.23
chromadb==0.4.18
pymysql==1.1.0
# 异步任务
celery==5.3.4
redis==5.0.1
# AI相关
openai==1.3.0
sentence-transformers==2.2.2
torch==2.1.0
# 文档解析
PyPDF2==3.0.1
python-docx==1.1.0
markdown==3.5.1
beautifulsoup4==4.12.2
# 工具库
python-dotenv==1.0.0
pydantic==2.5.0
pydantic-settings==2.1.0
httpx==0.25.1
loguru==0.7.2
安装命令
bash
# 1. 创建虚拟环境(推荐)
python -m venv venv
# 2. 激活虚拟环境
# Windows:
venv\Scripts\activate
# Linux/Mac:
source venv/bin/activate
# 3. 安装依赖
pip install -r requirements.txt
# 4. 如果使用GPU版本的PyTorch(可选)
# CUDA 11.8版本:
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118
# CUDA 12.1版本:
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121
额外依赖说明
如果使用智谱AI,需要额外安装:
bash
pip install zhipuai
环境配置文件
创建 .env 文件:
bash
# 应用配置
APP_NAME="RAG AI System"
APP_VERSION="1.0.0"
DEBUG=True
SECRET_KEY="your-secret-key-here"
# 数据库配置(使用SQLite进行本地开发)
DATABASE_URL="sqlite:///./rag_ai.db"
REDIS_URL="redis://localhost:6379/0"
# 向量数据库配置
CHROMA_PERSIST_DIR="./chroma_data"
CHROMA_COLLECTION_NAME="rag_knowledge_base"
# Embedding模型配置
EMBEDDING_MODEL_NAME="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
EMBEDDING_DEVICE="cpu"
# LLM配置(选择一种即可)
LLM_PROVIDER="openai"
OPENAI_API_KEY="your-openai-api-key"
OPENAI_BASE_URL="https://api.openai.com/v1"
OPENAI_MODEL="gpt-3.5-turbo"
# 文档处理配置
CHUNK_SIZE=500
CHUNK_OVERLAP=50
MAX_FILE_SIZE=10485760
ALLOWED_EXTENSIONS=".txt,.md,.pdf,.docx"
# 检索配置
TOP_K=5
SIMILARITY_THRESHOLD=0.7
# Celery配置
CELERY_BROKER_URL="redis://localhost:6379/0"
CELERY_RESULT_BACKEND="redis://localhost:6379/1"
启动后端服务
bash
# 1. 启动Redis(用于Celery任务队列)
# 如果使用Docker:
docker run -d -p 6379:6379 redis
# 或者本地安装redis-server
# 2. 启动Celery Worker(新终端窗口)
celery -A app.core.tasks worker --loglevel=info
# 3. 启动FastAPI应用
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
二、前端项目依赖安装
完整依赖列表(package.json)
bash
{
"name": "rag-ai-system-frontend",
"version": "1.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"vue": "^3.3.8",
"vue-router": "^4.2.5",
"pinia": "^2.1.7",
"axios": "^1.6.2",
"element-plus": "^2.4.3",
"@element-plus/icons-vue": "^2.1.0",
"markdown-it": "^13.0.2",
"highlight.js": "^11.9.0",
"dayjs": "^1.11.10"
},
"devDependencies": {
"@vitejs/plugin-vue": "^4.5.0",
"vite": "^5.0.0"
}
}
安装命令
bash
# 1. 进入前端项目目录
cd rag-ai-system-frontend
# 2. 安装依赖
npm install
# 或者使用yarn
yarn install
# 或者使用pnpm(推荐)
pnpm install
环境配置文件
创建 .env.development 文件:
bash
VITE_API_BASE_URL=http://localhost:8000/api/v1
VITE_WS_URL=ws://localhost:8000/ws
创建 .env.production 文件(生产环境):
bash
VITE_API_BASE_URL=https://your-domain.com/api/v1
VITE_WS_URL=wss://your-domain.com/ws
启动前端服务
bash
# 开发模式
npm run dev
# 构建生产版本
npm run build
# 预览生产构建
npm run preview
依赖版本说明
关键依赖说明
| 依赖 | 版本 | 说明 |
|---|---|---|
| FastAPI | 0.104.1 | 高性能Web框架,支持异步 |
| SQLAlchemy | 2.0.23 | ORM框架,数据库操作 |
| ChromaDB | 0.4.18 | 向量数据库,存储文档向量 |
| Celery | 5.3.4 | 异步任务队列 |
| Sentence-Transformers | 2.2.2 | 文本向量化模型 |
| PyTorch | 2.1.0 | 深度学习框架,用于向量化 |
| Element Plus | 2.4.3 | Vue3 UI组件库 |
| Vite | 5.0.0 | 前端构建工具 |
版本兼容性说明
-
Python: 3.9 - 3.11
-
Node.js: 18.x 或更高版本
-
Redis: 6.x 或更高版本
-
MySQL: 5.7+ / 8.0+(可选,SQLite可替代)
常见问题解决
1. PyTorch安装失败
bash
# 使用国内镜像源
pip install torch==2.1.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 或安装CPU版本(更快)
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
2. ChromaDB安装问题
bash
# 如果安装失败,尝试先安装依赖
pip install numpy==1.24.3
pip install chromadb==0.4.18
3. Node依赖安装慢
bash
# 使用淘宝镜像
npm config set registry https://registry.npmmirror.com
# 或使用cnpm
npm install -g cnpm --registry=https://registry.npmmirror.com
cnpm install
4. Redis连接问题
bash
# 检查Redis是否运行
redis-cli ping
# 应该返回 PONG
# 如果未安装Redis
# Ubuntu/Debian:
sudo apt-get install redis-server
# MacOS:
brew install redis
# Windows:
# 下载 https://github.com/microsoftarchive/redis/releases
🌟 感谢您耐心阅读到这里!
🚀 技术成长没有捷径,但每一次的阅读、思考和实践,都在默默缩短您与成功的距离。
💡 如果本文对您有所启发,欢迎点赞👍、收藏📌、分享📤给更多需要的伙伴!
🗣️ 期待在评论区看到您的想法、疑问或建议,我会认真回复,让我们共同探讨、一起进步~
🔔 关注我,持续获取更多干货内容!
🤗 我们下篇文章见!
