AI项目架构设计与代码组织规范

AI 项目架构设计与代码组织规范


一、项目架构概述

1.1 架构设计原则

SOLID 原则

  • Single Responsibility:单一职责
  • Open/Closed:开闭原则
  • Liskov Substitution:里氏替换
  • Interface Segregation:接口隔离
  • Dependency Inversion:依赖倒置

其他原则

  • DRY (Don't Repeat Yourself):避免重复
  • KISS (Keep It Simple, Stupid):保持简单
  • YAGNI (You Ain't Gonna Need It):不要过度设计

1.2 架构模式

模式 说明 适用场景
分层架构 按功能分层 大多数项目
微服务架构 服务独立部署 大型应用
事件驱动 基于事件通信 异步处理
插件架构 可扩展组件 需要插件支持

1.3 AI 项目特有考量

考量 说明 示例
模型管理 模型版本、缓存、加载 模型热更新
数据处理 数据管道、预处理 数据清洗流程
API 调用 外部 API 调用、重试 调用 OpenAI API
缓存策略 结果缓存、模型缓存 Redis 缓存

二、项目结构规范

2.1 标准项目结构

复制代码
ai-project/
├── config/                    # 配置文件
│   ├── settings.py            # 应用配置
│   ├── logger.py              # 日志配置
│   └── database.py            # 数据库配置
│
├── core/                      # 核心模块
│   ├── __init__.py
│   ├── base.py                # 基类定义
│   └── exceptions.py          # 自定义异常
│
├── models/                    # 模型层
│   ├── __init__.py
│   ├── llm/                   # 大语言模型
│   │   ├── base.py
│   │   ├── openai.py
│   │   ├── claude.py
│   │   └── qwen.py
│   ├── embedding/             # 嵌入模型
│   │   ├── base.py
│   │   └── openai_embedding.py
│   └── classifier/            # 分类模型
│       └── base.py
│
├── services/                  # 服务层
│   ├── __init__.py
│   ├── prompt_service.py      # 提示词服务
│   ├── rag_service.py         # RAG 服务
│   ├── agent_service.py       # 智能体服务
│   └── data_service.py        # 数据服务
│
├── api/                       # API 层
│   ├── __init__.py
│   ├── routers/               # 路由定义
│   │   ├── chat.py
│   │   ├── agent.py
│   │   └── health.py
│   ├── schemas/               # 请求/响应模型
│   │   ├── chat.py
│   │   └── agent.py
│   └── dependencies.py        # 依赖注入
│
├── utils/                     # 工具模块
│   ├── __init__.py
│   ├── token_utils.py         # Token 工具
│   ├── text_utils.py          # 文本处理
│   └── cache_utils.py         # 缓存工具
│
├── database/                  # 数据库模块
│   ├── __init__.py
│   ├── models.py              # ORM 模型
│   ├── migrations/            # 数据库迁移
│   └── session.py             # 数据库会话
│
├── tests/                     # 测试模块
│   ├── __init__.py
│   ├── unit/                  # 单元测试
│   ├── integration/           # 集成测试
│   └── fixtures/              # 测试数据
│
├── docs/                      # 文档
│   ├── api.md                 # API 文档
│   ├── architecture.md        # 架构文档
│   └── deployment.md          # 部署文档
│
├── scripts/                   # 脚本
│   ├── setup.py               # 初始化脚本
│   └── migrate.py             # 迁移脚本
│
├── .env.example               # 环境变量示例
├── .gitignore
├── README.md
├── requirements.txt
└── main.py                    # 入口文件

2.2 目录职责说明

目录 职责 说明
config/ 配置管理 集中管理配置项
core/ 核心逻辑 基础类、异常定义
models/ 模型封装 LLM、Embedding 等模型
services/ 业务逻辑 核心业务实现
api/ API 接口 RESTful API 定义
utils/ 工具函数 通用工具类
database/ 数据持久化 ORM、数据库操作
tests/ 测试代码 单元测试、集成测试
docs/ 项目文档 API、架构、部署文档
scripts/ 辅助脚本 初始化、迁移脚本

三、配置管理

3.1 配置文件结构

python 复制代码
# config/settings.py
import os
from dotenv import load_dotenv

load_dotenv()

class Settings:
    # 应用配置
    APP_NAME = os.getenv("APP_NAME", "AI Project")
    APP_ENV = os.getenv("APP_ENV", "development")
    DEBUG = os.getenv("DEBUG", "false").lower() == "true"
    
    # 服务器配置
    HOST = os.getenv("HOST", "0.0.0.0")
    PORT = int(os.getenv("PORT", "8000"))
    
    # LLM 配置
    OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
    CLAUDE_API_KEY = os.getenv("CLAUDE_API_KEY")
    QWEN_API_KEY = os.getenv("QWEN_API_KEY")
    DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gpt-4o")
    
    # 数据库配置
    DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./app.db")
    
    # 缓存配置
    REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
    
    # 日志配置
    LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
    LOG_FILE = os.getenv("LOG_FILE", "app.log")

settings = Settings()

3.2 环境变量示例

env 复制代码
# .env.example
APP_NAME=AI Project
APP_ENV=development
DEBUG=true

HOST=0.0.0.0
PORT=8000

# LLM API Keys
OPENAI_API_KEY=your-openai-api-key
CLAUDE_API_KEY=your-claude-api-key
QWEN_API_KEY=your-qwen-api-key
DEFAULT_MODEL=gpt-4o

# Database
DATABASE_URL=postgresql://user:password@localhost/dbname

# Redis
REDIS_URL=redis://localhost:6379/0

# Logging
LOG_LEVEL=INFO
LOG_FILE=app.log

3.3 配置验证

python 复制代码
# config/settings.py
def validate_settings(settings):
    errors = []
    
    if not settings.OPENAI_API_KEY:
        errors.append("OPENAI_API_KEY is required")
    
    if not settings.DATABASE_URL:
        errors.append("DATABASE_URL is required")
    
    if errors:
        raise ValueError("\n".join(errors))
    
    return settings

四、模型层设计

4.1 模型抽象层

python 复制代码
# models/llm/base.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional

class BaseLLM(ABC):
    @abstractmethod
    def generate(self, prompt: str, **kwargs) -> str:
        pass
    
    @abstractmethod
    def chat(self, messages: list, **kwargs) -> str:
        pass
    
    @abstractmethod
    def embeddings(self, text: str) -> list:
        pass
    
    def get_model_name(self) -> str:
        return self.__class__.__name__

4.2 OpenAI 模型实现

python 复制代码
# models/llm/openai.py
from .base import BaseLLM
from openai import OpenAI
from config.settings import settings

class OpenAIModel(BaseLLM):
    def __init__(self, model_name: str = "gpt-4o"):
        self.client = OpenAI(api_key=settings.OPENAI_API_KEY)
        self.model_name = model_name
    
    def generate(self, prompt: str, **kwargs) -> str:
        response = self.client.completions.create(
            model=self.model_name,
            prompt=prompt,
            **kwargs
        )
        return response.choices[0].text
    
    def chat(self, messages: list, **kwargs) -> str:
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            **kwargs
        )
        return response.choices[0].message.content
    
    def embeddings(self, text: str) -> list:
        response = self.client.embeddings.create(
            model="text-embedding-3-small",
            input=text
        )
        return response.data[0].embedding

4.3 模型工厂

python 复制代码
# models/llm/__init__.py
from .base import BaseLLM
from .openai import OpenAIModel
from .claude import ClaudeModel
from .qwen import QwenModel

class LLMFactory:
    _models = {
        "openai": OpenAIModel,
        "claude": ClaudeModel,
        "qwen": QwenModel
    }
    
    @classmethod
    def create(cls, model_type: str, **kwargs) -> BaseLLM:
        if model_type not in cls._models:
            raise ValueError(f"Unknown model type: {model_type}")
        
        return cls._models[model_type](**kwargs)

# 使用示例
llm = LLMFactory.create("openai", model_name="gpt-4o")

五、服务层设计

5.1 提示词服务

python 复制代码
# services/prompt_service.py
from typing import Dict, Any

class PromptService:
    def __init__(self):
        self.templates = {}
    
    def register_template(self, name: str, template: str):
        self.templates[name] = template
    
    def render(self, name: str, **kwargs) -> str:
        if name not in self.templates:
            raise ValueError(f"Template {name} not found")
        
        return self.templates[name].format(**kwargs)
    
    def create_system_prompt(self, role: str, task: str, **kwargs) -> str:
        template = """
        你是一位专业的{role}。
        
        任务:{task}
        
        {additional_info}
        
        请提供专业、准确的回答。
        """
        
        additional_info = "\n".join([f"{k}: {v}" for k, v in kwargs.items()])
        return template.format(role=role, task=task, additional_info=additional_info)

# 使用示例
prompt_service = PromptService()
prompt_service.register_template("customer_service", """
你是一位专业的客服助手。

用户问题:{{user_input}}

请根据以下知识回答:
{{knowledge}}
""")

prompt = prompt_service.render("customer_service", user_input="价格是多少?", knowledge="产品价格为 100 元")

5.2 RAG 服务

python 复制代码
# services/rag_service.py
from models.llm.base import BaseLLM
from utils.embedding_utils import get_embedding
from database.models import Document
from typing import List

class RAGService:
    def __init__(self, llm: BaseLLM):
        self.llm = llm
    
    def retrieve(self, query: str, top_k: int = 3) -> List[str]:
        # 获取查询向量
        query_embedding = get_embedding(query)
        
        # 向量检索
        documents = Document.search(query_embedding, top_k=top_k)
        
        return [doc.content for doc in documents]
    
    def generate(self, query: str, knowledge: List[str]) -> str:
        # 构建 Prompt
        knowledge_text = "\n\n".join(knowledge)
        
        prompt = f"""
        用户问题:{query}
        
        相关知识:
        {knowledge_text}
        
        请基于以上知识回答用户问题。
        """
        
        # 生成回答
        return self.llm.generate(prompt)
    
    def query(self, query: str, top_k: int = 3) -> str:
        # 检索知识
        knowledge = self.retrieve(query, top_k)
        
        # 生成回答
        return self.generate(query, knowledge)

5.3 Agent 服务

python 复制代码
# services/agent_service.py
from models.llm.base import BaseLLM
from services.prompt_service import PromptService
from services.rag_service import RAGService

class AgentService:
    def __init__(self, llm: BaseLLM):
        self.llm = llm
        self.prompt_service = PromptService()
        self.rag_service = RAGService(llm)
    
    def run(self, task: str, context: dict = None) -> str:
        # 获取上下文
        context = context or {}
        
        # 判断是否需要知识库
        if context.get("use_rag", True):
            knowledge = self.rag_service.retrieve(task)
            context["knowledge"] = "\n\n".join(knowledge)
        
        # 构建 Prompt
        prompt = self.prompt_service.create_system_prompt(
            role="智能助手",
            task=task,
            **context
        )
        
        # 生成回答
        return self.llm.generate(prompt)

六、API 层设计

6.1 路由定义

python 复制代码
# api/routers/chat.py
from fastapi import APIRouter, Depends, HTTPException
from api.schemas.chat import ChatRequest, ChatResponse
from services.agent_service import AgentService
from models.llm import LLMFactory

router = APIRouter(prefix="/chat", tags=["chat"])

def get_agent_service() -> AgentService:
    llm = LLMFactory.create("openai")
    return AgentService(llm)

@router.post("/", response_model=ChatResponse)
async def chat(request: ChatRequest, agent: AgentService = Depends(get_agent_service)):
    try:
        response = agent.run(request.message, request.context)
        return ChatResponse(response=response)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/stream")
async def chat_stream(request: ChatRequest, agent: AgentService = Depends(get_agent_service)):
    # 流式响应
    response = agent.run(request.message, request.context)
    return {"response": response}

6.2 请求/响应模型

python 复制代码
# api/schemas/chat.py
from pydantic import BaseModel
from typing import Optional, Dict

class ChatRequest(BaseModel):
    message: str
    context: Optional[Dict] = None
    use_rag: bool = True

class ChatResponse(BaseModel):
    response: str
    context: Optional[Dict] = None

6.3 依赖注入

python 复制代码
# api/dependencies.py
from fastapi import Depends
from models.llm import LLMFactory
from services.agent_service import AgentService

def get_llm():
    return LLMFactory.create("openai")

def get_agent_service(llm = Depends(get_llm)):
    return AgentService(llm)

def get_db_session():
    session = get_session()
    try:
        yield session
    finally:
        session.close()

七、数据库设计

7.1 ORM 模型

python 复制代码
# database/models.py
from sqlalchemy import Column, Integer, String, Text, DateTime
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime

Base = declarative_base()

class Document(Base):
    __tablename__ = "documents"
    
    id = Column(Integer, primary_key=True, index=True)
    title = Column(String, index=True)
    content = Column(Text)
    embedding = Column(Text)  # 存储向量
    source = Column(String)
    created_at = Column(DateTime, default=datetime.utcnow)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

class Conversation(Base):
    __tablename__ = "conversations"
    
    id = Column(Integer, primary_key=True, index=True)
    user_id = Column(String, index=True)
    messages = Column(Text)
    created_at = Column(DateTime, default=datetime.utcnow)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

class ChatMessage(Base):
    __tablename__ = "chat_messages"
    
    id = Column(Integer, primary_key=True, index=True)
    conversation_id = Column(Integer, index=True)
    role = Column(String)
    content = Column(Text)
    created_at = Column(DateTime, default=datetime.utcnow)

7.2 数据库会话

python 复制代码
# database/session.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from config.settings import settings

engine = create_engine(settings.DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

def get_session():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

八、工具模块

8.1 Token 工具

python 复制代码
# utils/token_utils.py
import tiktoken

def count_tokens(text: str, model: str = "gpt-4o") -> int:
    """计算文本的 Token 数量"""
    encoder = tiktoken.encoding_for_model(model)
    return len(encoder.encode(text))

def truncate_text(text: str, max_tokens: int = 4096, model: str = "gpt-4o") -> str:
    """截断文本到指定 Token 数量"""
    encoder = tiktoken.encoding_for_model(model)
    tokens = encoder.encode(text)
    
    if len(tokens) <= max_tokens:
        return text
    
    truncated_tokens = tokens[:max_tokens]
    return encoder.decode(truncated_tokens)

def estimate_cost(text: str, model: str = "gpt-4o") -> float:
    """估算 API 调用成本"""
    token_count = count_tokens(text, model)
    
    costs = {
        "gpt-4o": {"input": 5, "output": 15},
        "gpt-3.5-turbo": {"input": 0.5, "output": 1.5}
    }
    
    if model not in costs:
        return 0.0
    
    return (token_count * costs[model]["input"]) / 1000000

8.2 缓存工具

python 复制代码
# utils/cache_utils.py
import redis
from config.settings import settings
import json
from typing import Optional

class CacheService:
    def __init__(self):
        self.redis = redis.Redis.from_url(settings.REDIS_URL)
    
    def set(self, key: str, value: any, ttl: int = 3600):
        """设置缓存"""
        if isinstance(value, (dict, list)):
            value = json.dumps(value)
        self.redis.set(key, value, ex=ttl)
    
    def get(self, key: str) -> Optional[any]:
        """获取缓存"""
        value = self.redis.get(key)
        if value:
            try:
                return json.loads(value)
            except json.JSONDecodeError:
                return value.decode("utf-8")
        return None
    
    def delete(self, key: str):
        """删除缓存"""
        self.redis.delete(key)
    
    def exists(self, key: str) -> bool:
        """检查缓存是否存在"""
        return self.redis.exists(key) > 0

# 使用示例
cache = CacheService()
cache.set("user:123", {"name": "张三", "email": "zhangsan@example.com"})
user = cache.get("user:123")

九、日志与监控

9.1 日志配置

python 复制代码
# config/logger.py
import logging
from config.settings import settings
import sys

def setup_logger():
    logger = logging.getLogger("ai-project")
    logger.setLevel(settings.LOG_LEVEL)
    
    # 避免重复添加处理器
    if logger.handlers:
        return logger
    
    # 控制台处理器
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(settings.LOG_LEVEL)
    
    # 文件处理器
    file_handler = logging.FileHandler(settings.LOG_FILE)
    file_handler.setLevel(settings.LOG_LEVEL)
    
    # 格式化器
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    
    console_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)
    
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)
    
    return logger

