使用FastAPI为知识库问答系统前端提供后端功能接口

后端接口实现以及接口调用的类代码一览

  • [1. 后端接口代码](#1. 后端接口代码)
  • [2. 代码结构概述](#2. 代码结构概述)
  • [3. 主要功能模块](#3. 主要功能模块)
  • [4. 连接方式](#4. 连接方式)

1. 后端接口代码

python 复制代码
# app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict
import uvicorn
from user_database import UserDatabase
from ModelResponse import ModelResponse

app = FastAPI()

# 允许跨域访问(适配 Gradio 调用 FastAPI)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

user_db = UserDatabase()
response = ModelResponse()


class LoginRequest(BaseModel):
    username: str
    password: str


class RegisterRequest(BaseModel):
    username: str
    password: str


class ChatRequest(BaseModel):
    user_input: str
    chat_history: List[dict]


class ChatResponse(BaseModel):
    status: str
    response: str


@app.post("/login")
async def login(credentials: LoginRequest):
    """登录接口,验证用户名和密码。"""
    if user_db.verify_user(credentials.username, credentials.password):
        return {
            "status": "success",
            "message": "Login successful"
        }
    else:
        raise HTTPException(status_code=401, detail="Invalid username or password")


@app.post("/register")
async def register(user: RegisterRequest):
    """注册接口,添加用户。"""
    if user_db.add_user(user.username, user.password):
        return {
            "status": "success",
            "message": f"User '{user.username}' registered successfully."
        }
    else:
        raise HTTPException(status_code=400, detail="User already exists.")


@app.post("/chat")
async def chat(request: ChatRequest) -> ChatResponse:
    try:
        user_input = request.user_input
        chat_history = request.chat_history

        bot_response = response.ask(user_input, chat_history)

        updated_chat_history = chat_history + [[user_input, bot_response["answer"]]]

        return ChatResponse(
            status="success",
            response=bot_response["answer"],
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

2. 代码结构概述

app.py代码是一个基于 FastAPI 的后端服务,旨在为前端提供接口,支持用户登录、注册以及与聊天机器人进行交互的功能。以下是代码的详细功能介绍:

  1. 代码结构概述
    FastAPI:一个现代、高性能的 Python Web 框架,用于构建 API。
    CORS 中间件:允许跨域请求,方便前端(如 Gradio、React 等)调用后端 API。
    用户数据库:通过 UserDatabase 类管理用户的登录和注册。
    聊天机器人:通过 ModelResponse 类实现基于 LLM(大语言模型)的问答功能。
    API 接口:提供了 /login、/register 和 /chat 三个接口,分别用于用户登录、注册和聊天交互。

3. 主要功能模块

1. 跨域支持

python 复制代码
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

可以允许前端从任何域名访问后端API

2. 用户登录接口(/login)

python 复制代码
@app.post("/login")
async def login(credentials: LoginRequest):
    if user_db.verify_user(credentials.username, credentials.password):
        return {
            "status": "success",
            "message": "Login successful"
        }
    else:
        raise HTTPException(status_code=401, detail="Invalid username or password")

用于验证用户的用户名和密码。通过UserDatabase类中的verify_user方法实现。如果验证成功,返回 {"status": "success", "message": "Login successful"}。这是前端期望的数据格式,一定要注意前端期望后端返回什么样的数据类型!!接口接受和发送的数据格式最好在工程定框架的时候就定死,不要轻易改动。如果验证失败,返回 401 状态码和错误信息 "Invalid username or password"

3. 用户注册接口(/register)

python 复制代码
@app.post("/register")
async def register(user: RegisterRequest):
    if user_db.add_user(user.username, user.password):
        return {
            "status": "success",
            "message": f"User '{user.username}' registered successfully."
        }
    else:
        raise HTTPException(status_code=400, detail="User already exists.")

功能是注册新用户。通过UserDatabase类中的add_user方法实现。如果注册成功,返回 {"status": "success", "message": "User registered successfully."}。如果用户名已存在,返回 400 状态码和错误信息 "User already exists."

4.用户相关接口依赖的类

下面是实现接口中方法的类。

python 复制代码
# user_database.py
import sqlite3
import yaml

with open("config.yaml", "r", encoding="utf-8") as f:
    config = yaml.safe_load(f)


class UserDatabase:
    def __init__(self):
        """
        初始化 UserDatabase,使用本地 SQLite 数据库。
        """
        self.db_path = config["database"]["user_db_path"]
        self._init_db()

    def _init_db(self):
        """初始化用户信息库"""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS user (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    username TEXT UNIQUE NOT NULL,
                    password TEXT NOT NULL
                )
            ''')
            conn.commit()
            print("User database initialized.")

    def add_user(self, username, password):
        """添加用户到用户信息库"""
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                # 检查用户是否已存在
                cursor.execute('SELECT username FROM user WHERE username = ?', (username,))
                if cursor.fetchone():
                    print(f"User '{username}' already exists.")
                    return False  # 用户已存在,返回 False
                # 插入新用户
                cursor.execute('''
                    INSERT INTO user (username, password) VALUES (?, ?)
                ''', (username, password))
                conn.commit()
                print(f"User '{username}' added to the database.")
                return True  # 用户添加成功,返回 True
        except sqlite3.Error as e:
            print(f"Database error: {e}")
            return False  # 数据库操作失败,返回 False

    def get_user_by_id(self, user_id):
        """根据用户 ID 查询用户信息"""
        with sqlite3.connect(self.db_path) as conn:
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()
            cursor.execute('''
                SELECT * FROM user WHERE id = ?
            ''', (user_id,))
            row = cursor.fetchone()  # 只调用一次 fetchone
            return dict(row) if row else None

    def get_all_users(self):
        """获取所有用户信息"""
        with sqlite3.connect(self.db_path) as conn:
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()
            cursor.execute('SELECT * FROM user')
            return [dict(row) for row in cursor.fetchall()]

    def verify_user(self, username, password):
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('SELECT * FROM user WHERE username = ? AND password = ?', (username, password))
            return cursor.fetchone() is not None

5.聊天接口(/chat)

python 复制代码
@app.post("/chat")
async def chat(request: ChatRequest) -> ChatResponse:
    try:
        user_input = request.user_input
        chat_history = request.chat_history

        bot_response = response.ask(user_input, chat_history)

        updated_chat_history = chat_history + [[user_input, bot_response["answer"]]]

        return ChatResponse(
            status="success",
            response=bot_response["answer"],
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

功能是接收用户输入和聊天历史,调用response类中的ask方法生成回答。如果成功,返回 {"status": "success", "response": "生成的回答"}。如果发生错误,返回 500 状态码和错误信息。

6.聊天接口依赖的类

python 复制代码
import sqlite3
import numpy as np
import faiss
import requests
from typing import List, Dict
import yaml


def build_final_prompt(query: str, chat_history: List[Dict[str, str]],
                       relevant_docs: List[Dict[str, str]]) -> str:
    """构建最终的提示(final_prompt),包含对话历史和相关文档内容。"""
    # 生成对话历史上下文
    full_prompt = []
    for line in chat_history:
        role = line.get("role")
        content = line.get("content")
        if role == "user":
            current_prompt = f"用户:{content}"
        elif role == "assistant":
            current_prompt = f"助手:{content}"
        else:
            raise Exception(f"无法支持的角色类型, {role}")
        full_prompt.append(current_prompt)
    full_prompt = "\n".join(full_prompt)

    doc_context = "\n".join([doc["text"] for doc in relevant_docs[:3]])

    # 构建最终提示
    final_prompt = (
        f"你是一个助手,帮助用户解答问题。\n"
        f"背景资料:\n{doc_context}\n\n"
        f"对话历史:\n{full_prompt}\n\n"
        f"用户的问题:\n{query}"
    )
    return final_prompt


def build_faiss_index(embeddings: np.ndarray):
    """使用从数据库加载的嵌入向量构建 FAISS 索引。"""
    # 获取嵌入向量的维度
    dimension = embeddings.shape[1]

    # 创建 FAISS 索引
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index


class ModelResponse:
    def __init__(self, config_path: str = "config.yaml"):
        """初始化 Response 类,加载配置文件并初始化向量库。"""
        # 读取配置文件
        with open(config_path, "r", encoding="utf-8") as f:
            self.config = yaml.safe_load(f)

        self.chat_history = []
        self.history_db_path = self.config["database"]["history_db_path"]
        # 初始化 API URL
        self.OLLAMA_API_URL_EMBED = self.config["ollama"]["api_url_embedding"]
        self.OLLAMA_API_URL_GENER = self.config["ollama"]["api_url_generate"]

        # 初始化向量库
        self.vector_db_path = self.config["database"]["vector_db_path"]
        self.documents, self.embeddings = self.load_embeddings_from_db(self.vector_db_path)
        self.index = build_faiss_index(self.embeddings)
        self.PROMPT_TEMPLATE = self.config.get("prompt_template", "")

        print(f"已从数据库加载 {len(self.documents)} 个文档的嵌入向量。FAISS 索引构建完成!")

    def generate_embedding(self, text: str) -> List[float]:
        """调用 Ollama 的 API 生成文本嵌入。"""
        data = {
            "model": self.config["ollama"]["embedding_model"],  # 使用配置文件中的嵌入模型
            "prompt": text,
            "options": {"embedding_only": True}  # 只生成嵌入
        }
        response = requests.post(f"{self.OLLAMA_API_URL_EMBED}/embeddings", json=data)
        return response.json().get("embedding", [])

    @staticmethod
    def normalize_embeddings(embeddings: np.ndarray) -> np.ndarray:
        """对嵌入向量进行归一化。"""
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        return embeddings / norms

    def load_embeddings_from_db(self, db_path: str):
        """从 SQLite 数据库中加载文档和对应的嵌入向量。"""
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        # 查询数据库中的嵌入向量
        cursor.execute('SELECT id, pdf_file_name, document_text, embedding FROM document_embeddings')
        rows = cursor.fetchall()

        # 将字节流转换回嵌入向量
        documents = []
        embeddings = []
        for row in rows:
            doc_id, pdf_file_name, doc_text, embedding_bytes = row
            embedding = np.frombuffer(embedding_bytes, dtype=np.float32)
            documents.append({"id": doc_id, "pdf_file_name": pdf_file_name, "text": doc_text})
            embeddings.append(embedding)
        embeddings = self.normalize_embeddings(np.array(embeddings))

        conn.close()
        return documents, embeddings

    def retrieve_documents(self, query: str, k: int = 5, threshold: float = 1.0) -> List[dict]:
        """根据查询检索最相关的文档,并根据阈值过滤结果。"""
        # 生成查询嵌入
        query_embedding = np.array([self.generate_embedding(query)], dtype=np.float32)
        query_embedding = self.normalize_embeddings(query_embedding)
        distances, indices = self.index.search(query_embedding, k)

        relevant_docs = []
        for i, idx in enumerate(indices[0]):
            if distances[0][i] <= threshold:
                relevant_docs.append({
                    "text": self.documents[idx]["text"],
                    "score": float(distances[0][i])
                })
        relevant_docs.sort(key=lambda x: x["score"])
        return relevant_docs

    def generate_answer(self, final_prompt: str) -> str:
        """调用 Ollama 的 API 生成答案,并载入历史对话。"""
        # 定义模型参数
        data = {
            "model": self.config["ollama"]["generation_model"],
            "prompt": final_prompt,
            "stream": False,
            "temperature": self.config["ollama"]["temperature"]
        }
        response = requests.post(f"{self.OLLAMA_API_URL_GENER}/generate", json=data)
        if response.status_code == 200:
            return response.json().get("response", "")
        else:
            return f"API 请求失败,状态码:{response.status_code}"

    def ask(self, query: str, chat_history: List[dict] = None) -> Dict[str, str]:
        """接收用户的问题,检索相关文档并生成答案。"""
        if chat_history is None:
            chat_history = []

        # 检索相关文档
        relevant_docs = self.retrieve_documents(query, k=self.config["retrieval"]["k"],
                                                threshold=self.config["retrieval"]["threshold"])

        if not relevant_docs:
            return {
                "status": "error",
                "response": "未找到相关文档。",
            }

        # 构建最终提示
        final_prompt = build_final_prompt(query, chat_history, relevant_docs)

        # 生成答案
        answer = self.generate_answer(final_prompt)

        return {
            "query": query,
            "answer": answer,
        }

4. 连接方式

如果你没有前端界面的代码也没有关系,你可以通过uvicorn 启动 FastAPI 服务,直接访问URL,然后在终端来查看接口的数据。

bash 复制代码
uvicorn app:app --host 0.0.0.0 --port 8000

服务启动后,终端会显示类似以下信息:

bash 复制代码
INFO:     Started server process [12345]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

你可以通过FastAPI 自动生成了交互式 API 文档(Swagger UI)来访问接口 。

  1. 先访问 Swagger UI:http://localhost:8000/docs
  2. 找到 /register、/login 、/chat接口,点击 Try it out
  3. 输入 JSON 数据(如 {"username": "testuser", "password": "testpassword"}),然后点击 Execute
  4. 查看响应结果
bash 复制代码
访问根路径:http://localhost:8000/
访问用户登录。 http://localhost:8000/login
访问用户注册。http://localhost:8000/register
访问聊天交互。http://localhost:8000/chat

你也可以通过在终端使用cURL命令发送 HTTP 请求:

bash 复制代码
curl -X POST "http://localhost:8000/register" -H "Content-Type: application/json" -d '{"username": "user1", "password": "12345"}'
bash 复制代码
curl -X POST "http://localhost:8000/login" -H "Content-Type: application/json" -d '{"username": "user1", "password": "12345"}'
bash 复制代码
curl -X POST "http://localhost:8000/chat" -H "Content-Type: application/json" -d '{"user_input": "什么是人工智能?", "chat_history": []}'
相关推荐
欣然~6 分钟前
基于蒙特卡洛方法的网格世界求解
开发语言·python·信息可视化
棉猴17 分钟前
Pygame实现记忆拼图游戏14
python·游戏·pygame·游戏编程·python游戏编程
小爬虫程序猿38 分钟前
如何解析返回的商品信息?
爬虫·python
搏博1 小时前
本地基于Ollama部署的DeepSeek详细接口文档说明
人工智能·python·深度学习·神经网络
小白的高手之路2 小时前
Pytorch中的torch.utils.data.Dataset 类
pytorch·python·深度学习
舊時王謝堂前燕3 小时前
macOS使用brew切换Python版本【超详细图解】
python·macos
yukai080084 小时前
【最后203篇系列】021 Q201再计划
python
battlestar4 小时前
Siemens Smart 200 PLC 通讯(基于python-)
前端·网络·python
人类群星闪耀时4 小时前
回溯法经典练习:组合总和的深度解析与实战
开发语言·python
时光呢4 小时前
JAVA泛型的作用
java·windows·python