项目十一:Saga模式分布式旅行预订系统 核心服务实现与Saga编排器

以下是完整的可执行代码:

1. 项目结构

plain

复制代码
travel-saga/
├── docker-compose.yml
├── init-db.sql
├── shared/
│   ├── __init__.py
│   ├── models.py
│   ├── events.py
│   ├── saga_state.py
│   ├── message_bus.py
│   └── utils.py
├── services/
│   ├── flight_service/
│   ├── hotel_service/
│   ├── car_service/
│   ├── payment_service/
│   └── saga_orchestrator/
└── tests/
    └── chaos_test.py

2. 共享基础设施层

shared/models.py

Python

复制代码
"""
共享数据模型与领域对象
"""
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum, auto
from typing import Optional, Dict, Any, List
import json
import uuid


class SagaStatus(Enum):
    """Saga执行状态"""
    PENDING = "PENDING"
    RUNNING = "RUNNING"
    COMPLETED = "COMPLETED"
    COMPENSATING = "COMPENSATING"
    COMPENSATED = "COMPENSATED"
    FAILED = "FAILED"
    SUSPENDED = "SUSPENDED"  # 悬挂状态,需要人工介入
    TIMEOUT = "TIMEOUT"


class StepStatus(Enum):
    """单个步骤执行状态"""
    PENDING = "PENDING"
    TRYING = "TRYING"           # TCC: Try阶段
    TRIED = "TRIED"             # TCC: Try成功
    CONFIRMING = "CONFIRMING"   # TCC: Confirm阶段
    CONFIRMED = "CONFIRMED"     # TCC: Confirm成功
    CANCELLING = "CANCELLING"   # TCC: Cancel阶段
    CANCELLED = "CANCELLED"     # TCC: Cancel成功
    FAILED = "FAILED"
    COMPENSATING = "COMPENSATING"
    COMPENSATED = "COMPENSATED"
    TIMEOUT = "TIMEOUT"


class ServiceType(Enum):
    """服务类型"""
    FLIGHT = "FLIGHT"
    HOTEL = "HOTEL"
    CAR = "CAR"
    PAYMENT = "PAYMENT"


@dataclass
class SagaContext:
    """Saga上下文,贯穿整个事务"""
    saga_id: str
    user_id: str
    correlation_id: str
    created_at: datetime = field(default_factory=datetime.utcnow)
    updated_at: datetime = field(default_factory=datetime.utcnow)
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "saga_id": self.saga_id,
            "user_id": self.user_id,
            "correlation_id": self.correlation_id,
            "created_at": self.created_at.isoformat(),
            "updated_at": self.updated_at.isoformat(),
            "metadata": self.metadata
        }


@dataclass
class FlightReservation:
    """航班预订"""
    flight_id: str
    seat_number: str
    passenger_name: str
    passenger_id: str
    departure_time: datetime
    arrival_time: datetime
    from_city: str
    to_city: str
    price: float
    lock_token: Optional[str] = None
    status: str = "PENDING"
    
    def to_dict(self):
        result = asdict(self)
        result['departure_time'] = self.departure_time.isoformat()
        result['arrival_time'] = self.arrival_time.isoformat()
        return result


@dataclass
class HotelReservation:
    """酒店预订"""
    hotel_id: str
    room_type: str
    check_in: datetime
    check_out: datetime
    guest_name: str
    guest_count: int
    price_per_night: float
    total_nights: int
    lock_token: Optional[str] = None
    status: str = "PENDING"
    
    def to_dict(self):
        result = asdict(self)
        result['check_in'] = self.check_in.isoformat()
        result['check_out'] = self.check_out.isoformat()
        return result


@dataclass
class CarReservation:
    """租车预订"""
    car_id: str
    car_model: str
    pickup_location: str
    dropoff_location: str
    pickup_time: datetime
    dropoff_time: datetime
    driver_name: str
    driver_license: str
    daily_rate: float
    total_days: int
    lock_token: Optional[str] = None
    status: str = "PENDING"
    
    def to_dict(self):
        result = asdict(self)
        result['pickup_time'] = self.pickup_time.isoformat()
        result['dropoff_time'] = self.dropoff_time.isoformat()
        return result


@dataclass
class PaymentAuthorization:
    """支付预授权"""
    payment_id: str
    amount: float
    currency: str = "CNY"
    card_token: str = ""
    auth_code: Optional[str] = None
    status: str = "PENDING"
    captured_amount: float = 0.0
    refunded_amount: float = 0.0
    
    def to_dict(self):
        return asdict(self)


@dataclass
class TravelBooking:
    """完整旅行预订请求"""
    booking_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str = ""
    flight: Optional[FlightReservation] = None
    hotel: Optional[HotelReservation] = None
    car: Optional[CarReservation] = None
    payment: Optional[PaymentAuthorization] = None
    total_amount: float = 0.0
    status: SagaStatus = SagaStatus.PENDING
    created_at: datetime = field(default_factory=datetime.utcnow)
    
    def to_dict(self):
        return {
            "booking_id": self.booking_id,
            "user_id": self.user_id,
            "flight": self.flight.to_dict() if self.flight else None,
            "hotel": self.hotel.to_dict() if self.hotel else None,
            "car": self.car.to_dict() if self.car else None,
            "payment": self.payment.to_dict() if self.payment else None,
            "total_amount": self.total_amount,
            "status": self.status.value,
            "created_at": self.created_at.isoformat()
        }


@dataclass
class SagaStep:
    """Saga步骤定义"""
    step_id: str
    step_number: int
    service_type: ServiceType
    operation: str  # "TRY", "CONFIRM", "CANCEL"
    status: StepStatus = StepStatus.PENDING
    request_payload: Dict[str, Any] = field(default_factory=dict)
    response_payload: Dict[str, Any] = field(default_factory=dict)
    error_message: Optional[str] = None
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    retry_count: int = 0
    max_retries: int = 3
    timeout_seconds: int = 30
    compensation_payload: Optional[Dict[str, Any]] = None
    dependencies: List[str] = field(default_factory=list)  # 依赖的步骤ID
    
    def to_dict(self):
        return {
            "step_id": self.step_id,
            "step_number": self.step_number,
            "service_type": self.service_type.value,
            "operation": self.operation,
            "status": self.status.value,
            "request_payload": self.request_payload,
            "response_payload": self.response_payload,
            "error_message": self.error_message,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
            "retry_count": self.retry_count,
            "max_retries": self.max_retries,
            "timeout_seconds": self.timeout_seconds,
            "compensation_payload": self.compensation_payload,
            "dependencies": self.dependencies
        }


@dataclass
class SagaExecution:
    """Saga执行实例"""
    execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    saga_id: str = ""
    status: SagaStatus = SagaStatus.PENDING
    steps: List[SagaStep] = field(default_factory=list)
    current_step_index: int = 0
    context: Dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.utcnow)
    updated_at: datetime = field(default_factory=datetime.utcnow)
    completed_at: Optional[datetime] = None
    error_message: Optional[str] = None
    compensation_triggered_at: Optional[datetime] = None
    suspended_at: Optional[datetime] = None
    suspend_reason: Optional[str] = None
    
    def to_dict(self):
        return {
            "execution_id": self.execution_id,
            "saga_id": self.saga_id,
            "status": self.status.value,
            "steps": [s.to_dict() for s in self.steps],
            "current_step_index": self.current_step_index,
            "context": self.context,
            "created_at": self.created_at.isoformat(),
            "updated_at": self.updated_at.isoformat(),
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
            "error_message": self.error_message,
            "compensation_triggered_at": self.compensation_triggered_at.isoformat() if self.compensation_triggered_at else None,
            "suspended_at": self.suspended_at.isoformat() if self.suspended_at else None,
            "suspend_reason": self.suspend_reason
        }


class IdempotencyKey:
    """幂等性控制"""
    def __init__(self, key: str, operation: str, created_at: Optional[datetime] = None):
        self.key = key
        self.operation = operation
        self.created_at = created_at or datetime.utcnow()
        self.processed = False
        self.result: Optional[Dict[str, Any]] = None
    
    def to_dict(self):
        return {
            "key": self.key,
            "operation": self.operation,
            "created_at": self.created_at.isoformat(),
            "processed": self.processed,
            "result": self.result
        }

