后端接口实现以及接口调用的类代码一览
- [1. 后端接口代码](#1. 后端接口代码)
- [2. 代码结构概述](#2. 代码结构概述)
- [3. 主要功能模块](#3. 主要功能模块)
-
- [1. 跨域支持](#1. 跨域支持)
- [2. 用户登录接口(/login)](#2. 用户登录接口(/login))
- [3. 用户注册接口(/register)](#3. 用户注册接口(/register))
- 4.用户相关接口依赖的类
- 5.聊天接口(/chat)
- 6.聊天接口依赖的类
- [4. 连接方式](#4. 连接方式)
1. 后端接口代码
python
# app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict
import uvicorn
from user_database import UserDatabase
from ModelResponse import ModelResponse
app = FastAPI()
# 允许跨域访问(适配 Gradio 调用 FastAPI)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
user_db = UserDatabase()
response = ModelResponse()
class LoginRequest(BaseModel):
username: str
password: str
class RegisterRequest(BaseModel):
username: str
password: str
class ChatRequest(BaseModel):
user_input: str
chat_history: List[dict]
class ChatResponse(BaseModel):
status: str
response: str
@app.post("/login")
async def login(credentials: LoginRequest):
"""登录接口,验证用户名和密码。"""
if user_db.verify_user(credentials.username, credentials.password):
return {
"status": "success",
"message": "Login successful"
}
else:
raise HTTPException(status_code=401, detail="Invalid username or password")
@app.post("/register")
async def register(user: RegisterRequest):
"""注册接口,添加用户。"""
if user_db.add_user(user.username, user.password):
return {
"status": "success",
"message": f"User '{user.username}' registered successfully."
}
else:
raise HTTPException(status_code=400, detail="User already exists.")
@app.post("/chat")
async def chat(request: ChatRequest) -> ChatResponse:
try:
user_input = request.user_input
chat_history = request.chat_history
bot_response = response.ask(user_input, chat_history)
updated_chat_history = chat_history + [[user_input, bot_response["answer"]]]
return ChatResponse(
status="success",
response=bot_response["answer"],
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
2. 代码结构概述
app.py代码是一个基于 FastAPI 的后端服务,旨在为前端提供接口,支持用户登录、注册以及与聊天机器人进行交互的功能。以下是代码的详细功能介绍:
- 代码结构概述
FastAPI:一个现代、高性能的 Python Web 框架,用于构建 API。
CORS 中间件:允许跨域请求,方便前端(如 Gradio、React 等)调用后端 API。
用户数据库:通过 UserDatabase 类管理用户的登录和注册。
聊天机器人:通过 ModelResponse 类实现基于 LLM(大语言模型)的问答功能。
API 接口:提供了 /login、/register 和 /chat 三个接口,分别用于用户登录、注册和聊天交互。
3. 主要功能模块
1. 跨域支持
python
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
可以允许前端从任何域名访问后端API
2. 用户登录接口(/login)
python
@app.post("/login")
async def login(credentials: LoginRequest):
if user_db.verify_user(credentials.username, credentials.password):
return {
"status": "success",
"message": "Login successful"
}
else:
raise HTTPException(status_code=401, detail="Invalid username or password")
用于验证用户的用户名和密码。通过UserDatabase类中的verify_user方法实现。如果验证成功,返回 {"status": "success", "message": "Login successful"}。这是前端期望的数据格式,一定要注意前端期望后端返回什么样的数据类型!!接口接受和发送的数据格式最好在工程定框架的时候就定死,不要轻易改动。如果验证失败,返回 401 状态码和错误信息 "Invalid username or password"
3. 用户注册接口(/register)
python
@app.post("/register")
async def register(user: RegisterRequest):
if user_db.add_user(user.username, user.password):
return {
"status": "success",
"message": f"User '{user.username}' registered successfully."
}
else:
raise HTTPException(status_code=400, detail="User already exists.")
功能是注册新用户。通过UserDatabase类中的add_user方法实现。如果注册成功,返回 {"status": "success", "message": "User registered successfully."}。如果用户名已存在,返回 400 状态码和错误信息 "User already exists."
4.用户相关接口依赖的类
下面是实现接口中方法的类。
python
# user_database.py
import sqlite3
import yaml
with open("config.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
class UserDatabase:
def __init__(self):
"""
初始化 UserDatabase,使用本地 SQLite 数据库。
"""
self.db_path = config["database"]["user_db_path"]
self._init_db()
def _init_db(self):
"""初始化用户信息库"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL
)
''')
conn.commit()
print("User database initialized.")
def add_user(self, username, password):
"""添加用户到用户信息库"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 检查用户是否已存在
cursor.execute('SELECT username FROM user WHERE username = ?', (username,))
if cursor.fetchone():
print(f"User '{username}' already exists.")
return False # 用户已存在,返回 False
# 插入新用户
cursor.execute('''
INSERT INTO user (username, password) VALUES (?, ?)
''', (username, password))
conn.commit()
print(f"User '{username}' added to the database.")
return True # 用户添加成功,返回 True
except sqlite3.Error as e:
print(f"Database error: {e}")
return False # 数据库操作失败,返回 False
def get_user_by_id(self, user_id):
"""根据用户 ID 查询用户信息"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM user WHERE id = ?
''', (user_id,))
row = cursor.fetchone() # 只调用一次 fetchone
return dict(row) if row else None
def get_all_users(self):
"""获取所有用户信息"""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute('SELECT * FROM user')
return [dict(row) for row in cursor.fetchall()]
def verify_user(self, username, password):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('SELECT * FROM user WHERE username = ? AND password = ?', (username, password))
return cursor.fetchone() is not None
5.聊天接口(/chat)
python
@app.post("/chat")
async def chat(request: ChatRequest) -> ChatResponse:
try:
user_input = request.user_input
chat_history = request.chat_history
bot_response = response.ask(user_input, chat_history)
updated_chat_history = chat_history + [[user_input, bot_response["answer"]]]
return ChatResponse(
status="success",
response=bot_response["answer"],
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
功能是接收用户输入和聊天历史,调用response类中的ask方法生成回答。如果成功,返回 {"status": "success", "response": "生成的回答"}。如果发生错误,返回 500 状态码和错误信息。
6.聊天接口依赖的类
python
import sqlite3
import numpy as np
import faiss
import requests
from typing import List, Dict
import yaml
def build_final_prompt(query: str, chat_history: List[Dict[str, str]],
relevant_docs: List[Dict[str, str]]) -> str:
"""构建最终的提示(final_prompt),包含对话历史和相关文档内容。"""
# 生成对话历史上下文
full_prompt = []
for line in chat_history:
role = line.get("role")
content = line.get("content")
if role == "user":
current_prompt = f"用户:{content}"
elif role == "assistant":
current_prompt = f"助手:{content}"
else:
raise Exception(f"无法支持的角色类型, {role}")
full_prompt.append(current_prompt)
full_prompt = "\n".join(full_prompt)
doc_context = "\n".join([doc["text"] for doc in relevant_docs[:3]])
# 构建最终提示
final_prompt = (
f"你是一个助手,帮助用户解答问题。\n"
f"背景资料:\n{doc_context}\n\n"
f"对话历史:\n{full_prompt}\n\n"
f"用户的问题:\n{query}"
)
return final_prompt
def build_faiss_index(embeddings: np.ndarray):
"""使用从数据库加载的嵌入向量构建 FAISS 索引。"""
# 获取嵌入向量的维度
dimension = embeddings.shape[1]
# 创建 FAISS 索引
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
class ModelResponse:
def __init__(self, config_path: str = "config.yaml"):
"""初始化 Response 类,加载配置文件并初始化向量库。"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
self.config = yaml.safe_load(f)
self.chat_history = []
self.history_db_path = self.config["database"]["history_db_path"]
# 初始化 API URL
self.OLLAMA_API_URL_EMBED = self.config["ollama"]["api_url_embedding"]
self.OLLAMA_API_URL_GENER = self.config["ollama"]["api_url_generate"]
# 初始化向量库
self.vector_db_path = self.config["database"]["vector_db_path"]
self.documents, self.embeddings = self.load_embeddings_from_db(self.vector_db_path)
self.index = build_faiss_index(self.embeddings)
self.PROMPT_TEMPLATE = self.config.get("prompt_template", "")
print(f"已从数据库加载 {len(self.documents)} 个文档的嵌入向量。FAISS 索引构建完成!")
def generate_embedding(self, text: str) -> List[float]:
"""调用 Ollama 的 API 生成文本嵌入。"""
data = {
"model": self.config["ollama"]["embedding_model"], # 使用配置文件中的嵌入模型
"prompt": text,
"options": {"embedding_only": True} # 只生成嵌入
}
response = requests.post(f"{self.OLLAMA_API_URL_EMBED}/embeddings", json=data)
return response.json().get("embedding", [])
@staticmethod
def normalize_embeddings(embeddings: np.ndarray) -> np.ndarray:
"""对嵌入向量进行归一化。"""
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
def load_embeddings_from_db(self, db_path: str):
"""从 SQLite 数据库中加载文档和对应的嵌入向量。"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 查询数据库中的嵌入向量
cursor.execute('SELECT id, pdf_file_name, document_text, embedding FROM document_embeddings')
rows = cursor.fetchall()
# 将字节流转换回嵌入向量
documents = []
embeddings = []
for row in rows:
doc_id, pdf_file_name, doc_text, embedding_bytes = row
embedding = np.frombuffer(embedding_bytes, dtype=np.float32)
documents.append({"id": doc_id, "pdf_file_name": pdf_file_name, "text": doc_text})
embeddings.append(embedding)
embeddings = self.normalize_embeddings(np.array(embeddings))
conn.close()
return documents, embeddings
def retrieve_documents(self, query: str, k: int = 5, threshold: float = 1.0) -> List[dict]:
"""根据查询检索最相关的文档,并根据阈值过滤结果。"""
# 生成查询嵌入
query_embedding = np.array([self.generate_embedding(query)], dtype=np.float32)
query_embedding = self.normalize_embeddings(query_embedding)
distances, indices = self.index.search(query_embedding, k)
relevant_docs = []
for i, idx in enumerate(indices[0]):
if distances[0][i] <= threshold:
relevant_docs.append({
"text": self.documents[idx]["text"],
"score": float(distances[0][i])
})
relevant_docs.sort(key=lambda x: x["score"])
return relevant_docs
def generate_answer(self, final_prompt: str) -> str:
"""调用 Ollama 的 API 生成答案,并载入历史对话。"""
# 定义模型参数
data = {
"model": self.config["ollama"]["generation_model"],
"prompt": final_prompt,
"stream": False,
"temperature": self.config["ollama"]["temperature"]
}
response = requests.post(f"{self.OLLAMA_API_URL_GENER}/generate", json=data)
if response.status_code == 200:
return response.json().get("response", "")
else:
return f"API 请求失败,状态码:{response.status_code}"
def ask(self, query: str, chat_history: List[dict] = None) -> Dict[str, str]:
"""接收用户的问题,检索相关文档并生成答案。"""
if chat_history is None:
chat_history = []
# 检索相关文档
relevant_docs = self.retrieve_documents(query, k=self.config["retrieval"]["k"],
threshold=self.config["retrieval"]["threshold"])
if not relevant_docs:
return {
"status": "error",
"response": "未找到相关文档。",
}
# 构建最终提示
final_prompt = build_final_prompt(query, chat_history, relevant_docs)
# 生成答案
answer = self.generate_answer(final_prompt)
return {
"query": query,
"answer": answer,
}
4. 连接方式
如果你没有前端界面的代码也没有关系,你可以通过uvicorn 启动 FastAPI 服务,直接访问URL,然后在终端来查看接口的数据。
bash
uvicorn app:app --host 0.0.0.0 --port 8000
服务启动后,终端会显示类似以下信息:
bash
INFO: Started server process [12345]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
你可以通过FastAPI 自动生成了交互式 API 文档(Swagger UI)来访问接口 。
- 先访问 Swagger UI:http://localhost:8000/docs
- 找到 /register、/login 、/chat接口,点击 Try it out
- 输入 JSON 数据(如 {"username": "testuser", "password": "testpassword"}),然后点击 Execute
- 查看响应结果
bash
访问根路径:http://localhost:8000/
访问用户登录。 http://localhost:8000/login
访问用户注册。http://localhost:8000/register
访问聊天交互。http://localhost:8000/chat
你也可以通过在终端使用cURL命令发送 HTTP 请求:
bash
curl -X POST "http://localhost:8000/register" -H "Content-Type: application/json" -d '{"username": "user1", "password": "12345"}'
bash
curl -X POST "http://localhost:8000/login" -H "Content-Type: application/json" -d '{"username": "user1", "password": "12345"}'
bash
curl -X POST "http://localhost:8000/chat" -H "Content-Type: application/json" -d '{"user_input": "什么是人工智能?", "chat_history": []}'