logger = setup_logger()

9.2 监控指标

python 复制代码
# services/monitor_service.py
from prometheus_client import Counter, Histogram, Summary
from time import time

# 请求计数
REQUEST_COUNT = Counter("ai_requests_total", "Total requests", ["endpoint", "method"])

# 请求耗时
REQUEST_LATENCY = Histogram("ai_request_latency_seconds", "Request latency", ["endpoint"])

# Token 使用量
TOKEN_USAGE = Counter("ai_tokens_used", "Tokens used", ["model", "type"])

def monitor_request(endpoint: str, method: str = "POST"):
    """装饰器:监控请求"""
    def decorator(func):
        def wrapper(*args, **kwargs):
            REQUEST_COUNT.labels(endpoint=endpoint, method=method).inc()
            
            start_time = time()
            try:
                return func(*args, **kwargs)
            finally:
                REQUEST_LATENCY.labels(endpoint=endpoint).observe(time() - start_time)
        
        return wrapper
    return decorator

def record_token_usage(model: str, tokens: int, token_type: str = "input"):
    """记录 Token 使用量"""
    TOKEN_USAGE.labels(model=model, type=token_type).inc(tokens)

十、测试规范

10.1 测试结构

复制代码
tests/
├── unit/
│   ├── test_prompt_service.py
│   ├── test_rag_service.py
│   └── test_token_utils.py
├── integration/
│   ├── test_chat_api.py
│   └── test_database.py
└── fixtures/
    ├── test_data.py
    └── mock_llm.py