shared/events.py

Python

复制代码
"""
事件定义 - 命令事件与领域事件
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, Optional
import uuid


class EventType:
    """命令事件类型"""
    # 航班服务命令
    CMD_LOCK_SEAT = "CMD_LOCK_SEAT"
    CMD_CONFIRM_SEAT = "CMD_CONFIRM_SEAT"
    CMD_CANCEL_SEAT = "CMD_CANCEL_SEAT"
    CMD_RELEASE_SEAT = "CMD_RELEASE_SEAT"
    
    # 酒店服务命令
    CMD_LOCK_ROOM = "CMD_LOCK_ROOM"
    CMD_CONFIRM_ROOM = "CMD_CONFIRM_ROOM"
    CMD_CANCEL_ROOM = "CMD_CANCEL_ROOM"
    CMD_RELEASE_ROOM = "CMD_RELEASE_ROOM"
    
    # 租车服务命令
    CMD_LOCK_CAR = "CMD_LOCK_CAR"
    CMD_CONFIRM_CAR = "CMD_CONFIRM_CAR"
    CMD_CANCEL_CAR = "CMD_CANCEL_CAR"
    CMD_RELEASE_CAR = "CMD_RELEASE_CAR"
    
    # 支付服务命令
    CMD_PREAUTH_PAYMENT = "CMD_PREAUTH_PAYMENT"
    CMD_CAPTURE_PAYMENT = "CMD_CAPTURE_PAYMENT"
    CMD_REFUND_PAYMENT = "CMD_REFUND_PAYMENT"
    CMD_CANCEL_PREAUTH = "CMD_CANCEL_PREAUTH"
    
    # Saga编排器命令
    CMD_START_SAGA = "CMD_START_SAGA"
    CMD_COMPENSATE_SAGA = "CMD_COMPENSATE_SAGA"
    CMD_RETRY_STEP = "CMD_RETRY_STEP"
    CMD_SUSPEND_SAGA = "CMD_SUSPEND_SAGA"
    CMD_RESUME_SAGA = "CMD_RESUME_SAGA"
    
    # 领域事件
    EVT_SEAT_LOCKED = "EVT_SEAT_LOCKED"
    EVT_SEAT_CONFIRMED = "EVT_SEAT_CONFIRMED"
    EVT_SEAT_CANCELLED = "EVT_SEAT_CANCELLED"
    EVT_SEAT_LOCK_FAILED = "EVT_SEAT_LOCK_FAILED"
    
    EVT_ROOM_LOCKED = "EVT_ROOM_LOCKED"
    EVT_ROOM_CONFIRMED = "EVT_ROOM_CONFIRMED"
    EVT_ROOM_CANCELLED = "EVT_ROOM_CANCELLED"
    EVT_ROOM_LOCK_FAILED = "EVT_ROOM_LOCK_FAILED"
    
    EVT_CAR_LOCKED = "EVT_CAR_LOCKED"
    EVT_CAR_CONFIRMED = "EVT_CAR_CONFIRMED"
    EVT_CAR_CANCELLED = "EVT_CAR_CANCELLED"
    EVT_CAR_LOCK_FAILED = "EVT_CAR_LOCK_FAILED"
    
    EVT_PAYMENT_PREAUTHED = "EVT_PAYMENT_PREAUTHED"
    EVT_PAYMENT_CAPTURED = "EVT_PAYMENT_CAPTURED"
    EVT_PAYMENT_REFUNDED = "EVT_PAYMENT_REFUNDED"
    EVT_PAYMENT_PREAUTH_FAILED = "EVT_PAYMENT_PREAUTH_FAILED"
    
    EVT_SAGA_STARTED = "EVT_SAGA_STARTED"
    EVT_SAGA_STEP_COMPLETED = "EVT_SAGA_STEP_COMPLETED"
    EVT_SAGA_STEP_FAILED = "EVT_SAGA_STEP_FAILED"
    EVT_SAGA_COMPLETED = "EVT_SAGA_COMPLETED"
    EVT_SAGA_COMPENSATION_STARTED = "EVT_SAGA_COMPENSATION_STARTED"
    EVT_SAGA_COMPENSATED = "EVT_SAGA_COMPENSATED"
    EVT_SAGA_FAILED = "EVT_SAGA_FAILED"
    EVT_SAGA_SUSPENDED = "EVT_SAGA_SUSPENDED"


@dataclass
class DomainEvent:
    """领域事件基类"""
    event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    event_type: str = ""
    saga_id: str = ""
    step_id: str = ""
    correlation_id: str = ""
    timestamp: datetime = field(default_factory=datetime.utcnow)
    payload: Dict[str, Any] = field(default_factory=dict)
    partition_key: str = ""  # 用于消息分区保证顺序
    
    def to_dict(self):
        return {
            "event_id": self.event_id,
            "event_type": self.event_type,
            "saga_id": self.saga_id,
            "step_id": self.step_id,
            "correlation_id": self.correlation_id,
            "timestamp": self.timestamp.isoformat(),
            "payload": self.payload,
            "partition_key": self.partition_key
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'DomainEvent':
        return cls(
            event_id=data.get("event_id", str(uuid.uuid4())),
            event_type=data["event_type"],
            saga_id=data.get("saga_id", ""),
            step_id=data.get("step_id", ""),
            correlation_id=data.get("correlation_id", ""),
            timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
            payload=data.get("payload", {}),
            partition_key=data.get("partition_key", "")
        )


@dataclass
class CommandMessage:
    """命令消息"""
    command_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    command_type: str = ""
    saga_id: str = ""
    step_id: str = ""
    correlation_id: str = ""
    idempotency_key: str = ""
    target_service: str = ""
    payload: Dict[str, Any] = field(default_factory=dict)
    timestamp: datetime = field(default_factory=datetime.utcnow)
    ttl_seconds: int = 300  # 消息存活时间
    priority: int = 5  # 1-10, 数字越小优先级越高
    
    def to_dict(self):
        return {
            "command_id": self.command_id,
            "command_type": self.command_type,
            "saga_id": self.saga_id,
            "step_id": self.step_id,
            "correlation_id": self.correlation_id,
            "idempotency_key": self.idempotency_key,
            "target_service": self.target_service,
            "payload": self.payload,
            "timestamp": self.timestamp.isoformat(),
            "ttl_seconds": self.ttl_seconds,
            "priority": self.priority
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'CommandMessage':
        return cls(
            command_id=data.get("command_id", str(uuid.uuid4())),
            command_type=data["command_type"],
            saga_id=data.get("saga_id", ""),
            step_id=data.get("step_id", ""),
            correlation_id=data.get("correlation_id", ""),
            idempotency_key=data.get("idempotency_key", ""),
            target_service=data.get("target_service", ""),
            payload=data.get("payload", {}),
            timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
            ttl_seconds=data.get("ttl_seconds", 300),
            priority=data.get("priority", 5)
        )


@dataclass
class ResponseMessage:
    """响应消息"""
    response_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    command_id: str = ""
    saga_id: str = ""
    step_id: str = ""
    correlation_id: str = ""
    success: bool = False
    result: Dict[str, Any] = field(default_factory=dict)
    error_code: Optional[str] = None
    error_message: Optional[str] = None
    timestamp: datetime = field(default_factory=datetime.utcnow)
    processing_time_ms: int = 0
    
    def to_dict(self):
        return {
            "response_id": self.response_id,
            "command_id": self.command_id,
            "saga_id": self.saga_id,
            "step_id": self.step_id,
            "correlation_id": self.correlation_id,
            "success": self.success,
            "result": self.result,
            "error_code": self.error_code,
            "error_message": self.error_message,
            "timestamp": self.timestamp.isoformat(),
            "processing_time_ms": self.processing_time_ms
        }

shared/saga_state.py

Python

复制代码
"""
Saga状态持久化 - 数据库记录与恢复
"""
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
import sqlite3
import threading
from contextlib import contextmanager

logger = logging.getLogger(__name__)


class SagaStateStore:
    """
    Saga执行日志数据库记录与恢复
    支持幂等性控制、状态持久化、悬挂事务检测
    """
    
    def __init__(self, db_path: str = "saga_state.db"):
        self.db_path = db_path
        self._local = threading.local()
        self._init_db()
    
    def _get_connection(self) -> sqlite3.Connection:
        if not hasattr(self._local, 'connection'):
            self._local.connection = sqlite3.connect(self.db_path, check_same_thread=False)
            self._local.connection.row_factory = sqlite3.Row
        return self._local.connection
    
    @contextmanager
    def _transaction(self):
        conn = self._get_connection()
        try:
            conn.execute("BEGIN IMMEDIATE")
            yield conn
            conn.commit()
        except Exception as e:
            conn.rollback()
            raise e
    
    def _init_db(self):
        """初始化数据库表结构"""
        with self._transaction() as conn:
            # Saga执行记录表
            conn.execute("""
                CREATE TABLE IF NOT EXISTS saga_executions (
                    execution_id TEXT PRIMARY KEY,
                    saga_id TEXT NOT NULL,
                    status TEXT NOT NULL,
                    steps_json TEXT NOT NULL,
                    current_step_index INTEGER DEFAULT 0,
                    context_json TEXT DEFAULT '{}',
                    created_at TEXT NOT NULL,
                    updated_at TEXT NOT NULL,
                    completed_at TEXT,
                    error_message TEXT,
                    compensation_triggered_at TEXT,
                    suspended_at TEXT,
                    suspend_reason TEXT,
                    user_id TEXT,
                    total_amount REAL DEFAULT 0,
                    INDEX idx_saga_id (saga_id),
                    INDEX idx_status (status),
                    INDEX idx_user_id (user_id),
                    INDEX idx_created_at (created_at)
                )
            """)
            
            # Saga步骤执行记录表
            conn.execute("""
                CREATE TABLE IF NOT EXISTS saga_steps (
                    step_id TEXT PRIMARY KEY,
                    execution_id TEXT NOT NULL,
                    saga_id TEXT NOT NULL,
                    step_number INTEGER NOT NULL,
                    service_type TEXT NOT NULL,
                    operation TEXT NOT NULL,
                    status TEXT NOT NULL,
                    request_payload TEXT,
                    response_payload TEXT,
                    error_message TEXT,
                    started_at TEXT,
                    completed_at TEXT,
                    retry_count INTEGER DEFAULT 0,
                    max_retries INTEGER DEFAULT 3,
                    timeout_seconds INTEGER DEFAULT 30,
                    compensation_payload TEXT,
                    dependencies TEXT,
                    FOREIGN KEY (execution_id) REFERENCES saga_executions(execution_id),
                    INDEX idx_execution_id (execution_id),
                    INDEX idx_saga_id (saga_id)
                )
            """)
            
            # 幂等性控制表
            conn.execute("""
                CREATE TABLE IF NOT EXISTS idempotency_keys (
                    key TEXT PRIMARY KEY,
                    operation TEXT NOT NULL,
                    created_at TEXT NOT NULL,
                    processed INTEGER DEFAULT 0,
                    result_json TEXT,
                    saga_id TEXT,
                    INDEX idx_operation (operation),
                    INDEX idx_created_at (created_at)
                )
            """)
            
            # 事件日志表 (用于审计与恢复)
            conn.execute("""
                CREATE TABLE IF NOT EXISTS event_log (
                    event_id TEXT PRIMARY KEY,
                    event_type TEXT NOT NULL,
                    saga_id TEXT NOT NULL,
                    step_id TEXT,
                    execution_id TEXT,
                    payload TEXT,
                    timestamp TEXT NOT NULL,
                    partition_key TEXT,
                    INDEX idx_saga_id (saga_id),
                    INDEX idx_event_type (event_type),
                    INDEX idx_timestamp (timestamp)
                )
            """)
            
            # 悬挂事务告警表
            conn.execute("""
                CREATE TABLE IF NOT EXISTS suspended_transactions (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    saga_id TEXT NOT NULL,
                    execution_id TEXT NOT NULL,
                    suspended_at TEXT NOT NULL,
                    reason TEXT NOT NULL,
                    alert_sent INTEGER DEFAULT 0,
                    resolved INTEGER DEFAULT 0,
                    resolved_at TEXT,
                    resolved_by TEXT,
                    resolution_notes TEXT,
                    INDEX idx_saga_id (saga_id),
                    INDEX idx_alert_sent (alert_sent),
                    INDEX idx_suspended_at (suspended_at)
                )
            """)
            
            # 监控指标表
            conn.execute("""
                CREATE TABLE IF NOT EXISTS saga_metrics (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    saga_id TEXT,
                    event_type TEXT NOT NULL,
                    service_type TEXT,
                    step_number INTEGER,
                    success INTEGER,
                    processing_time_ms INTEGER,
                    timestamp TEXT NOT NULL,
                    INDEX idx_timestamp (timestamp),
                    INDEX idx_event_type (event_type)
                )
            """)
            
            conn.commit()
    
    def save_execution(self, execution: Dict[str, Any]) -> bool:
        """保存或更新Saga执行状态"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    INSERT OR REPLACE INTO saga_executions (
                        execution_id, saga_id, status, steps_json, current_step_index,
                        context_json, created_at, updated_at, completed_at, error_message,
                        compensation_triggered_at, suspended_at, suspend_reason, user_id, total_amount
                    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    execution["execution_id"],
                    execution["saga_id"],
                    execution["status"],
                    json.dumps(execution.get("steps", [])),
                    execution.get("current_step_index", 0),
                    json.dumps(execution.get("context", {})),
                    execution["created_at"],
                    execution["updated_at"],
                    execution.get("completed_at"),
                    execution.get("error_message"),
                    execution.get("compensation_triggered_at"),
                    execution.get("suspended_at"),
                    execution.get("suspend_reason"),
                    execution.get("context", {}).get("user_id"),
                    execution.get("context", {}).get("total_amount", 0)
                ))
                return True
        except Exception as e:
            logger.error(f"Failed to save execution: {e}")
            return False
    
    def load_execution(self, execution_id: str) -> Optional[Dict[str, Any]]:
        """加载Saga执行状态"""
        try:
            with self._transaction() as conn:
                cursor = conn.execute(
                    "SELECT * FROM saga_executions WHERE execution_id = ?",
                    (execution_id,)
                )
                row = cursor.fetchone()
                if row:
                    return {
                        "execution_id": row["execution_id"],
                        "saga_id": row["saga_id"],
                        "status": row["status"],
                        "steps": json.loads(row["steps_json"]),
                        "current_step_index": row["current_step_index"],
                        "context": json.loads(row["context_json"]),
                        "created_at": row["created_at"],
                        "updated_at": row["updated_at"],
                        "completed_at": row["completed_at"],
                        "error_message": row["error_message"],
                        "compensation_triggered_at": row["compensation_triggered_at"],
                        "suspended_at": row["suspended_at"],
                        "suspend_reason": row["suspend_reason"]
                    }
                return None
        except Exception as e:
            logger.error(f"Failed to load execution: {e}")
            return None
    
    def load_unfinished_executions(self) -> List[Dict[str, Any]]:
        """加载所有未完成的执行(用于系统恢复)"""
        try:
            with self._transaction() as conn:
                cursor = conn.execute("""
                    SELECT * FROM saga_executions 
                    WHERE status NOT IN ('COMPLETED', 'COMPENSATED', 'FAILED')
                    ORDER BY created_at ASC
                """)
                rows = cursor.fetchall()
                return [self._row_to_dict(row) for row in rows]
        except Exception as e:
            logger.error(f"Failed to load unfinished executions: {e}")
            return []
    
    def _row_to_dict(self, row) -> Dict[str, Any]:
        return {
            "execution_id": row["execution_id"],
            "saga_id": row["saga_id"],
            "status": row["status"],
            "steps": json.loads(row["steps_json"]),
            "current_step_index": row["current_step_index"],
            "context": json.loads(row["context_json"]),
            "created_at": row["created_at"],
            "updated_at": row["updated_at"],
            "completed_at": row["completed_at"],
            "error_message": row["error_message"],
            "compensation_triggered_at": row["compensation_triggered_at"],
            "suspended_at": row["suspended_at"],
            "suspend_reason": row["suspend_reason"]
        }
    
    def save_step(self, step: Dict[str, Any]) -> bool:
        """保存步骤状态"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    INSERT OR REPLACE INTO saga_steps (
                        step_id, execution_id, saga_id, step_number, service_type,
                        operation, status, request_payload, response_payload, error_message,
                        started_at, completed_at, retry_count, max_retries, timeout_seconds,
                        compensation_payload, dependencies
                    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    step["step_id"],
                    step.get("execution_id", ""),
                    step.get("saga_id", ""),
                    step["step_number"],
                    step["service_type"],
                    step["operation"],
                    step["status"],
                    json.dumps(step.get("request_payload", {})),
                    json.dumps(step.get("response_payload", {})),
                    step.get("error_message"),
                    step.get("started_at"),
                    step.get("completed_at"),
                    step.get("retry_count", 0),
                    step.get("max_retries", 3),
                    step.get("timeout_seconds", 30),
                    json.dumps(step.get("compensation_payload")) if step.get("compensation_payload") else None,
                    json.dumps(step.get("dependencies", []))
                ))
                return True
        except Exception as e:
            logger.error(f"Failed to save step: {e}")
            return False
    
    def check_idempotency(self, key: str) -> Optional[Dict[str, Any]]:
        """检查幂等性键"""
        try:
            with self._transaction() as conn:
                cursor = conn.execute(
                    "SELECT * FROM idempotency_keys WHERE key = ?",
                    (key,)
                )
                row = cursor.fetchone()
                if row:
                    return {
                        "key": row["key"],
                        "operation": row["operation"],
                        "created_at": row["created_at"],
                        "processed": bool(row["processed"]),
                        "result": json.loads(row["result_json"]) if row["result_json"] else None,
                        "saga_id": row["saga_id"]
                    }
                return None
        except Exception as e:
            logger.error(f"Failed to check idempotency: {e}")
            return None
    
    def save_idempotency_key(self, key: str, operation: str, saga_id: str = "") -> bool:
        """保存幂等性键"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    INSERT OR IGNORE INTO idempotency_keys (key, operation, created_at, saga_id)
                    VALUES (?, ?, ?, ?)
                """, (key, operation, datetime.utcnow().isoformat(), saga_id))
                return True
        except Exception as e:
            logger.error(f"Failed to save idempotency key: {e}")
            return False
    
    def mark_idempotency_processed(self, key: str, result: Dict[str, Any]) -> bool:
        """标记幂等性键已处理"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    UPDATE idempotency_keys 
                    SET processed = 1, result_json = ?
                    WHERE key = ?
                """, (json.dumps(result), key))
                return True
        except Exception as e:
            logger.error(f"Failed to mark idempotency processed: {e}")
            return False
    
    def log_event(self, event: Dict[str, Any]) -> bool:
        """记录事件日志"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    INSERT INTO event_log (event_id, event_type, saga_id, step_id, execution_id, payload, timestamp, partition_key)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    event.get("event_id", ""),
                    event["event_type"],
                    event.get("saga_id", ""),
                    event.get("step_id", ""),
                    event.get("execution_id", ""),
                    json.dumps(event.get("payload", {})),
                    event.get("timestamp", datetime.utcnow().isoformat()),
                    event.get("partition_key", "")
                ))
                return True
        except Exception as e:
            logger.error(f"Failed to log event: {e}")
            return False
    
    def record_metric(self, metric: Dict[str, Any]) -> bool:
        """记录监控指标"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    INSERT INTO saga_metrics (saga_id, event_type, service_type, step_number, success, processing_time_ms, timestamp)
                    VALUES (?, ?, ?, ?, ?, ?, ?)
                """, (
                    metric.get("saga_id"),
                    metric["event_type"],
                    metric.get("service_type"),
                    metric.get("step_number"),
                    1 if metric.get("success") else 0,
                    metric.get("processing_time_ms", 0),
                    metric.get("timestamp", datetime.utcnow().isoformat())
                ))
                return True
        except Exception as e:
            logger.error(f"Failed to record metric: {e}")
            return False
    
    def create_suspension_alert(self, saga_id: str, execution_id: str, reason: str) -> bool:
        """创建悬挂事务告警"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    INSERT INTO suspended_transactions (saga_id, execution_id, suspended_at, reason)
                    VALUES (?, ?, ?, ?)
                """, (saga_id, execution_id, datetime.utcnow().isoformat(), reason))
                return True
        except Exception as e:
            logger.error(f"Failed to create suspension alert: {e}")
            return False
    
    def get_suspended_transactions(self, alert_sent_only: bool = False) -> List[Dict[str, Any]]:
        """获取悬挂事务列表"""
        try:
            with self._transaction() as conn:
                if alert_sent_only:
                    cursor = conn.execute("""
                        SELECT * FROM suspended_transactions 
                        WHERE alert_sent = 0 AND resolved = 0
                        ORDER BY suspended_at ASC
                    """)
                else:
                    cursor = conn.execute("""
                        SELECT * FROM suspended_transactions 
                        WHERE resolved = 0
                        ORDER BY suspended_at ASC
                    """)
                rows = cursor.fetchall()
                return [{
                    "id": row["id"],
                    "saga_id": row["saga_id"],
                    "execution_id": row["execution_id"],
                    "suspended_at": row["suspended_at"],
                    "reason": row["reason"],
                    "alert_sent": bool(row["alert_sent"]),
                    "resolved": bool(row["resolved"])
                } for row in rows]
        except Exception as e:
            logger.error(f"Failed to get suspended transactions: {e}")
            return []
    
    def resolve_suspension(self, saga_id: str, resolved_by: str, notes: str) -> bool:
        """人工解决悬挂事务"""
        try:
            with self._transaction() as conn:
                conn.execute("""
                    UPDATE suspended_transactions 
                    SET resolved = 1, resolved_at = ?, resolved_by = ?, resolution_notes = ?
                    WHERE saga_id = ? AND resolved = 0
                """, (datetime.utcnow().isoformat(), resolved_by, notes, saga_id))
                
                # 同时更新saga执行状态
                conn.execute("""
                    UPDATE saga_executions 
                    SET status = 'FAILED', updated_at = ?, error_message = ?
                    WHERE saga_id = ? AND status = 'SUSPENDED'
                """, (datetime.utcnow().isoformat(), f"Manually resolved by {resolved_by}: {notes}", saga_id))
                return True
        except Exception as e:
            logger.error(f"Failed to resolve suspension: {e}")
            return False
    
    def get_metrics_summary(self, hours: int = 24) -> Dict[str, Any]:
        """获取监控指标摘要"""
        try:
            with self._transaction() as conn:
                from_time = datetime.utcnow().isoformat()
                # 简化实现,实际应计算时间范围
                cursor = conn.execute("""
                    SELECT 
                        COUNT(*) as total,
                        SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as success_count,
                        AVG(processing_time_ms) as avg_time
                    FROM saga_metrics
                    WHERE timestamp > datetime('now', '-{} hours')
                """.format(hours))
                row = cursor.fetchone()
                
                cursor = conn.execute("""
                    SELECT event_type, COUNT(*) as count
                    FROM saga_metrics
                    WHERE timestamp > datetime('now', '-{} hours')
                    GROUP BY event_type
                """.format(hours))
                event_counts = {row["event_type"]: row["count"] for row in cursor.fetchall()}
                
                return {
                    "total_events": row["total"] or 0,
                    "success_count": row["success_count"] or 0,
                    "failure_count": (row["total"] or 0) - (row["success_count"] or 0),
                    "avg_processing_time_ms": row["avg_time"] or 0,
                    "event_breakdown": event_counts,
                    "period_hours": hours
                }
        except Exception as e:
            logger.error(f"Failed to get metrics summary: {e}")
            return {}

