AI 项目架构设计与代码组织规范
一、项目架构概述
1.1 架构设计原则
SOLID 原则:
- Single Responsibility:单一职责
- Open/Closed:开闭原则
- Liskov Substitution:里氏替换
- Interface Segregation:接口隔离
- Dependency Inversion:依赖倒置
其他原则:
- DRY (Don't Repeat Yourself):避免重复
- KISS (Keep It Simple, Stupid):保持简单
- YAGNI (You Ain't Gonna Need It):不要过度设计
1.2 架构模式
| 模式 | 说明 | 适用场景 |
|---|---|---|
| 分层架构 | 按功能分层 | 大多数项目 |
| 微服务架构 | 服务独立部署 | 大型应用 |
| 事件驱动 | 基于事件通信 | 异步处理 |
| 插件架构 | 可扩展组件 | 需要插件支持 |
1.3 AI 项目特有考量
| 考量 | 说明 | 示例 |
|---|---|---|
| 模型管理 | 模型版本、缓存、加载 | 模型热更新 |
| 数据处理 | 数据管道、预处理 | 数据清洗流程 |
| API 调用 | 外部 API 调用、重试 | 调用 OpenAI API |
| 缓存策略 | 结果缓存、模型缓存 | Redis 缓存 |
二、项目结构规范
2.1 标准项目结构
ai-project/
├── config/ # 配置文件
│ ├── settings.py # 应用配置
│ ├── logger.py # 日志配置
│ └── database.py # 数据库配置
│
├── core/ # 核心模块
│ ├── __init__.py
│ ├── base.py # 基类定义
│ └── exceptions.py # 自定义异常
│
├── models/ # 模型层
│ ├── __init__.py
│ ├── llm/ # 大语言模型
│ │ ├── base.py
│ │ ├── openai.py
│ │ ├── claude.py
│ │ └── qwen.py
│ ├── embedding/ # 嵌入模型
│ │ ├── base.py
│ │ └── openai_embedding.py
│ └── classifier/ # 分类模型
│ └── base.py
│
├── services/ # 服务层
│ ├── __init__.py
│ ├── prompt_service.py # 提示词服务
│ ├── rag_service.py # RAG 服务
│ ├── agent_service.py # 智能体服务
│ └── data_service.py # 数据服务
│
├── api/ # API 层
│ ├── __init__.py
│ ├── routers/ # 路由定义
│ │ ├── chat.py
│ │ ├── agent.py
│ │ └── health.py
│ ├── schemas/ # 请求/响应模型
│ │ ├── chat.py
│ │ └── agent.py
│ └── dependencies.py # 依赖注入
│
├── utils/ # 工具模块
│ ├── __init__.py
│ ├── token_utils.py # Token 工具
│ ├── text_utils.py # 文本处理
│ └── cache_utils.py # 缓存工具
│
├── database/ # 数据库模块
│ ├── __init__.py
│ ├── models.py # ORM 模型
│ ├── migrations/ # 数据库迁移
│ └── session.py # 数据库会话
│
├── tests/ # 测试模块
│ ├── __init__.py
│ ├── unit/ # 单元测试
│ ├── integration/ # 集成测试
│ └── fixtures/ # 测试数据
│
├── docs/ # 文档
│ ├── api.md # API 文档
│ ├── architecture.md # 架构文档
│ └── deployment.md # 部署文档
│
├── scripts/ # 脚本
│ ├── setup.py # 初始化脚本
│ └── migrate.py # 迁移脚本
│
├── .env.example # 环境变量示例
├── .gitignore
├── README.md
├── requirements.txt
└── main.py # 入口文件
2.2 目录职责说明
| 目录 | 职责 | 说明 |
|---|---|---|
| config/ | 配置管理 | 集中管理配置项 |
| core/ | 核心逻辑 | 基础类、异常定义 |
| models/ | 模型封装 | LLM、Embedding 等模型 |
| services/ | 业务逻辑 | 核心业务实现 |
| api/ | API 接口 | RESTful API 定义 |
| utils/ | 工具函数 | 通用工具类 |
| database/ | 数据持久化 | ORM、数据库操作 |
| tests/ | 测试代码 | 单元测试、集成测试 |
| docs/ | 项目文档 | API、架构、部署文档 |
| scripts/ | 辅助脚本 | 初始化、迁移脚本 |
三、配置管理
3.1 配置文件结构
python
# config/settings.py
import os
from dotenv import load_dotenv
load_dotenv()
class Settings:
# 应用配置
APP_NAME = os.getenv("APP_NAME", "AI Project")
APP_ENV = os.getenv("APP_ENV", "development")
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
# 服务器配置
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8000"))
# LLM 配置
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
CLAUDE_API_KEY = os.getenv("CLAUDE_API_KEY")
QWEN_API_KEY = os.getenv("QWEN_API_KEY")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gpt-4o")
# 数据库配置
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./app.db")
# 缓存配置
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
# 日志配置
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_FILE = os.getenv("LOG_FILE", "app.log")
settings = Settings()
3.2 环境变量示例
env
# .env.example
APP_NAME=AI Project
APP_ENV=development
DEBUG=true
HOST=0.0.0.0
PORT=8000
# LLM API Keys
OPENAI_API_KEY=your-openai-api-key
CLAUDE_API_KEY=your-claude-api-key
QWEN_API_KEY=your-qwen-api-key
DEFAULT_MODEL=gpt-4o
# Database
DATABASE_URL=postgresql://user:password@localhost/dbname
# Redis
REDIS_URL=redis://localhost:6379/0
# Logging
LOG_LEVEL=INFO
LOG_FILE=app.log
3.3 配置验证
python
# config/settings.py
def validate_settings(settings):
errors = []
if not settings.OPENAI_API_KEY:
errors.append("OPENAI_API_KEY is required")
if not settings.DATABASE_URL:
errors.append("DATABASE_URL is required")
if errors:
raise ValueError("\n".join(errors))
return settings
四、模型层设计
4.1 模型抽象层
python
# models/llm/base.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
class BaseLLM(ABC):
@abstractmethod
def generate(self, prompt: str, **kwargs) -> str:
pass
@abstractmethod
def chat(self, messages: list, **kwargs) -> str:
pass
@abstractmethod
def embeddings(self, text: str) -> list:
pass
def get_model_name(self) -> str:
return self.__class__.__name__
4.2 OpenAI 模型实现
python
# models/llm/openai.py
from .base import BaseLLM
from openai import OpenAI
from config.settings import settings
class OpenAIModel(BaseLLM):
def __init__(self, model_name: str = "gpt-4o"):
self.client = OpenAI(api_key=settings.OPENAI_API_KEY)
self.model_name = model_name
def generate(self, prompt: str, **kwargs) -> str:
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
**kwargs
)
return response.choices[0].text
def chat(self, messages: list, **kwargs) -> str:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**kwargs
)
return response.choices[0].message.content
def embeddings(self, text: str) -> list:
response = self.client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
4.3 模型工厂
python
# models/llm/__init__.py
from .base import BaseLLM
from .openai import OpenAIModel
from .claude import ClaudeModel
from .qwen import QwenModel
class LLMFactory:
_models = {
"openai": OpenAIModel,
"claude": ClaudeModel,
"qwen": QwenModel
}
@classmethod
def create(cls, model_type: str, **kwargs) -> BaseLLM:
if model_type not in cls._models:
raise ValueError(f"Unknown model type: {model_type}")
return cls._models[model_type](**kwargs)
# 使用示例
llm = LLMFactory.create("openai", model_name="gpt-4o")
五、服务层设计
5.1 提示词服务
python
# services/prompt_service.py
from typing import Dict, Any
class PromptService:
def __init__(self):
self.templates = {}
def register_template(self, name: str, template: str):
self.templates[name] = template
def render(self, name: str, **kwargs) -> str:
if name not in self.templates:
raise ValueError(f"Template {name} not found")
return self.templates[name].format(**kwargs)
def create_system_prompt(self, role: str, task: str, **kwargs) -> str:
template = """
你是一位专业的{role}。
任务:{task}
{additional_info}
请提供专业、准确的回答。
"""
additional_info = "\n".join([f"{k}: {v}" for k, v in kwargs.items()])
return template.format(role=role, task=task, additional_info=additional_info)
# 使用示例
prompt_service = PromptService()
prompt_service.register_template("customer_service", """
你是一位专业的客服助手。
用户问题:{{user_input}}
请根据以下知识回答:
{{knowledge}}
""")
prompt = prompt_service.render("customer_service", user_input="价格是多少?", knowledge="产品价格为 100 元")
5.2 RAG 服务
python
# services/rag_service.py
from models.llm.base import BaseLLM
from utils.embedding_utils import get_embedding
from database.models import Document
from typing import List
class RAGService:
def __init__(self, llm: BaseLLM):
self.llm = llm
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
# 获取查询向量
query_embedding = get_embedding(query)
# 向量检索
documents = Document.search(query_embedding, top_k=top_k)
return [doc.content for doc in documents]
def generate(self, query: str, knowledge: List[str]) -> str:
# 构建 Prompt
knowledge_text = "\n\n".join(knowledge)
prompt = f"""
用户问题:{query}
相关知识:
{knowledge_text}
请基于以上知识回答用户问题。
"""
# 生成回答
return self.llm.generate(prompt)
def query(self, query: str, top_k: int = 3) -> str:
# 检索知识
knowledge = self.retrieve(query, top_k)
# 生成回答
return self.generate(query, knowledge)
5.3 Agent 服务
python
# services/agent_service.py
from models.llm.base import BaseLLM
from services.prompt_service import PromptService
from services.rag_service import RAGService
class AgentService:
def __init__(self, llm: BaseLLM):
self.llm = llm
self.prompt_service = PromptService()
self.rag_service = RAGService(llm)
def run(self, task: str, context: dict = None) -> str:
# 获取上下文
context = context or {}
# 判断是否需要知识库
if context.get("use_rag", True):
knowledge = self.rag_service.retrieve(task)
context["knowledge"] = "\n\n".join(knowledge)
# 构建 Prompt
prompt = self.prompt_service.create_system_prompt(
role="智能助手",
task=task,
**context
)
# 生成回答
return self.llm.generate(prompt)
六、API 层设计
6.1 路由定义
python
# api/routers/chat.py
from fastapi import APIRouter, Depends, HTTPException
from api.schemas.chat import ChatRequest, ChatResponse
from services.agent_service import AgentService
from models.llm import LLMFactory
router = APIRouter(prefix="/chat", tags=["chat"])
def get_agent_service() -> AgentService:
llm = LLMFactory.create("openai")
return AgentService(llm)
@router.post("/", response_model=ChatResponse)
async def chat(request: ChatRequest, agent: AgentService = Depends(get_agent_service)):
try:
response = agent.run(request.message, request.context)
return ChatResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/stream")
async def chat_stream(request: ChatRequest, agent: AgentService = Depends(get_agent_service)):
# 流式响应
response = agent.run(request.message, request.context)
return {"response": response}
6.2 请求/响应模型
python
# api/schemas/chat.py
from pydantic import BaseModel
from typing import Optional, Dict
class ChatRequest(BaseModel):
message: str
context: Optional[Dict] = None
use_rag: bool = True
class ChatResponse(BaseModel):
response: str
context: Optional[Dict] = None
6.3 依赖注入
python
# api/dependencies.py
from fastapi import Depends
from models.llm import LLMFactory
from services.agent_service import AgentService
def get_llm():
return LLMFactory.create("openai")
def get_agent_service(llm = Depends(get_llm)):
return AgentService(llm)
def get_db_session():
session = get_session()
try:
yield session
finally:
session.close()
七、数据库设计
7.1 ORM 模型
python
# database/models.py
from sqlalchemy import Column, Integer, String, Text, DateTime
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
Base = declarative_base()
class Document(Base):
__tablename__ = "documents"
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
content = Column(Text)
embedding = Column(Text) # 存储向量
source = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class Conversation(Base):
__tablename__ = "conversations"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(String, index=True)
messages = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class ChatMessage(Base):
__tablename__ = "chat_messages"
id = Column(Integer, primary_key=True, index=True)
conversation_id = Column(Integer, index=True)
role = Column(String)
content = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
7.2 数据库会话
python
# database/session.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from config.settings import settings
engine = create_engine(settings.DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_session():
db = SessionLocal()
try:
yield db
finally:
db.close()
八、工具模块
8.1 Token 工具
python
# utils/token_utils.py
import tiktoken
def count_tokens(text: str, model: str = "gpt-4o") -> int:
"""计算文本的 Token 数量"""
encoder = tiktoken.encoding_for_model(model)
return len(encoder.encode(text))
def truncate_text(text: str, max_tokens: int = 4096, model: str = "gpt-4o") -> str:
"""截断文本到指定 Token 数量"""
encoder = tiktoken.encoding_for_model(model)
tokens = encoder.encode(text)
if len(tokens) <= max_tokens:
return text
truncated_tokens = tokens[:max_tokens]
return encoder.decode(truncated_tokens)
def estimate_cost(text: str, model: str = "gpt-4o") -> float:
"""估算 API 调用成本"""
token_count = count_tokens(text, model)
costs = {
"gpt-4o": {"input": 5, "output": 15},
"gpt-3.5-turbo": {"input": 0.5, "output": 1.5}
}
if model not in costs:
return 0.0
return (token_count * costs[model]["input"]) / 1000000
8.2 缓存工具
python
# utils/cache_utils.py
import redis
from config.settings import settings
import json
from typing import Optional
class CacheService:
def __init__(self):
self.redis = redis.Redis.from_url(settings.REDIS_URL)
def set(self, key: str, value: any, ttl: int = 3600):
"""设置缓存"""
if isinstance(value, (dict, list)):
value = json.dumps(value)
self.redis.set(key, value, ex=ttl)
def get(self, key: str) -> Optional[any]:
"""获取缓存"""
value = self.redis.get(key)
if value:
try:
return json.loads(value)
except json.JSONDecodeError:
return value.decode("utf-8")
return None
def delete(self, key: str):
"""删除缓存"""
self.redis.delete(key)
def exists(self, key: str) -> bool:
"""检查缓存是否存在"""
return self.redis.exists(key) > 0
# 使用示例
cache = CacheService()
cache.set("user:123", {"name": "张三", "email": "zhangsan@example.com"})
user = cache.get("user:123")
九、日志与监控
9.1 日志配置
python
# config/logger.py
import logging
from config.settings import settings
import sys
def setup_logger():
logger = logging.getLogger("ai-project")
logger.setLevel(settings.LOG_LEVEL)
# 避免重复添加处理器
if logger.handlers:
return logger
# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(settings.LOG_LEVEL)
# 文件处理器
file_handler = logging.FileHandler(settings.LOG_FILE)
file_handler.setLevel(settings.LOG_LEVEL)
# 格式化器
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
logger = setup_logger()
9.2 监控指标
python
# services/monitor_service.py
from prometheus_client import Counter, Histogram, Summary
from time import time
# 请求计数
REQUEST_COUNT = Counter("ai_requests_total", "Total requests", ["endpoint", "method"])
# 请求耗时
REQUEST_LATENCY = Histogram("ai_request_latency_seconds", "Request latency", ["endpoint"])
# Token 使用量
TOKEN_USAGE = Counter("ai_tokens_used", "Tokens used", ["model", "type"])
def monitor_request(endpoint: str, method: str = "POST"):
"""装饰器:监控请求"""
def decorator(func):
def wrapper(*args, **kwargs):
REQUEST_COUNT.labels(endpoint=endpoint, method=method).inc()
start_time = time()
try:
return func(*args, **kwargs)
finally:
REQUEST_LATENCY.labels(endpoint=endpoint).observe(time() - start_time)
return wrapper
return decorator
def record_token_usage(model: str, tokens: int, token_type: str = "input"):
"""记录 Token 使用量"""
TOKEN_USAGE.labels(model=model, type=token_type).inc(tokens)
十、测试规范
10.1 测试结构
tests/
├── unit/
│ ├── test_prompt_service.py
│ ├── test_rag_service.py
│ └── test_token_utils.py
├── integration/
│ ├── test_chat_api.py
│ └── test_database.py
└── fixtures/
├── test_data.py
└── mock_llm.py
10.2 单元测试示例
python
# tests/unit/test_prompt_service.py
import pytest
from services.prompt_service import PromptService
class TestPromptService:
def setup_method(self):
self.service = PromptService()
def test_register_template(self):
self.service.register_template("test", "Hello {{name}}")
assert "test" in self.service.templates
def test_render_template(self):
self.service.register_template("test", "Hello {{name}}")
result = self.service.render("test", name="World")
assert result == "Hello World"
def test_render_missing_template(self):
with pytest.raises(ValueError):
self.service.render("missing")
def test_create_system_prompt(self):
result = self.service.create_system_prompt(
role="测试员",
task="测试功能",
additional="额外信息"
)
assert "测试员" in result
assert "测试功能" in result
10.3 集成测试示例
python
# tests/integration/test_chat_api.py
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
class TestChatAPI:
def test_chat_endpoint(self):
response = client.post(
"/chat/",
json={"message": "你好"}
)
assert response.status_code == 200
assert "response" in response.json()
def test_chat_with_context(self):
response = client.post(
"/chat/",
json={
"message": "价格是多少?",
"context": {"product": "手机"}
}
)
assert response.status_code == 200
十一、部署规范
11.1 Docker 配置
dockerfile
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
yaml
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- APP_ENV=production
- DATABASE_URL=postgresql://user:password@db:5432/app
- REDIS_URL=redis://redis:6379
depends_on:
- db
- redis
db:
image: postgres:15
volumes:
- postgres_data:/var/lib/postgresql/data
environment:
- POSTGRES_USER=user
- POSTGRES_PASSWORD=password
- POSTGRES_DB=app
redis:
image: redis:7-alpine
volumes:
- redis_data:/data
volumes:
postgres_data:
redis_data:
11.2 CI/CD 配置
yaml
# .github/workflows/deploy.yml
name: Deploy
on:
push:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: pip install -r requirements.txt
- name: Run tests
run: pytest
deploy:
needs: test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Deploy to production
run: ./scripts/deploy.sh
十二、代码风格规范
12.1 Python 代码风格
PEP 8 规范:
- 使用 4 空格缩进
- 行长度不超过 120 字符
- 使用 snake_case 命名变量和函数
- 使用 PascalCase 命名类
- 导入按顺序排列
代码示例:
python
# 好的示例
class UserService:
def get_user_by_id(self, user_id: int) -> Optional[User]:
"""根据用户ID获取用户"""
return self.db.query(User).filter(User.id == user_id).first()
# 不好的示例
class userService: # 类名应使用 PascalCase
def getUserById(self, userId): # 方法名应使用 snake_case
return self.db.query(User).filter(User.id == userId).first()
12.2 文档字符串规范
Google 风格:
python
def generate_report(data: dict, format: str = "pdf") -> str:
"""生成数据分析报告
Args:
data: 分析数据字典
format: 输出格式,支持 pdf、html、markdown
Returns:
报告内容字符串
Raises:
ValueError: 当格式不支持时抛出
"""
pass
12.3 Git 提交规范
Conventional Commits:
feat: 添加用户登录功能
fix: 修复登录页面样式问题
docs: 更新 API 文档
refactor: 重构用户服务代码
test: 添加单元测试
十三、总结
核心要点
- 项目结构:分层架构,职责清晰
- 配置管理:环境变量,集中管理
- 模型层:抽象基类,工厂模式
- 服务层:业务逻辑,组合复用
- API 层:RESTful,依赖注入
- 数据库:ORM 模型,会话管理
- 测试:单元测试 + 集成测试
- 部署:Docker,CI/CD
架构图
┌─────────────────────────────────────────────────────────────┐
│ API Layer │
│ (FastAPI, Routes, Schemas) │
├─────────────────────────────────────────────────────────────┤
│ Service Layer │
│ (AgentService, RAGService, PromptService) │
├─────────────────────────────────────────────────────────────┤
│ Model Layer │
│ (LLM, Embedding, Classifier) │
├─────────────────────────────────────────────────────────────┤
│ Database Layer │
│ (PostgreSQL, Redis, Vector DB) │
└─────────────────────────────────────────────────────────────┘
下一步建议
- 根据项目需求调整结构
- 使用依赖注入管理服务
- 添加完整的测试覆盖
- 配置监控和日志
- 实现 CI/CD 流程