10.2 单元测试示例

python 复制代码
# tests/unit/test_prompt_service.py
import pytest
from services.prompt_service import PromptService

class TestPromptService:
    def setup_method(self):
        self.service = PromptService()
    
    def test_register_template(self):
        self.service.register_template("test", "Hello {{name}}")
        assert "test" in self.service.templates
    
    def test_render_template(self):
        self.service.register_template("test", "Hello {{name}}")
        result = self.service.render("test", name="World")
        assert result == "Hello World"
    
    def test_render_missing_template(self):
        with pytest.raises(ValueError):
            self.service.render("missing")
    
    def test_create_system_prompt(self):
        result = self.service.create_system_prompt(
            role="测试员",
            task="测试功能",
            additional="额外信息"
        )
        assert "测试员" in result
        assert "测试功能" in result

10.3 集成测试示例

python 复制代码
# tests/integration/test_chat_api.py
from fastapi.testclient import TestClient
from main import app

client = TestClient(app)

class TestChatAPI:
    def test_chat_endpoint(self):
        response = client.post(
            "/chat/",
            json={"message": "你好"}
        )
        assert response.status_code == 200
        assert "response" in response.json()
    
    def test_chat_with_context(self):
        response = client.post(
            "/chat/",
            json={
                "message": "价格是多少?",
                "context": {"product": "手机"}
            }
        )
        assert response.status_code == 200