shared/message_bus.py

Python

复制代码
"""
消息基础设施 - 至少一次交付、去重消费、顺序保证
"""
import json
import logging
import queue
import threading
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, List, Callable, Optional, Set
import uuid

logger = logging.getLogger(__name__)


class PartitionedMessageQueue:
    """
    分区消息队列
    使用分区键保证单用户操作顺序
    """
    
    def __init__(self, partition_count: int = 16):
        self.partition_count = partition_count
        self._queues: Dict[int, queue.Queue] = {
            i: queue.Queue() for i in range(partition_count)
        }
        self._lock = threading.Lock()
        self._delivered: Set[str] = set()  # 已交付消息ID
        self._delivered_lock = threading.Lock()
        self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
        self._subscriber_lock = threading.Lock()
        self._running = False
        self._threads: List[threading.Thread] = []
        self._delivery_attempts: Dict[str, int] = {}  # 交付尝试次数
        self._max_delivery_attempts = 3
    
    def _get_partition(self, partition_key: str) -> int:
        """根据分区键计算分区"""
        if not partition_key:
            return 0
        return hash(partition_key) % self.partition_count
    
    def publish(self, message: Dict[str, Any], partition_key: str = "") -> bool:
        """
        发布消息到指定分区
        保证相同partition_key的消息顺序消费
        """
        try:
            msg_id = message.get("event_id") or message.get("command_id") or str(uuid.uuid4())
            message["_msg_id"] = msg_id
            message["_partition_key"] = partition_key
            message["_publish_time"] = datetime.utcnow().isoformat()
            
            partition = self._get_partition(partition_key)
            self._queues[partition].put(message)
            logger.debug(f"Published message {msg_id} to partition {partition}")
            return True
        except Exception as e:
            logger.error(f"Failed to publish message: {e}")
            return False
    
    def subscribe(self, event_type: str, handler: Callable[[Dict[str, Any]], None]):
        """订阅特定事件类型"""
        with self._subscriber_lock:
            self._subscribers[event_type].append(handler)
            logger.info(f"Subscribed handler to event type: {event_type}")
    
    def start_consuming(self):
        """启动消费者线程"""
        self._running = True
        for i in range(self.partition_count):
            thread = threading.Thread(
                target=self._consume_partition,
                args=(i,),
                name=f"PartitionConsumer-{i}",
                daemon=True
            )
            thread.start()
            self._threads.append(thread)
        logger.info(f"Started {self.partition_count} partition consumers")
    
    def stop_consuming(self):
        """停止消费"""
        self._running = False
        for thread in self._threads:
            thread.join(timeout=5)
        logger.info("Stopped all partition consumers")
    
    def _consume_partition(self, partition: int):
        """消费指定分区的消息"""
        q = self._queues[partition]
        while self._running:
            try:
                message = q.get(timeout=1)
                self._process_message(message)
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"Error consuming partition {partition}: {e}")
    
    def _process_message(self, message: Dict[str, Any]):
        """处理消息 - 至少一次交付 + 去重"""
        msg_id = message.get("_msg_id", "")
        
        # 去重检查
        with self._delivered_lock:
            if msg_id in self._delivered:
                logger.debug(f"Message {msg_id} already delivered, skipping")
                return
            self._delivered.add(msg_id)
            
            # 清理旧的去重记录(保留最近10000条)
            if len(self._delivered) > 10000:
                self._delivered = set(list(self._delivered)[-5000:])
        
        # 获取事件类型
        event_type = message.get("event_type") or message.get("command_type", "")
        
        # 调用订阅者
        with self._subscriber_lock:
            handlers = self._subscribers.get(event_type, [])
        
        success = False
        for handler in handlers:
            try:
                handler(message)
                success = True
            except Exception as e:
                logger.error(f"Handler failed for message {msg_id}: {e}")
        
        # 记录交付尝试
        self._delivery_attempts[msg_id] = self._delivery_attempts.get(msg_id, 0) + 1
        
        # 如果处理失败且未达到最大尝试次数,重新入队
        if not success and self._delivery_attempts[msg_id] < self._max_delivery_attempts:
            partition_key = message.get("_partition_key", "")
            partition = self._get_partition(partition_key)
            # 延迟重试
            time.sleep(0.1 * self._delivery_attempts[msg_id])
            self._queues[partition].put(message)
            logger.warning(f"Requeued message {msg_id} for retry {self._delivery_attempts[msg_id]}")
        elif not success:
            logger.error(f"Message {msg_id} failed after {self._max_delivery_attempts} attempts")
    
    def get_queue_depth(self) -> Dict[int, int]:
        """获取各分区队列深度"""
        return {i: q.qsize() for i, q in self._queues.items()}


