。
以下是完整的可执行代码:
1. 项目结构
plain
travel-saga/
├── docker-compose.yml
├── init-db.sql
├── shared/
│ ├── __init__.py
│ ├── models.py
│ ├── events.py
│ ├── saga_state.py
│ ├── message_bus.py
│ └── utils.py
├── services/
│ ├── flight_service/
│ ├── hotel_service/
│ ├── car_service/
│ ├── payment_service/
│ └── saga_orchestrator/
└── tests/
└── chaos_test.py
2. 共享基础设施层
shared/models.py
Python
"""
共享数据模型与领域对象
"""
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum, auto
from typing import Optional, Dict, Any, List
import json
import uuid
class SagaStatus(Enum):
"""Saga执行状态"""
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
COMPENSATING = "COMPENSATING"
COMPENSATED = "COMPENSATED"
FAILED = "FAILED"
SUSPENDED = "SUSPENDED" # 悬挂状态,需要人工介入
TIMEOUT = "TIMEOUT"
class StepStatus(Enum):
"""单个步骤执行状态"""
PENDING = "PENDING"
TRYING = "TRYING" # TCC: Try阶段
TRIED = "TRIED" # TCC: Try成功
CONFIRMING = "CONFIRMING" # TCC: Confirm阶段
CONFIRMED = "CONFIRMED" # TCC: Confirm成功
CANCELLING = "CANCELLING" # TCC: Cancel阶段
CANCELLED = "CANCELLED" # TCC: Cancel成功
FAILED = "FAILED"
COMPENSATING = "COMPENSATING"
COMPENSATED = "COMPENSATED"
TIMEOUT = "TIMEOUT"
class ServiceType(Enum):
"""服务类型"""
FLIGHT = "FLIGHT"
HOTEL = "HOTEL"
CAR = "CAR"
PAYMENT = "PAYMENT"
@dataclass
class SagaContext:
"""Saga上下文,贯穿整个事务"""
saga_id: str
user_id: str
correlation_id: str
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"saga_id": self.saga_id,
"user_id": self.user_id,
"correlation_id": self.correlation_id,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"metadata": self.metadata
}
@dataclass
class FlightReservation:
"""航班预订"""
flight_id: str
seat_number: str
passenger_name: str
passenger_id: str
departure_time: datetime
arrival_time: datetime
from_city: str
to_city: str
price: float
lock_token: Optional[str] = None
status: str = "PENDING"
def to_dict(self):
result = asdict(self)
result['departure_time'] = self.departure_time.isoformat()
result['arrival_time'] = self.arrival_time.isoformat()
return result
@dataclass
class HotelReservation:
"""酒店预订"""
hotel_id: str
room_type: str
check_in: datetime
check_out: datetime
guest_name: str
guest_count: int
price_per_night: float
total_nights: int
lock_token: Optional[str] = None
status: str = "PENDING"
def to_dict(self):
result = asdict(self)
result['check_in'] = self.check_in.isoformat()
result['check_out'] = self.check_out.isoformat()
return result
@dataclass
class CarReservation:
"""租车预订"""
car_id: str
car_model: str
pickup_location: str
dropoff_location: str
pickup_time: datetime
dropoff_time: datetime
driver_name: str
driver_license: str
daily_rate: float
total_days: int
lock_token: Optional[str] = None
status: str = "PENDING"
def to_dict(self):
result = asdict(self)
result['pickup_time'] = self.pickup_time.isoformat()
result['dropoff_time'] = self.dropoff_time.isoformat()
return result
@dataclass
class PaymentAuthorization:
"""支付预授权"""
payment_id: str
amount: float
currency: str = "CNY"
card_token: str = ""
auth_code: Optional[str] = None
status: str = "PENDING"
captured_amount: float = 0.0
refunded_amount: float = 0.0
def to_dict(self):
return asdict(self)
@dataclass
class TravelBooking:
"""完整旅行预订请求"""
booking_id: str = field(default_factory=lambda: str(uuid.uuid4()))
user_id: str = ""
flight: Optional[FlightReservation] = None
hotel: Optional[HotelReservation] = None
car: Optional[CarReservation] = None
payment: Optional[PaymentAuthorization] = None
total_amount: float = 0.0
status: SagaStatus = SagaStatus.PENDING
created_at: datetime = field(default_factory=datetime.utcnow)
def to_dict(self):
return {
"booking_id": self.booking_id,
"user_id": self.user_id,
"flight": self.flight.to_dict() if self.flight else None,
"hotel": self.hotel.to_dict() if self.hotel else None,
"car": self.car.to_dict() if self.car else None,
"payment": self.payment.to_dict() if self.payment else None,
"total_amount": self.total_amount,
"status": self.status.value,
"created_at": self.created_at.isoformat()
}
@dataclass
class SagaStep:
"""Saga步骤定义"""
step_id: str
step_number: int
service_type: ServiceType
operation: str # "TRY", "CONFIRM", "CANCEL"
status: StepStatus = StepStatus.PENDING
request_payload: Dict[str, Any] = field(default_factory=dict)
response_payload: Dict[str, Any] = field(default_factory=dict)
error_message: Optional[str] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
retry_count: int = 0
max_retries: int = 3
timeout_seconds: int = 30
compensation_payload: Optional[Dict[str, Any]] = None
dependencies: List[str] = field(default_factory=list) # 依赖的步骤ID
def to_dict(self):
return {
"step_id": self.step_id,
"step_number": self.step_number,
"service_type": self.service_type.value,
"operation": self.operation,
"status": self.status.value,
"request_payload": self.request_payload,
"response_payload": self.response_payload,
"error_message": self.error_message,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"retry_count": self.retry_count,
"max_retries": self.max_retries,
"timeout_seconds": self.timeout_seconds,
"compensation_payload": self.compensation_payload,
"dependencies": self.dependencies
}
@dataclass
class SagaExecution:
"""Saga执行实例"""
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
saga_id: str = ""
status: SagaStatus = SagaStatus.PENDING
steps: List[SagaStep] = field(default_factory=list)
current_step_index: int = 0
context: Dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
compensation_triggered_at: Optional[datetime] = None
suspended_at: Optional[datetime] = None
suspend_reason: Optional[str] = None
def to_dict(self):
return {
"execution_id": self.execution_id,
"saga_id": self.saga_id,
"status": self.status.value,
"steps": [s.to_dict() for s in self.steps],
"current_step_index": self.current_step_index,
"context": self.context,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"error_message": self.error_message,
"compensation_triggered_at": self.compensation_triggered_at.isoformat() if self.compensation_triggered_at else None,
"suspended_at": self.suspended_at.isoformat() if self.suspended_at else None,
"suspend_reason": self.suspend_reason
}
class IdempotencyKey:
"""幂等性控制"""
def __init__(self, key: str, operation: str, created_at: Optional[datetime] = None):
self.key = key
self.operation = operation
self.created_at = created_at or datetime.utcnow()
self.processed = False
self.result: Optional[Dict[str, Any]] = None
def to_dict(self):
return {
"key": self.key,
"operation": self.operation,
"created_at": self.created_at.isoformat(),
"processed": self.processed,
"result": self.result
}
shared/events.py
Python
"""
事件定义 - 命令事件与领域事件
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, Optional
import uuid
class EventType:
"""命令事件类型"""
# 航班服务命令
CMD_LOCK_SEAT = "CMD_LOCK_SEAT"
CMD_CONFIRM_SEAT = "CMD_CONFIRM_SEAT"
CMD_CANCEL_SEAT = "CMD_CANCEL_SEAT"
CMD_RELEASE_SEAT = "CMD_RELEASE_SEAT"
# 酒店服务命令
CMD_LOCK_ROOM = "CMD_LOCK_ROOM"
CMD_CONFIRM_ROOM = "CMD_CONFIRM_ROOM"
CMD_CANCEL_ROOM = "CMD_CANCEL_ROOM"
CMD_RELEASE_ROOM = "CMD_RELEASE_ROOM"
# 租车服务命令
CMD_LOCK_CAR = "CMD_LOCK_CAR"
CMD_CONFIRM_CAR = "CMD_CONFIRM_CAR"
CMD_CANCEL_CAR = "CMD_CANCEL_CAR"
CMD_RELEASE_CAR = "CMD_RELEASE_CAR"
# 支付服务命令
CMD_PREAUTH_PAYMENT = "CMD_PREAUTH_PAYMENT"
CMD_CAPTURE_PAYMENT = "CMD_CAPTURE_PAYMENT"
CMD_REFUND_PAYMENT = "CMD_REFUND_PAYMENT"
CMD_CANCEL_PREAUTH = "CMD_CANCEL_PREAUTH"
# Saga编排器命令
CMD_START_SAGA = "CMD_START_SAGA"
CMD_COMPENSATE_SAGA = "CMD_COMPENSATE_SAGA"
CMD_RETRY_STEP = "CMD_RETRY_STEP"
CMD_SUSPEND_SAGA = "CMD_SUSPEND_SAGA"
CMD_RESUME_SAGA = "CMD_RESUME_SAGA"
# 领域事件
EVT_SEAT_LOCKED = "EVT_SEAT_LOCKED"
EVT_SEAT_CONFIRMED = "EVT_SEAT_CONFIRMED"
EVT_SEAT_CANCELLED = "EVT_SEAT_CANCELLED"
EVT_SEAT_LOCK_FAILED = "EVT_SEAT_LOCK_FAILED"
EVT_ROOM_LOCKED = "EVT_ROOM_LOCKED"
EVT_ROOM_CONFIRMED = "EVT_ROOM_CONFIRMED"
EVT_ROOM_CANCELLED = "EVT_ROOM_CANCELLED"
EVT_ROOM_LOCK_FAILED = "EVT_ROOM_LOCK_FAILED"
EVT_CAR_LOCKED = "EVT_CAR_LOCKED"
EVT_CAR_CONFIRMED = "EVT_CAR_CONFIRMED"
EVT_CAR_CANCELLED = "EVT_CAR_CANCELLED"
EVT_CAR_LOCK_FAILED = "EVT_CAR_LOCK_FAILED"
EVT_PAYMENT_PREAUTHED = "EVT_PAYMENT_PREAUTHED"
EVT_PAYMENT_CAPTURED = "EVT_PAYMENT_CAPTURED"
EVT_PAYMENT_REFUNDED = "EVT_PAYMENT_REFUNDED"
EVT_PAYMENT_PREAUTH_FAILED = "EVT_PAYMENT_PREAUTH_FAILED"
EVT_SAGA_STARTED = "EVT_SAGA_STARTED"
EVT_SAGA_STEP_COMPLETED = "EVT_SAGA_STEP_COMPLETED"
EVT_SAGA_STEP_FAILED = "EVT_SAGA_STEP_FAILED"
EVT_SAGA_COMPLETED = "EVT_SAGA_COMPLETED"
EVT_SAGA_COMPENSATION_STARTED = "EVT_SAGA_COMPENSATION_STARTED"
EVT_SAGA_COMPENSATED = "EVT_SAGA_COMPENSATED"
EVT_SAGA_FAILED = "EVT_SAGA_FAILED"
EVT_SAGA_SUSPENDED = "EVT_SAGA_SUSPENDED"
@dataclass
class DomainEvent:
"""领域事件基类"""
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
event_type: str = ""
saga_id: str = ""
step_id: str = ""
correlation_id: str = ""
timestamp: datetime = field(default_factory=datetime.utcnow)
payload: Dict[str, Any] = field(default_factory=dict)
partition_key: str = "" # 用于消息分区保证顺序
def to_dict(self):
return {
"event_id": self.event_id,
"event_type": self.event_type,
"saga_id": self.saga_id,
"step_id": self.step_id,
"correlation_id": self.correlation_id,
"timestamp": self.timestamp.isoformat(),
"payload": self.payload,
"partition_key": self.partition_key
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DomainEvent':
return cls(
event_id=data.get("event_id", str(uuid.uuid4())),
event_type=data["event_type"],
saga_id=data.get("saga_id", ""),
step_id=data.get("step_id", ""),
correlation_id=data.get("correlation_id", ""),
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
payload=data.get("payload", {}),
partition_key=data.get("partition_key", "")
)
@dataclass
class CommandMessage:
"""命令消息"""
command_id: str = field(default_factory=lambda: str(uuid.uuid4()))
command_type: str = ""
saga_id: str = ""
step_id: str = ""
correlation_id: str = ""
idempotency_key: str = ""
target_service: str = ""
payload: Dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.utcnow)
ttl_seconds: int = 300 # 消息存活时间
priority: int = 5 # 1-10, 数字越小优先级越高
def to_dict(self):
return {
"command_id": self.command_id,
"command_type": self.command_type,
"saga_id": self.saga_id,
"step_id": self.step_id,
"correlation_id": self.correlation_id,
"idempotency_key": self.idempotency_key,
"target_service": self.target_service,
"payload": self.payload,
"timestamp": self.timestamp.isoformat(),
"ttl_seconds": self.ttl_seconds,
"priority": self.priority
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CommandMessage':
return cls(
command_id=data.get("command_id", str(uuid.uuid4())),
command_type=data["command_type"],
saga_id=data.get("saga_id", ""),
step_id=data.get("step_id", ""),
correlation_id=data.get("correlation_id", ""),
idempotency_key=data.get("idempotency_key", ""),
target_service=data.get("target_service", ""),
payload=data.get("payload", {}),
timestamp=datetime.fromisoformat(data["timestamp"]) if "timestamp" in data else datetime.utcnow(),
ttl_seconds=data.get("ttl_seconds", 300),
priority=data.get("priority", 5)
)
@dataclass
class ResponseMessage:
"""响应消息"""
response_id: str = field(default_factory=lambda: str(uuid.uuid4()))
command_id: str = ""
saga_id: str = ""
step_id: str = ""
correlation_id: str = ""
success: bool = False
result: Dict[str, Any] = field(default_factory=dict)
error_code: Optional[str] = None
error_message: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.utcnow)
processing_time_ms: int = 0
def to_dict(self):
return {
"response_id": self.response_id,
"command_id": self.command_id,
"saga_id": self.saga_id,
"step_id": self.step_id,
"correlation_id": self.correlation_id,
"success": self.success,
"result": self.result,
"error_code": self.error_code,
"error_message": self.error_message,
"timestamp": self.timestamp.isoformat(),
"processing_time_ms": self.processing_time_ms
}
shared/saga_state.py
Python
"""
Saga状态持久化 - 数据库记录与恢复
"""
import json
import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
import sqlite3
import threading
from contextlib import contextmanager
logger = logging.getLogger(__name__)
class SagaStateStore:
"""
Saga执行日志数据库记录与恢复
支持幂等性控制、状态持久化、悬挂事务检测
"""
def __init__(self, db_path: str = "saga_state.db"):
self.db_path = db_path
self._local = threading.local()
self._init_db()
def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._local, 'connection'):
self._local.connection = sqlite3.connect(self.db_path, check_same_thread=False)
self._local.connection.row_factory = sqlite3.Row
return self._local.connection
@contextmanager
def _transaction(self):
conn = self._get_connection()
try:
conn.execute("BEGIN IMMEDIATE")
yield conn
conn.commit()
except Exception as e:
conn.rollback()
raise e
def _init_db(self):
"""初始化数据库表结构"""
with self._transaction() as conn:
# Saga执行记录表
conn.execute("""
CREATE TABLE IF NOT EXISTS saga_executions (
execution_id TEXT PRIMARY KEY,
saga_id TEXT NOT NULL,
status TEXT NOT NULL,
steps_json TEXT NOT NULL,
current_step_index INTEGER DEFAULT 0,
context_json TEXT DEFAULT '{}',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
completed_at TEXT,
error_message TEXT,
compensation_triggered_at TEXT,
suspended_at TEXT,
suspend_reason TEXT,
user_id TEXT,
total_amount REAL DEFAULT 0,
INDEX idx_saga_id (saga_id),
INDEX idx_status (status),
INDEX idx_user_id (user_id),
INDEX idx_created_at (created_at)
)
""")
# Saga步骤执行记录表
conn.execute("""
CREATE TABLE IF NOT EXISTS saga_steps (
step_id TEXT PRIMARY KEY,
execution_id TEXT NOT NULL,
saga_id TEXT NOT NULL,
step_number INTEGER NOT NULL,
service_type TEXT NOT NULL,
operation TEXT NOT NULL,
status TEXT NOT NULL,
request_payload TEXT,
response_payload TEXT,
error_message TEXT,
started_at TEXT,
completed_at TEXT,
retry_count INTEGER DEFAULT 0,
max_retries INTEGER DEFAULT 3,
timeout_seconds INTEGER DEFAULT 30,
compensation_payload TEXT,
dependencies TEXT,
FOREIGN KEY (execution_id) REFERENCES saga_executions(execution_id),
INDEX idx_execution_id (execution_id),
INDEX idx_saga_id (saga_id)
)
""")
# 幂等性控制表
conn.execute("""
CREATE TABLE IF NOT EXISTS idempotency_keys (
key TEXT PRIMARY KEY,
operation TEXT NOT NULL,
created_at TEXT NOT NULL,
processed INTEGER DEFAULT 0,
result_json TEXT,
saga_id TEXT,
INDEX idx_operation (operation),
INDEX idx_created_at (created_at)
)
""")
# 事件日志表 (用于审计与恢复)
conn.execute("""
CREATE TABLE IF NOT EXISTS event_log (
event_id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
saga_id TEXT NOT NULL,
step_id TEXT,
execution_id TEXT,
payload TEXT,
timestamp TEXT NOT NULL,
partition_key TEXT,
INDEX idx_saga_id (saga_id),
INDEX idx_event_type (event_type),
INDEX idx_timestamp (timestamp)
)
""")
# 悬挂事务告警表
conn.execute("""
CREATE TABLE IF NOT EXISTS suspended_transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
saga_id TEXT NOT NULL,
execution_id TEXT NOT NULL,
suspended_at TEXT NOT NULL,
reason TEXT NOT NULL,
alert_sent INTEGER DEFAULT 0,
resolved INTEGER DEFAULT 0,
resolved_at TEXT,
resolved_by TEXT,
resolution_notes TEXT,
INDEX idx_saga_id (saga_id),
INDEX idx_alert_sent (alert_sent),
INDEX idx_suspended_at (suspended_at)
)
""")
# 监控指标表
conn.execute("""
CREATE TABLE IF NOT EXISTS saga_metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
saga_id TEXT,
event_type TEXT NOT NULL,
service_type TEXT,
step_number INTEGER,
success INTEGER,
processing_time_ms INTEGER,
timestamp TEXT NOT NULL,
INDEX idx_timestamp (timestamp),
INDEX idx_event_type (event_type)
)
""")
conn.commit()
def save_execution(self, execution: Dict[str, Any]) -> bool:
"""保存或更新Saga执行状态"""
try:
with self._transaction() as conn:
conn.execute("""
INSERT OR REPLACE INTO saga_executions (
execution_id, saga_id, status, steps_json, current_step_index,
context_json, created_at, updated_at, completed_at, error_message,
compensation_triggered_at, suspended_at, suspend_reason, user_id, total_amount
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
execution["execution_id"],
execution["saga_id"],
execution["status"],
json.dumps(execution.get("steps", [])),
execution.get("current_step_index", 0),
json.dumps(execution.get("context", {})),
execution["created_at"],
execution["updated_at"],
execution.get("completed_at"),
execution.get("error_message"),
execution.get("compensation_triggered_at"),
execution.get("suspended_at"),
execution.get("suspend_reason"),
execution.get("context", {}).get("user_id"),
execution.get("context", {}).get("total_amount", 0)
))
return True
except Exception as e:
logger.error(f"Failed to save execution: {e}")
return False
def load_execution(self, execution_id: str) -> Optional[Dict[str, Any]]:
"""加载Saga执行状态"""
try:
with self._transaction() as conn:
cursor = conn.execute(
"SELECT * FROM saga_executions WHERE execution_id = ?",
(execution_id,)
)
row = cursor.fetchone()
if row:
return {
"execution_id": row["execution_id"],
"saga_id": row["saga_id"],
"status": row["status"],
"steps": json.loads(row["steps_json"]),
"current_step_index": row["current_step_index"],
"context": json.loads(row["context_json"]),
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"completed_at": row["completed_at"],
"error_message": row["error_message"],
"compensation_triggered_at": row["compensation_triggered_at"],
"suspended_at": row["suspended_at"],
"suspend_reason": row["suspend_reason"]
}
return None
except Exception as e:
logger.error(f"Failed to load execution: {e}")
return None
def load_unfinished_executions(self) -> List[Dict[str, Any]]:
"""加载所有未完成的执行(用于系统恢复)"""
try:
with self._transaction() as conn:
cursor = conn.execute("""
SELECT * FROM saga_executions
WHERE status NOT IN ('COMPLETED', 'COMPENSATED', 'FAILED')
ORDER BY created_at ASC
""")
rows = cursor.fetchall()
return [self._row_to_dict(row) for row in rows]
except Exception as e:
logger.error(f"Failed to load unfinished executions: {e}")
return []
def _row_to_dict(self, row) -> Dict[str, Any]:
return {
"execution_id": row["execution_id"],
"saga_id": row["saga_id"],
"status": row["status"],
"steps": json.loads(row["steps_json"]),
"current_step_index": row["current_step_index"],
"context": json.loads(row["context_json"]),
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"completed_at": row["completed_at"],
"error_message": row["error_message"],
"compensation_triggered_at": row["compensation_triggered_at"],
"suspended_at": row["suspended_at"],
"suspend_reason": row["suspend_reason"]
}
def save_step(self, step: Dict[str, Any]) -> bool:
"""保存步骤状态"""
try:
with self._transaction() as conn:
conn.execute("""
INSERT OR REPLACE INTO saga_steps (
step_id, execution_id, saga_id, step_number, service_type,
operation, status, request_payload, response_payload, error_message,
started_at, completed_at, retry_count, max_retries, timeout_seconds,
compensation_payload, dependencies
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
step["step_id"],
step.get("execution_id", ""),
step.get("saga_id", ""),
step["step_number"],
step["service_type"],
step["operation"],
step["status"],
json.dumps(step.get("request_payload", {})),
json.dumps(step.get("response_payload", {})),
step.get("error_message"),
step.get("started_at"),
step.get("completed_at"),
step.get("retry_count", 0),
step.get("max_retries", 3),
step.get("timeout_seconds", 30),
json.dumps(step.get("compensation_payload")) if step.get("compensation_payload") else None,
json.dumps(step.get("dependencies", []))
))
return True
except Exception as e:
logger.error(f"Failed to save step: {e}")
return False
def check_idempotency(self, key: str) -> Optional[Dict[str, Any]]:
"""检查幂等性键"""
try:
with self._transaction() as conn:
cursor = conn.execute(
"SELECT * FROM idempotency_keys WHERE key = ?",
(key,)
)
row = cursor.fetchone()
if row:
return {
"key": row["key"],
"operation": row["operation"],
"created_at": row["created_at"],
"processed": bool(row["processed"]),
"result": json.loads(row["result_json"]) if row["result_json"] else None,
"saga_id": row["saga_id"]
}
return None
except Exception as e:
logger.error(f"Failed to check idempotency: {e}")
return None
def save_idempotency_key(self, key: str, operation: str, saga_id: str = "") -> bool:
"""保存幂等性键"""
try:
with self._transaction() as conn:
conn.execute("""
INSERT OR IGNORE INTO idempotency_keys (key, operation, created_at, saga_id)
VALUES (?, ?, ?, ?)
""", (key, operation, datetime.utcnow().isoformat(), saga_id))
return True
except Exception as e:
logger.error(f"Failed to save idempotency key: {e}")
return False
def mark_idempotency_processed(self, key: str, result: Dict[str, Any]) -> bool:
"""标记幂等性键已处理"""
try:
with self._transaction() as conn:
conn.execute("""
UPDATE idempotency_keys
SET processed = 1, result_json = ?
WHERE key = ?
""", (json.dumps(result), key))
return True
except Exception as e:
logger.error(f"Failed to mark idempotency processed: {e}")
return False
def log_event(self, event: Dict[str, Any]) -> bool:
"""记录事件日志"""
try:
with self._transaction() as conn:
conn.execute("""
INSERT INTO event_log (event_id, event_type, saga_id, step_id, execution_id, payload, timestamp, partition_key)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
event.get("event_id", ""),
event["event_type"],
event.get("saga_id", ""),
event.get("step_id", ""),
event.get("execution_id", ""),
json.dumps(event.get("payload", {})),
event.get("timestamp", datetime.utcnow().isoformat()),
event.get("partition_key", "")
))
return True
except Exception as e:
logger.error(f"Failed to log event: {e}")
return False
def record_metric(self, metric: Dict[str, Any]) -> bool:
"""记录监控指标"""
try:
with self._transaction() as conn:
conn.execute("""
INSERT INTO saga_metrics (saga_id, event_type, service_type, step_number, success, processing_time_ms, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
metric.get("saga_id"),
metric["event_type"],
metric.get("service_type"),
metric.get("step_number"),
1 if metric.get("success") else 0,
metric.get("processing_time_ms", 0),
metric.get("timestamp", datetime.utcnow().isoformat())
))
return True
except Exception as e:
logger.error(f"Failed to record metric: {e}")
return False
def create_suspension_alert(self, saga_id: str, execution_id: str, reason: str) -> bool:
"""创建悬挂事务告警"""
try:
with self._transaction() as conn:
conn.execute("""
INSERT INTO suspended_transactions (saga_id, execution_id, suspended_at, reason)
VALUES (?, ?, ?, ?)
""", (saga_id, execution_id, datetime.utcnow().isoformat(), reason))
return True
except Exception as e:
logger.error(f"Failed to create suspension alert: {e}")
return False
def get_suspended_transactions(self, alert_sent_only: bool = False) -> List[Dict[str, Any]]:
"""获取悬挂事务列表"""
try:
with self._transaction() as conn:
if alert_sent_only:
cursor = conn.execute("""
SELECT * FROM suspended_transactions
WHERE alert_sent = 0 AND resolved = 0
ORDER BY suspended_at ASC
""")
else:
cursor = conn.execute("""
SELECT * FROM suspended_transactions
WHERE resolved = 0
ORDER BY suspended_at ASC
""")
rows = cursor.fetchall()
return [{
"id": row["id"],
"saga_id": row["saga_id"],
"execution_id": row["execution_id"],
"suspended_at": row["suspended_at"],
"reason": row["reason"],
"alert_sent": bool(row["alert_sent"]),
"resolved": bool(row["resolved"])
} for row in rows]
except Exception as e:
logger.error(f"Failed to get suspended transactions: {e}")
return []
def resolve_suspension(self, saga_id: str, resolved_by: str, notes: str) -> bool:
"""人工解决悬挂事务"""
try:
with self._transaction() as conn:
conn.execute("""
UPDATE suspended_transactions
SET resolved = 1, resolved_at = ?, resolved_by = ?, resolution_notes = ?
WHERE saga_id = ? AND resolved = 0
""", (datetime.utcnow().isoformat(), resolved_by, notes, saga_id))
# 同时更新saga执行状态
conn.execute("""
UPDATE saga_executions
SET status = 'FAILED', updated_at = ?, error_message = ?
WHERE saga_id = ? AND status = 'SUSPENDED'
""", (datetime.utcnow().isoformat(), f"Manually resolved by {resolved_by}: {notes}", saga_id))
return True
except Exception as e:
logger.error(f"Failed to resolve suspension: {e}")
return False
def get_metrics_summary(self, hours: int = 24) -> Dict[str, Any]:
"""获取监控指标摘要"""
try:
with self._transaction() as conn:
from_time = datetime.utcnow().isoformat()
# 简化实现,实际应计算时间范围
cursor = conn.execute("""
SELECT
COUNT(*) as total,
SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as success_count,
AVG(processing_time_ms) as avg_time
FROM saga_metrics
WHERE timestamp > datetime('now', '-{} hours')
""".format(hours))
row = cursor.fetchone()
cursor = conn.execute("""
SELECT event_type, COUNT(*) as count
FROM saga_metrics
WHERE timestamp > datetime('now', '-{} hours')
GROUP BY event_type
""".format(hours))
event_counts = {row["event_type"]: row["count"] for row in cursor.fetchall()}
return {
"total_events": row["total"] or 0,
"success_count": row["success_count"] or 0,
"failure_count": (row["total"] or 0) - (row["success_count"] or 0),
"avg_processing_time_ms": row["avg_time"] or 0,
"event_breakdown": event_counts,
"period_hours": hours
}
except Exception as e:
logger.error(f"Failed to get metrics summary: {e}")
return {}
shared/message_bus.py
Python
"""
消息基础设施 - 至少一次交付、去重消费、顺序保证
"""
import json
import logging
import queue
import threading
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Any, List, Callable, Optional, Set
import uuid
logger = logging.getLogger(__name__)
class PartitionedMessageQueue:
"""
分区消息队列
使用分区键保证单用户操作顺序
"""
def __init__(self, partition_count: int = 16):
self.partition_count = partition_count
self._queues: Dict[int, queue.Queue] = {
i: queue.Queue() for i in range(partition_count)
}
self._lock = threading.Lock()
self._delivered: Set[str] = set() # 已交付消息ID
self._delivered_lock = threading.Lock()
self._subscribers: Dict[str, List[Callable]] = defaultdict(list)
self._subscriber_lock = threading.Lock()
self._running = False
self._threads: List[threading.Thread] = []
self._delivery_attempts: Dict[str, int] = {} # 交付尝试次数
self._max_delivery_attempts = 3
def _get_partition(self, partition_key: str) -> int:
"""根据分区键计算分区"""
if not partition_key:
return 0
return hash(partition_key) % self.partition_count
def publish(self, message: Dict[str, Any], partition_key: str = "") -> bool:
"""
发布消息到指定分区
保证相同partition_key的消息顺序消费
"""
try:
msg_id = message.get("event_id") or message.get("command_id") or str(uuid.uuid4())
message["_msg_id"] = msg_id
message["_partition_key"] = partition_key
message["_publish_time"] = datetime.utcnow().isoformat()
partition = self._get_partition(partition_key)
self._queues[partition].put(message)
logger.debug(f"Published message {msg_id} to partition {partition}")
return True
except Exception as e:
logger.error(f"Failed to publish message: {e}")
return False
def subscribe(self, event_type: str, handler: Callable[[Dict[str, Any]], None]):
"""订阅特定事件类型"""
with self._subscriber_lock:
self._subscribers[event_type].append(handler)
logger.info(f"Subscribed handler to event type: {event_type}")
def start_consuming(self):
"""启动消费者线程"""
self._running = True
for i in range(self.partition_count):
thread = threading.Thread(
target=self._consume_partition,
args=(i,),
name=f"PartitionConsumer-{i}",
daemon=True
)
thread.start()
self._threads.append(thread)
logger.info(f"Started {self.partition_count} partition consumers")
def stop_consuming(self):
"""停止消费"""
self._running = False
for thread in self._threads:
thread.join(timeout=5)
logger.info("Stopped all partition consumers")
def _consume_partition(self, partition: int):
"""消费指定分区的消息"""
q = self._queues[partition]
while self._running:
try:
message = q.get(timeout=1)
self._process_message(message)
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error consuming partition {partition}: {e}")
def _process_message(self, message: Dict[str, Any]):
"""处理消息 - 至少一次交付 + 去重"""
msg_id = message.get("_msg_id", "")
# 去重检查
with self._delivered_lock:
if msg_id in self._delivered:
logger.debug(f"Message {msg_id} already delivered, skipping")
return
self._delivered.add(msg_id)
# 清理旧的去重记录(保留最近10000条)
if len(self._delivered) > 10000:
self._delivered = set(list(self._delivered)[-5000:])
# 获取事件类型
event_type = message.get("event_type") or message.get("command_type", "")
# 调用订阅者
with self._subscriber_lock:
handlers = self._subscribers.get(event_type, [])
success = False
for handler in handlers:
try:
handler(message)
success = True
except Exception as e:
logger.error(f"Handler failed for message {msg_id}: {e}")
# 记录交付尝试
self._delivery_attempts[msg_id] = self._delivery_attempts.get(msg_id, 0) + 1
# 如果处理失败且未达到最大尝试次数,重新入队
if not success and self._delivery_attempts[msg_id] < self._max_delivery_attempts:
partition_key = message.get("_partition_key", "")
partition = self._get_partition(partition_key)
# 延迟重试
time.sleep(0.1 * self._delivery_attempts[msg_id])
self._queues[partition].put(message)
logger.warning(f"Requeued message {msg_id} for retry {self._delivery_attempts[msg_id]}")
elif not success:
logger.error(f"Message {msg_id} failed after {self._max_delivery_attempts} attempts")
def get_queue_depth(self) -> Dict[int, int]:
"""获取各分区队列深度"""
return {i: q.qsize() for i, q in self._queues.items()}
class InMemoryMessageBus:
"""
内存消息总线 - 模拟消息队列中间件
支持命令消息、事件消息、可靠交付
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self.command_queue = PartitionedMessageQueue(partition_count=8)
self.event_queue = PartitionedMessageQueue(partition_count=16)
self.response_queue = PartitionedMessageQueue(partition_count=4)
# 响应等待器 - 用于同步等待响应
self._pending_responses: Dict[str, threading.Event] = {}
self._response_data: Dict[str, Dict[str, Any]] = {}
self._response_lock = threading.Lock()
# 启动消费
self.command_queue.start_consuming()
self.event_queue.start_consuming()
self.response_queue.start_consuming()
# 响应队列订阅
self.response_queue.subscribe("RESPONSE", self._handle_response)
def _handle_response(self, message: Dict[str, Any]):
"""处理响应消息"""
command_id = message.get("command_id", "")
with self._response_lock:
self._response_data[command_id] = message
event = self._pending_responses.get(command_id)
if event:
event.set()
def publish_command(self, command: Dict[str, Any], partition_key: str = "") -> bool:
"""发布命令消息"""
return self.command_queue.publish(command, partition_key)
def publish_event(self, event: Dict[str, Any], partition_key: str = "") -> bool:
"""发布事件消息"""
return self.event_queue.publish(event, partition_key)
def publish_response(self, response: Dict[str, Any], partition_key: str = "") -> bool:
"""发布响应消息"""
return self.response_queue.publish(response, partition_key)
def subscribe_command(self, command_type: str, handler: Callable[[Dict[str, Any]], None]):
"""订阅命令"""
self.command_queue.subscribe(command_type, handler)
def subscribe_event(self, event_type: str, handler: Callable[[Dict[str, Any]], None]):
"""订阅事件"""
self.event_queue.subscribe(event_type, handler)
def wait_for_response(self, command_id: str, timeout: float = 30.0) -> Optional[Dict[str, Any]]:
"""同步等待响应"""
event = threading.Event()
with self._response_lock:
self._pending_responses[command_id] = event
if event.wait(timeout=timeout):
with self._response_lock:
return self._response_data.pop(command_id, None)
else:
with self._response_lock:
self._pending_responses.pop(command_id, None)
return None
def get_stats(self) -> Dict[str, Any]:
"""获取消息总线统计"""
return {
"command_queue_depth": self.command_queue.get_queue_depth(),
"event_queue_depth": self.event_queue.get_queue_depth(),
"response_queue_depth": self.response_queue.get_queue_depth(),
"pending_responses": len(self._pending_responses)
}
shared/utils.py
Python
"""
工具函数
"""
import hashlib
import json
import logging
import time
import uuid
from datetime import datetime
from functools import wraps
from typing import Any, Callable, Dict, Optional
logger = logging.getLogger(__name__)
def generate_idempotency_key(*args) -> str:
"""生成幂等性键"""
content = "|".join(str(arg) for arg in args)
return hashlib.sha256(content.encode()).hexdigest()[:32]
def generate_uuid() -> str:
"""生成UUID"""
return str(uuid.uuid4())
def now_iso() -> str:
"""当前ISO格式时间"""
return datetime.utcnow().isoformat()
def retry_with_backoff(max_retries: int = 3, base_delay: float = 0.1, max_delay: float = 5.0):
"""带退避的重试装饰器"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
delay = min(base_delay * (2 ** attempt), max_delay)
logger.warning(f"Retry {attempt + 1}/{max_retries} for {func.__name__} after {delay}s: {e}")
time.sleep(delay)
raise last_exception
return wrapper
return decorator
def timed_execution(func: Callable) -> Callable:
"""执行时间统计装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
try:
result = func(*args, **kwargs)
elapsed = int((time.time() - start) * 1000)
return result, elapsed
except Exception as e:
elapsed = int((time.time() - start) * 1000)
raise e
return wrapper
class CircuitBreaker:
"""熔断器"""
STATE_CLOSED = "CLOSED" # 正常
STATE_OPEN = "OPEN" # 熔断
STATE_HALF_OPEN = "HALF_OPEN" # 半开
def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 30.0):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.state = self.STATE_CLOSED
self.failure_count = 0
self.last_failure_time = 0.0
self._lock = False
def can_execute(self) -> bool:
if self.state == self.STATE_CLOSED:
return True
if self.state == self.STATE_OPEN:
if time.time() - self.last_failure_time > self.recovery_timeout:
self.state = self.STATE_HALF_OPEN
return True
return False
return True # HALF_OPEN
def record_success(self):
self.failure_count = 0
self.state = self.STATE_CLOSED
def record_failure(self):
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = self.STATE_OPEN
def call(self, func: Callable, *args, **kwargs):
if not self.can_execute():
raise Exception("Circuit breaker is OPEN")
try:
result = func(*args, **kwargs)
self.record_success()
return result
except Exception as e:
self.record_failure()
raise e
class RateLimiter:
"""速率限制器"""
def __init__(self, max_requests: int = 100, window_seconds: float = 60.0):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests: Dict[str, list] = {}
self._lock = threading.Lock()
def is_allowed(self, key: str) -> bool:
import threading
with self._lock:
now = time.time()
if key not in self.requests:
self.requests[key] = []
# 清理过期请求
self.requests[key] = [t for t in self.requests[key] if now - t < self.window_seconds]
if len(self.requests[key]) < self.max_requests:
self.requests[key].append(now)
return True
return False
3. 业务服务实现
services/flight_service/app.py
Python
#!/usr/bin/env python3
"""
航班服务 - 座位锁定与释放API
实现TCC模式:Try(锁定) -> Confirm(确认) -> Cancel(取消)
"""
import json
import logging
import sqlite3
import threading
import time
import uuid
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
import sys
sys.path.insert(0, '../..')
from shared.models import FlightReservation, IdempotencyKey
from shared.events import EventType, DomainEvent, CommandMessage, ResponseMessage
from shared.message_bus import InMemoryMessageBus
from shared.utils import generate_idempotency_key, retry_with_backoff, timed_execution, CircuitBreaker
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("FlightService")
class FlightDatabase:
"""航班数据库"""
def __init__(self, db_path: str = "flight_service.db"):
self.db_path = db_path
self._init_db()
def _get_connection(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _init_db(self):
conn = self._get_connection()
try:
conn.executescript("""
CREATE TABLE IF NOT EXISTS flights (
flight_id TEXT PRIMARY KEY,
airline TEXT NOT NULL,
from_city TEXT NOT NULL,
to_city TEXT NOT NULL,
departure_time TEXT NOT NULL,
arrival_time TEXT NOT NULL,
total_seats INTEGER NOT NULL,
available_seats INTEGER NOT NULL,
price REAL NOT NULL,
status TEXT DEFAULT 'ACTIVE'
);
CREATE TABLE IF NOT EXISTS seat_locks (
lock_id TEXT PRIMARY KEY,
flight_id TEXT NOT NULL,
seat_number TEXT NOT NULL,
saga_id TEXT NOT NULL,
lock_token TEXT NOT NULL,
locked_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
status TEXT DEFAULT 'LOCKED',
passenger_name TEXT,
passenger_id TEXT,
FOREIGN KEY (flight_id) REFERENCES flights(flight_id)
);
CREATE TABLE IF NOT EXISTS reservations (
reservation_id TEXT PRIMARY KEY,
flight_id TEXT NOT NULL,
seat_number TEXT NOT NULL,
saga_id TEXT NOT NULL,
passenger_name TEXT NOT NULL,
passenger_id TEXT NOT NULL,
price REAL NOT NULL,
status TEXT DEFAULT 'CONFIRMED',
created_at TEXT NOT NULL,
cancelled_at TEXT,
refund_amount REAL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS idempotency_store (
key TEXT PRIMARY KEY,
operation TEXT NOT NULL,
result TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_flight_id ON seat_locks(flight_id);
CREATE INDEX IF NOT EXISTS idx_saga_id ON seat_locks(saga_id);
CREATE INDEX IF NOT EXISTS idx_lock_token ON seat_locks(lock_token);
""")
conn.commit()
# 初始化示例航班数据
cursor = conn.execute("SELECT COUNT(*) FROM flights")
if cursor.fetchone()[0] == 0:
self._seed_data(conn)
finally:
conn.close()
def _seed_data(self, conn):
"""初始化航班数据"""
flights = [
("CA1001", "Air China", "Beijing", "Shanghai",
(datetime.now() + timedelta(days=1)).isoformat(),
(datetime.now() + timedelta(days=1, hours=2)).isoformat(),
200, 200, 1200.0),
("CA1002", "Air China", "Shanghai", "Beijing",
(datetime.now() + timedelta(days=2)).isoformat(),
(datetime.now() + timedelta(days=2, hours=2)).isoformat(),
200, 200, 1100.0),
("MU5001", "China Eastern", "Beijing", "Guangzhou",
(datetime.now() + timedelta(days=1)).isoformat(),
(datetime.now() + timedelta(days=1, hours=3)).isoformat(),
300, 300, 1500.0),
("CZ3001", "China Southern", "Shanghai", "Shenzhen",
(datetime.now() + timedelta(days=3)).isoformat(),
(datetime.now() + timedelta(days=3, hours=2, minutes=30)).isoformat(),
250, 250, 1300.0),
]
conn.executemany("""
INSERT INTO flights (flight_id, airline, from_city, to_city, departure_time, arrival_time, total_seats, available_seats, price)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", flights)
conn.commit()
logger.info("Seeded flight data")
def get_available_seats(self, flight_id: str) -> List[Dict[str, Any]]:
"""获取可用座位"""
conn = self._get_connection()
try:
cursor = conn.execute("""
SELECT seat_number FROM seat_locks
WHERE flight_id = ? AND status = 'LOCKED' AND expires_at > ?
""", (flight_id, datetime.now().isoformat()))
locked_seats = {row["seat_number"] for row in cursor.fetchall()}
cursor = conn.execute("""
SELECT seat_number FROM reservations
WHERE flight_id = ? AND status = 'CONFIRMED'
""", (flight_id,))
reserved_seats = {row["seat_number"] for row in cursor.fetchall()}
# 生成所有座位
cursor = conn.execute("SELECT total_seats FROM flights WHERE flight_id = ?", (flight_id,))
row = cursor.fetchone()
if not row:
return []
total = row["total_seats"]
all_seats = [f"{i+1}A" if i % 6 == 0 else f"{i+1}B" if i % 6 == 1 else
f"{i+1}C" if i % 6 == 2 else f"{i+1}D" if i % 6 == 3 else
f"{i+1}E" if i % 6 == 4 else f"{i+1}F" for i in range(total)]
available = [s for s in all_seats if s not in locked_seats and s not in reserved_seats]
return [{"seat_number": s, "status": "AVAILABLE"} for s in available[:20]]
finally:
conn.close()
def lock_seat(self, flight_id: str, seat_number: str, saga_id: str,
passenger_name: str, passenger_id: str) -> Dict[str, Any]:
"""TCC Try: 锁定座位"""
conn = self._get_connection()
try:
conn.execute("BEGIN IMMEDIATE")
# 检查航班是否存在
cursor = conn.execute("SELECT * FROM flights WHERE flight_id = ? AND status = 'ACTIVE'",
(flight_id,))
if not cursor.fetchone():
conn.rollback()
return {"success": False, "error": "Flight not found or inactive"}
# 检查座位是否已被锁定或预订
cursor = conn.execute("""
SELECT * FROM seat_locks
WHERE flight_id = ? AND seat_number = ? AND status = 'LOCKED' AND expires_at > ?
""", (flight_id, seat_number, datetime.now().isoformat()))
if cursor.fetchone():
conn.rollback()
return {"success": False, "error": "Seat already locked"}
cursor = conn.execute("""
SELECT * FROM reservations
WHERE flight_id = ? AND seat_number = ? AND status = 'CONFIRMED'
""", (flight_id, seat_number))
if cursor.fetchone():
conn.rollback()
return {"success": False, "error": "Seat already reserved"}
# 生成锁令牌
lock_token = str(uuid.uuid4())
lock_id = str(uuid.uuid4())
expires_at = (datetime.now() + timedelta(minutes=10)).isoformat()
conn.execute("""
INSERT INTO seat_locks (lock_id, flight_id, seat_number, saga_id, lock_token,
locked_at, expires_at, status, passenger_name, passenger_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (lock_id, flight_id, seat_number, saga_id, lock_token,
datetime.now().isoformat(), expires_at, "LOCKED", passenger_name, passenger_id))
# 减少可用座位数
conn.execute("""
UPDATE flights SET available_seats = available_seats - 1
WHERE flight_id = ?
""", (flight_id,))
conn.commit()
return {
"success": True,
"lock_id": lock_id,
"lock_token": lock_token,
"expires_at": expires_at,
"flight_id": flight_id,
"seat_number": seat_number
}
except Exception as e:
conn.rollback()
logger.error(f"Lock seat failed: {e}")
return {"success": False, "error": str(e)}
finally:
conn.close()
def confirm_seat(self, lock_token: str) -> Dict[str, Any]:
"""TCC Confirm: 确认座位预订"""
conn = self._get_connection()
try:
conn.execute("BEGIN IMMEDIATE")
cursor = conn.execute("""
SELECT * FROM seat_locks WHERE lock_token = ? AND status = 'LOCKED'
""", (lock_token,))
lock = cursor.fetchone()
if not lock:
conn.rollback()
return {"success": False, "error": "Lock not found or expired"}
# 检查是否已过期
if datetime.fromisoformat(lock["expires_at"]) < datetime.now():
conn.rollback()
return {"success": False, "error": "Lock expired"}
# 创建正式预订
reservation_id = str(uuid.uuid4())
conn.execute("""
INSERT INTO reservations (reservation_id, flight_id, seat_number, saga_id,
passenger_name, passenger_id, price, status, created_at)
SELECT ?, flight_id, seat_number, saga_id, passenger_name, passenger_id,
(SELECT price FROM flights WHERE flight_id = seat_locks.flight_id),
'CONFIRMED', ?
FROM seat_locks WHERE lock_token = ?
""", (reservation_id, datetime.now().isoformat(), lock_token))
# 更新锁状态
conn.execute("""
UPDATE seat_locks SET status = 'CONFIRMED' WHERE lock_token = ?
""", (lock_token,))
conn.commit()
return {
"success": True,
"reservation_id": reservation_id,
"lock_token": lock_token
}
except Exception as e:
conn.rollback()
logger.error(f"Confirm seat failed: {e}")
return {"success": False, "error": str(e)}
finally:
conn.close()
def cancel_lock(self, lock_token: str) -> Dict[str, Any]:
"""TCC Cancel: 取消座位锁定"""
conn = self._get_connection()
try:
conn.execute("BEGIN IMMEDIATE")
cursor = conn.execute("""
SELECT * FROM seat_locks WHERE lock_token = ? AND status = 'LOCKED'
""", (lock_token,))
lock = cursor.fetchone()
if not lock:
# 可能已经确认或已取消,幂等处理
cursor = conn.execute("""
SELECT * FROM seat_locks WHERE lock_token = ?
""", (lock_token,))
existing = cursor.fetchone()
if existing and existing["status"] == "CANCELLED":
return {"success": True, "message": "Already cancelled"}
return {"success": False, "error": "Lock not found"}
# 释放座位
conn.execute("""
UPDATE flights SET available_seats = available_seats + 1
WHERE flight_id = ?
""", (lock["flight_id"],))
conn.execute("""
UPDATE seat_locks SET status = 'CANCELLED' WHERE lock_token = ?
""", (lock_token,))
conn.commit()
return {"success": True, "lock_token": lock_token}
except Exception as e:
conn.rollback()
logger.error(f"Cancel lock failed: {e}")
return {"success": False, "error": str(e)}
finally:
conn.close()
def release_seat(self, reservation_id: str) -> Dict[str, Any]:
"""释放已确认座位(补偿用)"""
conn = self._get_connection()
try:
conn.execute("BEGIN IMMEDIATE")
cursor = conn.execute("""
SELECT * FROM reservations WHERE reservation_id = ? AND status = 'CONFIRMED'
""", (reservation_id,))
reservation = cursor.fetchone()
if not reservation:
return {"success": True, "message": "Reservation not found or already cancelled"}
# 计算退款金额(简化:全额退款)
refund_amount = reservation["price"]
conn.execute("""
UPDATE reservations SET status = 'CANCELLED', cancelled_at = ?, refund_amount = ?
WHERE reservation_id = ?
""", (datetime.now().isoformat(), refund_amount, reservation_id))
# 恢复可用座位
conn.execute("""
UPDATE flights SET available_seats = available_seats + 1
WHERE flight_id = ?
""", (reservation["flight_id"],))
conn.commit()
return {
"success": True,
"reservation_id": reservation_id,
"refund_amount": refund_amount
}
except Exception as e:
conn.rollback()
logger.error(f"Release seat failed: {e}")
return {"success": False, "error": str(e)}
finally:
conn.close()
def check_idempotency(self, key: str) -> Optional[Dict[str, Any]]:
"""检查幂等性"""
conn = self._get_connection()
try:
cursor = conn.execute("SELECT * FROM idempotency_store WHERE key = ?", (key,))
row = cursor.fetchone()
if row:
return json.loads(row["result"]) if row["result"] else None
return None
finally:
conn.close()
def save_idempotency(self, key: str, operation: str, result: Dict[str, Any]):
"""保存幂等性结果"""
conn = self._get_connection()
try:
conn.execute("""
INSERT OR REPLACE INTO idempotency_store (key, operation, result, created_at)
VALUES (?, ?, ?, ?)
""", (key, operation, json.dumps(result), datetime.now().isoformat()))
conn.commit()
finally:
conn.close()
class FlightService:
"""航班服务"""
def __init__(self):
self.db = FlightDatabase()
self.bus = InMemoryMessageBus()
self.circuit_breaker = CircuitBreaker(failure_threshold=5)
self._setup_subscribers()
self._start_cleanup_thread()
def _setup_subscribers(self):
"""设置消息订阅"""
self.bus.subscribe_command(EventType.CMD_LOCK_SEAT, self._handle_lock_seat)
self.bus.subscribe_command(EventType.CMD_CONFIRM_SEAT, self._handle_confirm_seat)
self.bus.subscribe_command(EventType.CMD_CANCEL_SEAT, self._handle_cancel_seat)
self.bus.subscribe_command(EventType.CMD_RELEASE_SEAT, self._handle_release_seat)
def _start_cleanup_thread(self):
"""启动过期锁清理线程"""
def cleanup():
while True:
time.sleep(60)
self._cleanup_expired_locks()
thread = threading.Thread(target=cleanup, daemon=True, name="FlightLockCleanup")
thread.start()
def _cleanup_expired_locks(self):
"""清理过期锁"""
conn = self.db._get_connection()
try:
cursor = conn.execute("""
SELECT flight_id, lock_token FROM seat_locks
WHERE status = 'LOCKED' AND expires_at < ?
""", (datetime.now().isoformat(),))
expired = cursor.fetchall()
for row in expired:
self.db.cancel_lock(row["lock_token"])
logger.info(f"Cleaned up expired lock: {row['lock_token']}")
except Exception as e:
logger.error(f"Cleanup failed: {e}")
finally:
conn.close()
def _handle_lock_seat(self, command: Dict[str, Any]):
"""处理锁定座位命令"""
start_time = time.time()
cmd = CommandMessage.from_dict(command)
saga_id = cmd.saga_id
step_id = cmd.step_id
# 幂等性检查
idempotency_key = cmd.idempotency_key or generate_idempotency_key(
cmd.command_type, saga_id, step_id, json.dumps(cmd.payload, sort_keys=True)
)
cached = self.db.check_idempotency(idempotency_key)
if cached is not None:
logger.info(f"Idempotent lock seat return cached result for {saga_id}")
self._send_response(cmd, cached, int((time.time() - start_time) * 1000))
return
try:
result = self.circuit_breaker.call(
self.db.lock_seat,
cmd.payload["flight_id"],
cmd.payload["seat_number"],
saga_id,
cmd.payload["passenger_name"],
cmd.payload["passenger_id"]
)
# 保存幂等性结果
self.db.save_idempotency(idempotency_key, cmd.command_type, result)
# 发布事件
if result["success"]:
event = DomainEvent(
event_type=EventType.EVT_SEAT_LOCKED,
saga_id=saga_id,
step_id=step_id,
correlation_id=cmd.correlation_id,
payload=result,
partition_key=cmd.payload.get("user_id", saga_id)
)
self.bus.publish_event(event.to_dict(), partition_key=event.partition_key)
else:
event = DomainEvent(
event_type=EventType.EVT_SEAT_LOCK_FAILED,
saga_id=saga_id,
step_id=step_id,
correlation_id=cmd.correlation_id,
payload=result,
partition_key=cmd.payload.get("user_id", saga_id)
)
self.bus.publish_event(event.to_dict(), partition_key=event.partition_key)
self._send_response(cmd, result, int((time.time() - start_time) * 1000))
except Exception as e:
logger.error(f"Lock seat error: {e}")
error_result = {"success": False, "error": str(e)}
self._send_response(cmd, error_result, int((time.time() - start_time) * 1000))
def _handle_confirm_seat(self, command: Dict[str, Any]):
"""处理确认座位命令"""
start_time = time.time()
cmd = CommandMessage.from_dict(command)
try:
result = self.db.confirm_seat(cmd.payload["lock_token"])
self._send_response(cmd, result, int((time.time() - start_time) * 1000))
except Exception as e:
logger.error(f"Confirm seat error: {e}")
self._send_response(cmd, {"success": False, "error": str(e)},
int((time.time() - start_time) * 1000))
def _handle_cancel_seat(self, command: Dict[str, Any]):
"""处理取消锁定命令"""
start_time = time.time()
cmd = CommandMessage.from_dict(command)
try:
result = self.db.cancel_lock(cmd.payload["lock_token"])
self._send_response(cmd, result, int((time.time() - start_time) * 1000))
except Exception as e:
logger.error(f"Cancel seat error: {e}")
self._send_response(cmd, {"success": False, "error": str(e)},
int((time.time() - start_time) * 1000))
def _handle_release_seat(self, command: Dict[str, Any]):
"""处理释放座位命令(补偿)"""
start_time = time.time()
cmd = CommandMessage.from_dict(command)
try:
result = self.db.release_seat(cmd.payload["reservation_id"])
self._send_response(cmd, result, int((time.time() - start_time) * 1000))
except Exception as e:
logger.error(f"Release seat error: {e}")
self._send_response(cmd, {"success": False, "error": str(e)},
int((time.time() - start_time) * 1000))
def _send_response(self, command: CommandMessage, result: Dict[str, Any], processing_time: int):
"""发送响应"""
response = ResponseMessage(
command_id=command.command_id,
saga_id=command.saga_id,
step_id=command.step_id,
correlation_id=command.correlation_id,
success=result.get("success", False),
result=result,
error_code=result.get("error_code"),
error_message=result.get("error"),
processing_time_ms=processing_time
)
self.bus.publish_response(response.to_dict(), partition_key=command.correlation_id)
if __name__ == "__main__":
service = FlightService()
logger.info("Flight Service started")
# 保持运行
while True:
time.sleep(1)