十一、部署规范

11.1 Docker 配置

dockerfile 复制代码
# Dockerfile
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
yaml 复制代码
# docker-compose.yml
version: '3.8'

services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - APP_ENV=production
      - DATABASE_URL=postgresql://user:password@db:5432/app
      - REDIS_URL=redis://redis:6379
    depends_on:
      - db
      - redis
  
  db:
    image: postgres:15
    volumes:
      - postgres_data:/var/lib/postgresql/data
    environment:
      - POSTGRES_USER=user
      - POSTGRES_PASSWORD=password
      - POSTGRES_DB=app
  
  redis:
    image: redis:7-alpine
    volumes:
      - redis_data:/data

volumes:
  postgres_data:
  redis_data:

11.2 CI/CD 配置

yaml 复制代码
# .github/workflows/deploy.yml
name: Deploy

on:
  push:
    branches: [main]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.11'
      - name: Install dependencies
        run: pip install -r requirements.txt
      - name: Run tests
        run: pytest

  deploy:
    needs: test
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - name: Deploy to production
        run: ./scripts/deploy.sh

十二、代码风格规范

12.1 Python 代码风格

PEP 8 规范

  • 使用 4 空格缩进
  • 行长度不超过 120 字符
  • 使用 snake_case 命名变量和函数
  • 使用 PascalCase 命名类
  • 导入按顺序排列