class InMemoryMessageBus:
    """
    内存消息总线 - 模拟消息队列中间件
    支持命令消息、事件消息、可靠交付
    """
    
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
        self._initialized = True
        
        self.command_queue = PartitionedMessageQueue(partition_count=8)
        self.event_queue = PartitionedMessageQueue(partition_count=16)
        self.response_queue = PartitionedMessageQueue(partition_count=4)
        
        # 响应等待器 - 用于同步等待响应
        self._pending_responses: Dict[str, threading.Event] = {}
        self._response_data: Dict[str, Dict[str, Any]] = {}
        self._response_lock = threading.Lock()
        
        # 启动消费
        self.command_queue.start_consuming()
        self.event_queue.start_consuming()
        self.response_queue.start_consuming()
        
        # 响应队列订阅
        self.response_queue.subscribe("RESPONSE", self._handle_response)
    
    def _handle_response(self, message: Dict[str, Any]):
        """处理响应消息"""
        command_id = message.get("command_id", "")
        with self._response_lock:
            self._response_data[command_id] = message
            event = self._pending_responses.get(command_id)
            if event:
                event.set()
    
    def publish_command(self, command: Dict[str, Any], partition_key: str = "") -> bool:
        """发布命令消息"""
        return self.command_queue.publish(command, partition_key)
    
    def publish_event(self, event: Dict[str, Any], partition_key: str = "") -> bool:
        """发布事件消息"""
        return self.event_queue.publish(event, partition_key)
    
    def publish_response(self, response: Dict[str, Any], partition_key: str = "") -> bool:
        """发布响应消息"""
        return self.response_queue.publish(response, partition_key)
    
    def subscribe_command(self, command_type: str, handler: Callable[[Dict[str, Any]], None]):
        """订阅命令"""
        self.command_queue.subscribe(command_type, handler)
    
    def subscribe_event(self, event_type: str, handler: Callable[[Dict[str, Any]], None]):
        """订阅事件"""
        self.event_queue.subscribe(event_type, handler)
    
    def wait_for_response(self, command_id: str, timeout: float = 30.0) -> Optional[Dict[str, Any]]:
        """同步等待响应"""
        event = threading.Event()
        with self._response_lock:
            self._pending_responses[command_id] = event
        
        if event.wait(timeout=timeout):
            with self._response_lock:
                return self._response_data.pop(command_id, None)
        else:
            with self._response_lock:
                self._pending_responses.pop(command_id, None)
            return None
    
    def get_stats(self) -> Dict[str, Any]:
        """获取消息总线统计"""
        return {
            "command_queue_depth": self.command_queue.get_queue_depth(),
            "event_queue_depth": self.event_queue.get_queue_depth(),
            "response_queue_depth": self.response_queue.get_queue_depth(),
            "pending_responses": len(self._pending_responses)
        }

shared/utils.py

Python

复制代码
"""
工具函数
"""
import hashlib
import json
import logging
import time
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Callable, Dict, Optional

logger = logging.getLogger(__name__)


def generate_idempotency_key(*args) -> str:
    """生成幂等性键"""
    content = "|".join(str(arg) for arg in args)
    return hashlib.sha256(content.encode()).hexdigest()[:32]


def generate_uuid() -> str:
    """生成UUID"""
    return str(uuid.uuid4())


def now_iso() -> str:
    """当前ISO格式时间"""
    return datetime.utcnow().isoformat()


def retry_with_backoff(max_retries: int = 3, base_delay: float = 0.1, max_delay: float = 5.0):
    """带退避的重试装饰器"""
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_exception = None
            for attempt in range(max_retries):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    last_exception = e
                    if attempt < max_retries - 1:
                        delay = min(base_delay * (2 ** attempt), max_delay)
                        logger.warning(f"Retry {attempt + 1}/{max_retries} for {func.__name__} after {delay}s: {e}")
                        time.sleep(delay)
            raise last_exception
        return wrapper
    return decorator


def timed_execution(func: Callable) -> Callable:
    """执行时间统计装饰器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        try:
            result = func(*args, **kwargs)
            elapsed = int((time.time() - start) * 1000)
            return result, elapsed
        except Exception as e:
            elapsed = int((time.time() - start) * 1000)
            raise e
    return wrapper


class CircuitBreaker:
    """熔断器"""
    
    STATE_CLOSED = "CLOSED"      # 正常
    STATE_OPEN = "OPEN"          # 熔断
    STATE_HALF_OPEN = "HALF_OPEN"  # 半开
    
    def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 30.0):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.state = self.STATE_CLOSED
        self.failure_count = 0
        self.last_failure_time = 0.0
        self._lock = False
    
    def can_execute(self) -> bool:
        if self.state == self.STATE_CLOSED:
            return True
        if self.state == self.STATE_OPEN:
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = self.STATE_HALF_OPEN
                return True
            return False
        return True  # HALF_OPEN
    
    def record_success(self):
        self.failure_count = 0
        self.state = self.STATE_CLOSED
    
    def record_failure(self):
        self.failure_count += 1
        self.last_failure_time = time.time()
        if self.failure_count >= self.failure_threshold:
            self.state = self.STATE_OPEN
    
    def call(self, func: Callable, *args, **kwargs):
        if not self.can_execute():
            raise Exception("Circuit breaker is OPEN")
        try:
            result = func(*args, **kwargs)
            self.record_success()
            return result
        except Exception as e:
            self.record_failure()
            raise e


class RateLimiter:
    """速率限制器"""
    
    def __init__(self, max_requests: int = 100, window_seconds: float = 60.0):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.requests: Dict[str, list] = {}
        self._lock = threading.Lock()
    
    def is_allowed(self, key: str) -> bool:
        import threading
        with self._lock:
            now = time.time()
            if key not in self.requests:
                self.requests[key] = []
            # 清理过期请求
            self.requests[key] = [t for t in self.requests[key] if now - t < self.window_seconds]
            if len(self.requests[key]) < self.max_requests:
                self.requests[key].append(now)
                return True
            return False

3. 业务服务实现

services/flight_service/app.py

Python

复制代码
#!/usr/bin/env python3
"""
航班服务 - 座位锁定与释放API
实现TCC模式:Try(锁定) -> Confirm(确认) -> Cancel(取消)
"""
import json
import logging
import sqlite3
import threading
import time
import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List

import sys
sys.path.insert(0, '../..')

from shared.models import FlightReservation, IdempotencyKey
from shared.events import EventType, DomainEvent, CommandMessage, ResponseMessage
from shared.message_bus import InMemoryMessageBus
from shared.utils import generate_idempotency_key, retry_with_backoff, timed_execution, CircuitBreaker

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("FlightService")


class FlightDatabase:
    """航班数据库"""
    
    def __init__(self, db_path: str = "flight_service.db"):
        self.db_path = db_path
        self._init_db()
    
    def _get_connection(self):
        conn = sqlite3.connect(self.db_path, check_same_thread=False)
        conn.row_factory = sqlite3.Row
        return conn
    
    def _init_db(self):
        conn = self._get_connection()
        try:
            conn.executescript("""
                CREATE TABLE IF NOT EXISTS flights (
                    flight_id TEXT PRIMARY KEY,
                    airline TEXT NOT NULL,
                    from_city TEXT NOT NULL,
                    to_city TEXT NOT NULL,
                    departure_time TEXT NOT NULL,
                    arrival_time TEXT NOT NULL,
                    total_seats INTEGER NOT NULL,
                    available_seats INTEGER NOT NULL,
                    price REAL NOT NULL,
                    status TEXT DEFAULT 'ACTIVE'
                );
                
                CREATE TABLE IF NOT EXISTS seat_locks (
                    lock_id TEXT PRIMARY KEY,
                    flight_id TEXT NOT NULL,
                    seat_number TEXT NOT NULL,
                    saga_id TEXT NOT NULL,
                    lock_token TEXT NOT NULL,
                    locked_at TEXT NOT NULL,
                    expires_at TEXT NOT NULL,
                    status TEXT DEFAULT 'LOCKED',
                    passenger_name TEXT,
                    passenger_id TEXT,
                    FOREIGN KEY (flight_id) REFERENCES flights(flight_id)
                );
                
                CREATE TABLE IF NOT EXISTS reservations (
                    reservation_id TEXT PRIMARY KEY,
                    flight_id TEXT NOT NULL,
                    seat_number TEXT NOT NULL,
                    saga_id TEXT NOT NULL,
                    passenger_name TEXT NOT NULL,
                    passenger_id TEXT NOT NULL,
                    price REAL NOT NULL,
                    status TEXT DEFAULT 'CONFIRMED',
                    created_at TEXT NOT NULL,
                    cancelled_at TEXT,
                    refund_amount REAL DEFAULT 0
                );
                
                CREATE TABLE IF NOT EXISTS idempotency_store (
                    key TEXT PRIMARY KEY,
                    operation TEXT NOT NULL,
                    result TEXT,
                    created_at TEXT NOT NULL
                );
                
                CREATE INDEX IF NOT EXISTS idx_flight_id ON seat_locks(flight_id);
                CREATE INDEX IF NOT EXISTS idx_saga_id ON seat_locks(saga_id);
                CREATE INDEX IF NOT EXISTS idx_lock_token ON seat_locks(lock_token);
            """)
            conn.commit()
            
            # 初始化示例航班数据
            cursor = conn.execute("SELECT COUNT(*) FROM flights")
            if cursor.fetchone()[0] == 0:
                self._seed_data(conn)
        finally:
            conn.close()
    
    def _seed_data(self, conn):
        """初始化航班数据"""
        flights = [
            ("CA1001", "Air China", "Beijing", "Shanghai", 
             (datetime.now() + timedelta(days=1)).isoformat(),
             (datetime.now() + timedelta(days=1, hours=2)).isoformat(),
             200, 200, 1200.0),
            ("CA1002", "Air China", "Shanghai", "Beijing",
             (datetime.now() + timedelta(days=2)).isoformat(),
             (datetime.now() + timedelta(days=2, hours=2)).isoformat(),
             200, 200, 1100.0),
            ("MU5001", "China Eastern", "Beijing", "Guangzhou",
             (datetime.now() + timedelta(days=1)).isoformat(),
             (datetime.now() + timedelta(days=1, hours=3)).isoformat(),
             300, 300, 1500.0),
            ("CZ3001", "China Southern", "Shanghai", "Shenzhen",
             (datetime.now() + timedelta(days=3)).isoformat(),
             (datetime.now() + timedelta(days=3, hours=2, minutes=30)).isoformat(),
             250, 250, 1300.0),
        ]
        conn.executemany("""
            INSERT INTO flights (flight_id, airline, from_city, to_city, departure_time, arrival_time, total_seats, available_seats, price)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, flights)
        conn.commit()
        logger.info("Seeded flight data")
    
    def get_available_seats(self, flight_id: str) -> List[Dict[str, Any]]:
        """获取可用座位"""
        conn = self._get_connection()
        try:
            cursor = conn.execute("""
                SELECT seat_number FROM seat_locks 
                WHERE flight_id = ? AND status = 'LOCKED' AND expires_at > ?
            """, (flight_id, datetime.now().isoformat()))
            locked_seats = {row["seat_number"] for row in cursor.fetchall()}
            
            cursor = conn.execute("""
                SELECT seat_number FROM reservations 
                WHERE flight_id = ? AND status = 'CONFIRMED'
            """, (flight_id,))
            reserved_seats = {row["seat_number"] for row in cursor.fetchall()}
            
            # 生成所有座位
            cursor = conn.execute("SELECT total_seats FROM flights WHERE flight_id = ?", (flight_id,))
            row = cursor.fetchone()
            if not row:
                return []
            
            total = row["total_seats"]
            all_seats = [f"{i+1}A" if i % 6 == 0 else f"{i+1}B" if i % 6 == 1 else 
                        f"{i+1}C" if i % 6 == 2 else f"{i+1}D" if i % 6 == 3 else 
                        f"{i+1}E" if i % 6 == 4 else f"{i+1}F" for i in range(total)]
            
            available = [s for s in all_seats if s not in locked_seats and s not in reserved_seats]
            return [{"seat_number": s, "status": "AVAILABLE"} for s in available[:20]]
        finally:
            conn.close()
    
    def lock_seat(self, flight_id: str, seat_number: str, saga_id: str, 
                  passenger_name: str, passenger_id: str) -> Dict[str, Any]:
        """TCC Try: 锁定座位"""
        conn = self._get_connection()
        try:
            conn.execute("BEGIN IMMEDIATE")
            
            # 检查航班是否存在
            cursor = conn.execute("SELECT * FROM flights WHERE flight_id = ? AND status = 'ACTIVE'", 
                                (flight_id,))
            if not cursor.fetchone():
                conn.rollback()
                return {"success": False, "error": "Flight not found or inactive"}
            
            # 检查座位是否已被锁定或预订
            cursor = conn.execute("""
                SELECT * FROM seat_locks 
                WHERE flight_id = ? AND seat_number = ? AND status = 'LOCKED' AND expires_at > ?
            """, (flight_id, seat_number, datetime.now().isoformat()))
            if cursor.fetchone():
                conn.rollback()
                return {"success": False, "error": "Seat already locked"}
            
            cursor = conn.execute("""
                SELECT * FROM reservations 
                WHERE flight_id = ? AND seat_number = ? AND status = 'CONFIRMED'
            """, (flight_id, seat_number))
            if cursor.fetchone():
                conn.rollback()
                return {"success": False, "error": "Seat already reserved"}
            
            # 生成锁令牌
            lock_token = str(uuid.uuid4())
            lock_id = str(uuid.uuid4())
            expires_at = (datetime.now() + timedelta(minutes=10)).isoformat()
            
            conn.execute("""
                INSERT INTO seat_locks (lock_id, flight_id, seat_number, saga_id, lock_token, 
                                       locked_at, expires_at, status, passenger_name, passenger_id)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (lock_id, flight_id, seat_number, saga_id, lock_token,
                  datetime.now().isoformat(), expires_at, "LOCKED", passenger_name, passenger_id))
            
            # 减少可用座位数
            conn.execute("""
                UPDATE flights SET available_seats = available_seats - 1 
                WHERE flight_id = ?
            """, (flight_id,))
            
            conn.commit()
            
            return {
                "success": True,
                "lock_id": lock_id,
                "lock_token": lock_token,
                "expires_at": expires_at,
                "flight_id": flight_id,
                "seat_number": seat_number
            }
        except Exception as e:
            conn.rollback()
            logger.error(f"Lock seat failed: {e}")
            return {"success": False, "error": str(e)}
        finally:
            conn.close()
    
    def confirm_seat(self, lock_token: str) -> Dict[str, Any]:
        """TCC Confirm: 确认座位预订"""
        conn = self._get_connection()
        try:
            conn.execute("BEGIN IMMEDIATE")
            
            cursor = conn.execute("""
                SELECT * FROM seat_locks WHERE lock_token = ? AND status = 'LOCKED'
            """, (lock_token,))
            lock = cursor.fetchone()
            
            if not lock:
                conn.rollback()
                return {"success": False, "error": "Lock not found or expired"}
            
            # 检查是否已过期
            if datetime.fromisoformat(lock["expires_at"]) < datetime.now():
                conn.rollback()
                return {"success": False, "error": "Lock expired"}
            
            # 创建正式预订
            reservation_id = str(uuid.uuid4())
            conn.execute("""
                INSERT INTO reservations (reservation_id, flight_id, seat_number, saga_id, 
                                        passenger_name, passenger_id, price, status, created_at)
                SELECT ?, flight_id, seat_number, saga_id, passenger_name, passenger_id,
                       (SELECT price FROM flights WHERE flight_id = seat_locks.flight_id), 
                       'CONFIRMED', ?
                FROM seat_locks WHERE lock_token = ?
            """, (reservation_id, datetime.now().isoformat(), lock_token))
            
            # 更新锁状态
            conn.execute("""
                UPDATE seat_locks SET status = 'CONFIRMED' WHERE lock_token = ?
            """, (lock_token,))
            
            conn.commit()
            
            return {
                "success": True,
                "reservation_id": reservation_id,
                "lock_token": lock_token
            }
        except Exception as e:
            conn.rollback()
            logger.error(f"Confirm seat failed: {e}")
            return {"success": False, "error": str(e)}
        finally:
            conn.close()
    
    def cancel_lock(self, lock_token: str) -> Dict[str, Any]:
        """TCC Cancel: 取消座位锁定"""
        conn = self._get_connection()
        try:
            conn.execute("BEGIN IMMEDIATE")
            
            cursor = conn.execute("""
                SELECT * FROM seat_locks WHERE lock_token = ? AND status = 'LOCKED'
            """, (lock_token,))
            lock = cursor.fetchone()
            
            if not lock:
                # 可能已经确认或已取消,幂等处理
                cursor = conn.execute("""
                    SELECT * FROM seat_locks WHERE lock_token = ?
                """, (lock_token,))
                existing = cursor.fetchone()
                if existing and existing["status"] == "CANCELLED":
                    return {"success": True, "message": "Already cancelled"}
                return {"success": False, "error": "Lock not found"}
            
            # 释放座位
            conn.execute("""
                UPDATE flights SET available_seats = available_seats + 1 
                WHERE flight_id = ?
            """, (lock["flight_id"],))
            
            conn.execute("""
                UPDATE seat_locks SET status = 'CANCELLED' WHERE lock_token = ?
            """, (lock_token,))
            
            conn.commit()
            
            return {"success": True, "lock_token": lock_token}
        except Exception as e:
            conn.rollback()
            logger.error(f"Cancel lock failed: {e}")
            return {"success": False, "error": str(e)}
        finally:
            conn.close()
    
    def release_seat(self, reservation_id: str) -> Dict[str, Any]:
        """释放已确认座位(补偿用)"""
        conn = self._get_connection()
        try:
            conn.execute("BEGIN IMMEDIATE")
            
            cursor = conn.execute("""
                SELECT * FROM reservations WHERE reservation_id = ? AND status = 'CONFIRMED'
            """, (reservation_id,))
            reservation = cursor.fetchone()
            
            if not reservation:
                return {"success": True, "message": "Reservation not found or already cancelled"}
            
            # 计算退款金额(简化:全额退款)
            refund_amount = reservation["price"]
            
            conn.execute("""
                UPDATE reservations SET status = 'CANCELLED', cancelled_at = ?, refund_amount = ?
                WHERE reservation_id = ?
            """, (datetime.now().isoformat(), refund_amount, reservation_id))
            
            # 恢复可用座位
            conn.execute("""
                UPDATE flights SET available_seats = available_seats + 1 
                WHERE flight_id = ?
            """, (reservation["flight_id"],))
            
            conn.commit()
            
            return {
                "success": True,
                "reservation_id": reservation_id,
                "refund_amount": refund_amount
            }
        except Exception as e:
            conn.rollback()
            logger.error(f"Release seat failed: {e}")
            return {"success": False, "error": str(e)}
        finally:
            conn.close()
    
    def check_idempotency(self, key: str) -> Optional[Dict[str, Any]]:
        """检查幂等性"""
        conn = self._get_connection()
        try:
            cursor = conn.execute("SELECT * FROM idempotency_store WHERE key = ?", (key,))
            row = cursor.fetchone()
            if row:
                return json.loads(row["result"]) if row["result"] else None
            return None
        finally:
            conn.close()
    
    def save_idempotency(self, key: str, operation: str, result: Dict[str, Any]):
        """保存幂等性结果"""
        conn = self._get_connection()
        try:
            conn.execute("""
                INSERT OR REPLACE INTO idempotency_store (key, operation, result, created_at)
                VALUES (?, ?, ?, ?)
            """, (key, operation, json.dumps(result), datetime.now().isoformat()))
            conn.commit()
        finally:
            conn.close()


class FlightService:
    """航班服务"""
    
    def __init__(self):
        self.db = FlightDatabase()
        self.bus = InMemoryMessageBus()
        self.circuit_breaker = CircuitBreaker(failure_threshold=5)
        self._setup_subscribers()
        self._start_cleanup_thread()
    
    def _setup_subscribers(self):
        """设置消息订阅"""
        self.bus.subscribe_command(EventType.CMD_LOCK_SEAT, self._handle_lock_seat)
        self.bus.subscribe_command(EventType.CMD_CONFIRM_SEAT, self._handle_confirm_seat)
        self.bus.subscribe_command(EventType.CMD_CANCEL_SEAT, self._handle_cancel_seat)
        self.bus.subscribe_command(EventType.CMD_RELEASE_SEAT, self._handle_release_seat)
    
    def _start_cleanup_thread(self):
        """启动过期锁清理线程"""
        def cleanup():
            while True:
                time.sleep(60)
                self._cleanup_expired_locks()
        
        thread = threading.Thread(target=cleanup, daemon=True, name="FlightLockCleanup")
        thread.start()
    
    def _cleanup_expired_locks(self):
        """清理过期锁"""
        conn = self.db._get_connection()
        try:
            cursor = conn.execute("""
                SELECT flight_id, lock_token FROM seat_locks 
                WHERE status = 'LOCKED' AND expires_at < ?
            """, (datetime.now().isoformat(),))
            expired = cursor.fetchall()
            
            for row in expired:
                self.db.cancel_lock(row["lock_token"])
                logger.info(f"Cleaned up expired lock: {row['lock_token']}")
        except Exception as e:
            logger.error(f"Cleanup failed: {e}")
        finally:
            conn.close()
    
    def _handle_lock_seat(self, command: Dict[str, Any]):
        """处理锁定座位命令"""
        start_time = time.time()
        cmd = CommandMessage.from_dict(command)
        saga_id = cmd.saga_id
        step_id = cmd.step_id
        
        # 幂等性检查
        idempotency_key = cmd.idempotency_key or generate_idempotency_key(
            cmd.command_type, saga_id, step_id, json.dumps(cmd.payload, sort_keys=True)
        )
        
        cached = self.db.check_idempotency(idempotency_key)
        if cached is not None:
            logger.info(f"Idempotent lock seat return cached result for {saga_id}")
            self._send_response(cmd, cached, int((time.time() - start_time) * 1000))
            return
        
        try:
            result = self.circuit_breaker.call(
                self.db.lock_seat,
                cmd.payload["flight_id"],
                cmd.payload["seat_number"],
                saga_id,
                cmd.payload["passenger_name"],
                cmd.payload["passenger_id"]
            )
            
            # 保存幂等性结果
            self.db.save_idempotency(idempotency_key, cmd.command_type, result)
            
            # 发布事件
            if result["success"]:
                event = DomainEvent(
                    event_type=EventType.EVT_SEAT_LOCKED,
                    saga_id=saga_id,
                    step_id=step_id,
                    correlation_id=cmd.correlation_id,
                    payload=result,
                    partition_key=cmd.payload.get("user_id", saga_id)
                )
                self.bus.publish_event(event.to_dict(), partition_key=event.partition_key)
            else:
                event = DomainEvent(
                    event_type=EventType.EVT_SEAT_LOCK_FAILED,
                    saga_id=saga_id,
                    step_id=step_id,
                    correlation_id=cmd.correlation_id,
                    payload=result,
                    partition_key=cmd.payload.get("user_id", saga_id)
                )
                self.bus.publish_event(event.to_dict(), partition_key=event.partition_key)
            
            self._send_response(cmd, result, int((time.time() - start_time) * 1000))
            
        except Exception as e:
            logger.error(f"Lock seat error: {e}")
            error_result = {"success": False, "error": str(e)}
            self._send_response(cmd, error_result, int((time.time() - start_time) * 1000))
    
    def _handle_confirm_seat(self, command: Dict[str, Any]):
        """处理确认座位命令"""
        start_time = time.time()
        cmd = CommandMessage.from_dict(command)
        
        try:
            result = self.db.confirm_seat(cmd.payload["lock_token"])
            self._send_response(cmd, result, int((time.time() - start_time) * 1000))
        except Exception as e:
            logger.error(f"Confirm seat error: {e}")
            self._send_response(cmd, {"success": False, "error": str(e)}, 
                              int((time.time() - start_time) * 1000))
    
    def _handle_cancel_seat(self, command: Dict[str, Any]):
        """处理取消锁定命令"""
        start_time = time.time()
        cmd = CommandMessage.from_dict(command)
        
        try:
            result = self.db.cancel_lock(cmd.payload["lock_token"])
            self._send_response(cmd, result, int((time.time() - start_time) * 1000))
        except Exception as e:
            logger.error(f"Cancel seat error: {e}")
            self._send_response(cmd, {"success": False, "error": str(e)}, 
                              int((time.time() - start_time) * 1000))
    
    def _handle_release_seat(self, command: Dict[str, Any]):
        """处理释放座位命令(补偿)"""
        start_time = time.time()
        cmd = CommandMessage.from_dict(command)
        
        try:
            result = self.db.release_seat(cmd.payload["reservation_id"])
            self._send_response(cmd, result, int((time.time() - start_time) * 1000))
        except Exception as e:
            logger.error(f"Release seat error: {e}")
            self._send_response(cmd, {"success": False, "error": str(e)}, 
                              int((time.time() - start_time) * 1000))
    
    def _send_response(self, command: CommandMessage, result: Dict[str, Any], processing_time: int):
        """发送响应"""
        response = ResponseMessage(
            command_id=command.command_id,
            saga_id=command.saga_id,
            step_id=command.step_id,
            correlation_id=command.correlation_id,
            success=result.get("success", False),
            result=result,
            error_code=result.get("error_code"),
            error_message=result.get("error"),
            processing_time_ms=processing_time
        )
        self.bus.publish_response(response.to_dict(), partition_key=command.correlation_id)


if __name__ == "__main__":
    service = FlightService()
    logger.info("Flight Service started")
    # 保持运行
    while True:
        time.sleep(1)

services/hotel_service/app.py

相关推荐
取名好樊1 小时前
Windows Docker PostgreSQL 端口绑定失败问题记录
windows·docker·postgresql
ai产品老杨1 小时前
深度解析:基于Docker构建的安防视频AI平台——如何通过GB28181/RTSP协议栈统一接入与全套源码交付,破局异构边缘计算芯片内卷
人工智能·docker·音视频
逍遥德1 小时前
PostgreSQL --- 二进制数使用详解
数据库·sql·postgresql
流浪法师解剖鱼1 小时前
CocosCreator制作推箱子游戏
python·cocos2d
Ze3G90nYt1 小时前
Redis 分布式锁进阶第一百三十一篇
数据库·redis·分布式
AI服务老曹1 小时前
基于Docker与边缘计算的企业级AI视频平台架构演进:GB28181/RTSP多协议接入与源码交付深度解析
人工智能·docker·边缘计算
蜀道山老天师1 小时前
OpenClaw 从零部署 + 飞书机器人完整接入(实操篇)
运维·docker·容器·飞书
倔强的石头1061 小时前
《Kingbase护城河》——数据库卡顿急救手册:会话状态深度解析与“僵尸进程”排查实战
数据库
财经资讯数据_灵砚智能2 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年6月9日
人工智能·python·ai·信息可视化·自然语言处理·ai编程·灵砚智能