代码示例

python 复制代码
# 好的示例
class UserService:
    def get_user_by_id(self, user_id: int) -> Optional[User]:
        """根据用户ID获取用户"""
        return self.db.query(User).filter(User.id == user_id).first()

# 不好的示例
class userService:  # 类名应使用 PascalCase
    def getUserById(self, userId):  # 方法名应使用 snake_case
        return self.db.query(User).filter(User.id == userId).first()

12.2 文档字符串规范

Google 风格

python 复制代码
def generate_report(data: dict, format: str = "pdf") -> str:
    """生成数据分析报告
    
    Args:
        data: 分析数据字典
        format: 输出格式,支持 pdf、html、markdown
    
    Returns:
        报告内容字符串
    
    Raises:
        ValueError: 当格式不支持时抛出
    """
    pass

12.3 Git 提交规范

Conventional Commits

复制代码
feat: 添加用户登录功能
fix: 修复登录页面样式问题
docs: 更新 API 文档
refactor: 重构用户服务代码
test: 添加单元测试

十三、总结

核心要点

  1. 项目结构:分层架构,职责清晰
  2. 配置管理:环境变量,集中管理
  3. 模型层:抽象基类,工厂模式
  4. 服务层:业务逻辑,组合复用
  5. API 层:RESTful,依赖注入
  6. 数据库:ORM 模型,会话管理
  7. 测试:单元测试 + 集成测试
  8. 部署:Docker,CI/CD

架构图

复制代码
┌─────────────────────────────────────────────────────────────┐
│                      API Layer                              │
│              (FastAPI, Routes, Schemas)                     │
├─────────────────────────────────────────────────────────────┤
│                      Service Layer                          │
│         (AgentService, RAGService, PromptService)           │
├─────────────────────────────────────────────────────────────┤
│                      Model Layer                            │
│              (LLM, Embedding, Classifier)                   │
├─────────────────────────────────────────────────────────────┤
│                      Database Layer                         │
│              (PostgreSQL, Redis, Vector DB)                 │
└─────────────────────────────────────────────────────────────┘

下一步建议

  1. 根据项目需求调整结构
  2. 使用依赖注入管理服务
  3. 添加完整的测试覆盖
  4. 配置监控和日志
  5. 实现 CI/CD 流程