Python的抽象基类(ABC):定义接口契约的艺术

目录

  • Python的抽象基类(ABC):定义接口契约的艺术
    • [1. 引言](#1. 引言)
      • [1.1 什么是抽象基类?](#1.1 什么是抽象基类?)
      • [1.2 为什么需要抽象基类?](#1.2 为什么需要抽象基类?)
    • [2. 抽象基类的基础](#2. 抽象基类的基础)
      • [2.1 创建抽象基类](#2.1 创建抽象基类)
      • [2.2 实现抽象基类](#2.2 实现抽象基类)
      • [2.3 抽象属性的使用](#2.3 抽象属性的使用)
    • [3. 高级抽象基类特性](#3. 高级抽象基类特性)
      • [3.1 多重继承与抽象基类](#3.1 多重继承与抽象基类)
      • [3.2 抽象类方法与静态方法](#3.2 抽象类方法与静态方法)
    • [4. 注册机制与虚拟子类](#4. 注册机制与虚拟子类)
      • [4.1 注册虚拟子类](#4.1 注册虚拟子类)
      • [4.2 `subclasshook`方法](#4.2 __subclasshook__方法)
    • [5. 标准库中的抽象基类](#5. 标准库中的抽象基类)
      • [5.1 集合抽象基类](#5.1 集合抽象基类)
      • [5.2 上下文管理器抽象基类](#5.2 上下文管理器抽象基类)
    • [6. 类型检查与抽象基类](#6. 类型检查与抽象基类)
      • [6.1 类型提示与抽象基类](#6.1 类型提示与抽象基类)
    • [7. 设计模式与抽象基类](#7. 设计模式与抽象基类)
      • [7.1 策略模式](#7.1 策略模式)
      • [7.2 观察者模式](#7.2 观察者模式)
    • [8. 完整代码示例:电子商务系统](#8. 完整代码示例:电子商务系统)
    • [9. 最佳实践与注意事项](#9. 最佳实践与注意事项)
      • [9.1 抽象基类设计原则](#9.1 抽象基类设计原则)
      • [9.2 常见陷阱与解决方案](#9.2 常见陷阱与解决方案)
    • [10. 总结](#10. 总结)
      • [10.1 关键要点](#10.1 关键要点)
      • [10.2 适用场景](#10.2 适用场景)
    • [11. 代码自查](#11. 代码自查)

『宝藏代码胶囊开张啦!』------ 我的 CodeCapsule 来咯!✨写代码不再头疼!我的新站点 CodeCapsule 主打一个 "白菜价"+"量身定制 "!无论是卡脖子的毕设/课设/文献复现 ,需要灵光一现的算法改进 ,还是想给项目加个"外挂",这里都有便宜又好用的代码方案等你发现!低成本,高适配,助你轻松通关!速来围观 👉 CodeCapsule官网

Python的抽象基类(ABC):定义接口契约的艺术

1. 引言

在面向对象编程中,抽象基类(Abstract Base Classes,简称ABC)是一种强大的工具,用于定义接口契约和建立类之间的规范。Python通过abc模块提供了对抽象基类的原生支持,这使得开发者能够创建更加健壮、可维护的代码。

1.1 什么是抽象基类?

抽象基类是不能被实例化的类,其主要目的是为子类定义接口契约。它规定了子类必须实现的方法和属性,从而确保了一致性和可预测性。

1.2 为什么需要抽象基类?

  • 强制接口实现:确保子类实现了必要的方法
  • 提供明确的契约:定义清晰的API规范
  • 增强代码可读性:明确表达设计意图
  • 便于类型检查:支持静态类型检查和IDE提示
  • 多态性支持:统一的接口,不同的实现
python 复制代码
from abc import ABC, abstractmethod

class Shape(ABC):
    """形状抽象基类"""
    
    @abstractmethod
    def area(self):
        """计算面积"""
        pass
    
    @abstractmethod
    def perimeter(self):
        """计算周长"""
        pass
    
    def description(self):
        """非抽象方法,可以有默认实现"""
        return "这是一个形状"

2. 抽象基类的基础

2.1 创建抽象基类

Python中使用ABC元类和abstractmethod装饰器来创建抽象基类。

python 复制代码
from abc import ABC, abstractmethod
import math

class Vehicle(ABC):
    """交通工具抽象基类"""
    
    def __init__(self, name, max_speed):
        self.name = name
        self.max_speed = max_speed
        self._current_speed = 0
    
    @abstractmethod
    def start_engine(self):
        """启动引擎"""
        pass
    
    @abstractmethod
    def stop_engine(self):
        """停止引擎"""
        pass
    
    @abstractmethod
    def accelerate(self, increment):
        """加速"""
        pass
    
    def get_current_speed(self):
        """获取当前速度 - 具体方法"""
        return self._current_speed
    
    def __str__(self):
        return f"{self.name} (最大速度: {self.max_speed} km/h)"

2.2 实现抽象基类

子类必须实现所有抽象方法,否则在实例化时会抛出TypeError

python 复制代码
class Car(Vehicle):
    """汽车类 - 实现Vehicle抽象基类"""
    
    def __init__(self, name, max_speed, fuel_type):
        super().__init__(name, max_speed)
        self.fuel_type = fuel_type
        self._engine_started = False
    
    def start_engine(self):
        if not self._engine_started:
            self._engine_started = True
            return f"{self.name}引擎已启动"
        return f"{self.name}引擎已经在运行"
    
    def stop_engine(self):
        if self._engine_started:
            self._engine_started = False
            self._current_speed = 0
            return f"{self.name}引擎已停止"
        return f"{self.name}引擎已经停止"
    
    def accelerate(self, increment):
        if not self._engine_started:
            return "请先启动引擎"
        
        new_speed = self._current_speed + increment
        if new_speed <= self.max_speed:
            self._current_speed = new_speed
            return f"加速到 {self._current_speed} km/h"
        else:
            self._current_speed = self.max_speed
            return f"已达到最大速度 {self.max_speed} km/h"

class Bicycle(Vehicle):
    """自行车类 - 实现Vehicle抽象基类"""
    
    def __init__(self, name, max_speed, gear_count):
        super().__init__(name, max_speed)
        self.gear_count = gear_count
        self._current_gear = 1
    
    def start_engine(self):
        return "自行车没有引擎"
    
    def stop_engine(self):
        return "自行车没有引擎"
    
    def accelerate(self, increment):
        new_speed = self._current_speed + increment
        if new_speed <= self.max_speed:
            self._current_speed = new_speed
            return f"蹬车加速到 {self._current_speed} km/h"
        else:
            self._current_speed = self.max_speed
            return f"已达到最大速度 {self.max_speed} km/h"
    
    def change_gear(self, new_gear):
        if 1 <= new_gear <= self.gear_count:
            self._current_gear = new_gear
            return f"切换到 {new_gear} 档"
        return "无效的档位"

def demonstrate_vehicle_abc():
    """演示抽象基类的使用"""
    # 创建汽车实例
    car = Car("丰田卡罗拉", 180, "汽油")
    print(car)
    print(car.start_engine())
    print(car.accelerate(50))
    print(car.accelerate(100))
    print(car.stop_engine())
    
    print("\n" + "="*50 + "\n")
    
    # 创建自行车实例
    bike = Bicycle("山地自行车", 40, 21)
    print(bike)
    print(bike.start_engine())
    print(bike.accelerate(15))
    print(bike.change_gear(3))
    print(bike.accelerate(10))

if __name__ == "__main__":
    demonstrate_vehicle_abc()

2.3 抽象属性的使用

除了抽象方法,Python还支持抽象属性,使用abstractproperty或结合propertyabstractmethod

python 复制代码
from abc import ABC, abstractmethod

class DatabaseConnection(ABC):
    """数据库连接抽象基类"""
    
    @property
    @abstractmethod
    def is_connected(self):
        """连接状态"""
        pass
    
    @property
    @abstractmethod
    def connection_string(self):
        """连接字符串"""
        pass
    
    @abstractmethod
    def connect(self):
        """建立连接"""
        pass
    
    @abstractmethod
    def disconnect(self):
        """断开连接"""
        pass
    
    @abstractmethod
    def execute_query(self, query):
        """执行查询"""
        pass

class MySQLConnection(DatabaseConnection):
    """MySQL数据库连接实现"""
    
    def __init__(self, host, user, password, database):
        self._host = host
        self._user = user
        self._password = password
        self._database = database
        self._connected = False
    
    @property
    def is_connected(self):
        return self._connected
    
    @property
    def connection_string(self):
        return f"mysql://{self._user}@{self._host}/{self._database}"
    
    def connect(self):
        if not self._connected:
            # 模拟连接建立
            self._connected = True
            return f"已连接到MySQL数据库: {self.connection_string}"
        return "已经连接到数据库"
    
    def disconnect(self):
        if self._connected:
            self._connected = False
            return "MySQL连接已断开"
        return "未连接到数据库"
    
    def execute_query(self, query):
        if not self._connected:
            raise ConnectionError("请先连接到数据库")
        return f"执行MySQL查询: {query}"

def demonstrate_abstract_properties():
    """演示抽象属性的使用"""
    mysql_conn = MySQLConnection("localhost", "admin", "password", "myapp")
    
    print(f"连接字符串: {mysql_conn.connection_string}")
    print(f"连接状态: {mysql_conn.is_connected}")
    
    print(mysql_conn.connect())
    print(f"连接状态: {mysql_conn.is_connected}")
    
    print(mysql_conn.execute_query("SELECT * FROM users"))
    print(mysql_conn.disconnect())

if __name__ == "__main__":
    demonstrate_abstract_properties()

3. 高级抽象基类特性

3.1 多重继承与抽象基类

抽象基类支持多重继承,可以创建复杂的接口层次结构。

python 复制代码
from abc import ABC, abstractmethod
from typing import List

class Readable(ABC):
    """可读接口"""
    
    @abstractmethod
    def read(self) -> str:
        """读取内容"""
        pass
    
    @abstractmethod
    def get_size(self) -> int:
        """获取大小"""
        pass

class Writable(ABC):
    """可写接口"""
    
    @abstractmethod
    def write(self, content: str) -> bool:
        """写入内容"""
        pass
    
    @abstractmethod
    def append(self, content: str) -> bool:
        """追加内容"""
        pass

class ReadWriteFile(Readable, Writable):
    """读写文件实现"""
    
    def __init__(self, filename: str):
        self.filename = filename
        self._content = ""
    
    def read(self) -> str:
        return self._content
    
    def get_size(self) -> int:
        return len(self._content)
    
    def write(self, content: str) -> bool:
        self._content = content
        return True
    
    def append(self, content: str) -> bool:
        self._content += content
        return True
    
    def clear(self):
        """额外的方法"""
        self._content = ""

class NetworkStream(Readable, Writable):
    """网络流实现"""
    
    def __init__(self, buffer_size: int = 1024):
        self.buffer: List[str] = []
        self.buffer_size = buffer_size
    
    def read(self) -> str:
        if self.buffer:
            return self.buffer.pop(0)
        return ""
    
    def get_size(self) -> int:
        return sum(len(item) for item in self.buffer)
    
    def write(self, content: str) -> bool:
        if len(content) <= self.buffer_size:
            self.buffer.append(content)
            return True
        return False
    
    def append(self, content: str) -> bool:
        return self.write(content)  # 对于网络流,write和append行为相同

def demonstrate_multiple_inheritance():
    """演示多重继承的抽象基类"""
    
    # 文件操作
    file_processor = ReadWriteFile("example.txt")
    file_processor.write("Hello, World!")
    file_processor.append(" This is additional content.")
    
    print("文件内容:", file_processor.read())
    print("文件大小:", file_processor.get_size())
    
    print("\n" + "="*50 + "\n")
    
    # 网络流操作
    stream = NetworkStream()
    stream.write("数据包1")
    stream.write("数据包2")
    
    print("流内容:", stream.read())
    print("流大小:", stream.get_size())
    
    # 类型检查
    print(f"\nfile_processor是Readable: {isinstance(file_processor, Readable)}")
    print(f"file_processor是Writable: {isinstance(file_processor, Writable)}")
    print(f"stream是Readable: {isinstance(stream, Readable)}")
    print(f"stream是Writable: {isinstance(stream, Writable)}")

if __name__ == "__main__":
    demonstrate_multiple_inheritance()

3.2 抽象类方法与静态方法

抽象基类也支持抽象类方法和抽象静态方法。

python 复制代码
from abc import ABC, abstractmethod, abstractclassmethod, abstractstaticmethod
from datetime import datetime
import json

class Serializer(ABC):
    """序列化器抽象基类"""
    
    @abstractclassmethod
    def get_format_name(cls) -> str:
        """获取格式名称"""
        pass
    
    @abstractstaticmethod
    def is_valid_content(content: str) -> bool:
        """验证内容格式"""
        pass
    
    @abstractmethod
    def serialize(self, data: dict) -> str:
        """序列化数据"""
        pass
    
    @abstractmethod
    def deserialize(self, content: str) -> dict:
        """反序列化内容"""
        pass
    
    @classmethod
    def get_serializer_info(cls):
        """获取序列化器信息"""
        return {
            "format": cls.get_format_name(),
            "timestamp": datetime.now().isoformat()
        }

class JSONSerializer(Serializer):
    """JSON序列化器"""
    
    @classmethod
    def get_format_name(cls) -> str:
        return "JSON"
    
    @staticmethod
    def is_valid_content(content: str) -> bool:
        try:
            json.loads(content)
            return True
        except (json.JSONDecodeError, TypeError):
            return False
    
    def serialize(self, data: dict) -> str:
        return json.dumps(data, indent=2, ensure_ascii=False)
    
    def deserialize(self, content: str) -> dict:
        if not self.is_valid_content(content):
            raise ValueError("无效的JSON内容")
        return json.loads(content)

class XMLSerializer(Serializer):
    """XML序列化器(简化版)"""
    
    @classmethod
    def get_format_name(cls) -> str:
        return "XML"
    
    @staticmethod
    def is_valid_content(content: str) -> bool:
        # 简化的XML验证
        return content.strip().startswith('<') and content.strip().endswith('>')
    
    def serialize(self, data: dict) -> str:
        xml_parts = ['<root>']
        for key, value in data.items():
            xml_parts.append(f'  <{key}>{value}</{key}>')
        xml_parts.append('</root>')
        return '\n'.join(xml_parts)
    
    def deserialize(self, content: str) -> dict:
        if not self.is_valid_content(content):
            raise ValueError("无效的XML内容")
        # 简化的XML解析
        lines = content.strip().split('\n')
        result = {}
        for line in lines:
            line = line.strip()
            if line.startswith('<') and not line.startswith('</') and not line.startswith('<root'):
                key = line.split('>')[0][1:]
                value = line.split('>')[1].split('<')[0]
                result[key] = value
        return result

def demonstrate_abstract_class_static_methods():
    """演示抽象类方法和静态方法"""
    
    # 测试JSON序列化器
    json_serializer = JSONSerializer()
    print(f"序列化器格式: {JSONSerializer.get_format_name()}")
    print(f"序列化器信息: {JSONSerializer.get_serializer_info()}")
    
    test_data = {"name": "Alice", "age": 30, "city": "Beijing"}
    json_content = json_serializer.serialize(test_data)
    print(f"\nJSON序列化结果:\n{json_content}")
    print(f"内容是否有效: {JSONSerializer.is_valid_content(json_content)}")
    
    deserialized_data = json_serializer.deserialize(json_content)
    print(f"反序列化数据: {deserialized_data}")
    
    print("\n" + "="*50 + "\n")
    
    # 测试XML序列化器
    xml_serializer = XMLSerializer()
    print(f"序列化器格式: {XMLSerializer.get_format_name()}")
    
    xml_content = xml_serializer.serialize(test_data)
    print(f"\nXML序列化结果:\n{xml_content}")
    print(f"内容是否有效: {XMLSerializer.is_valid_content(xml_content)}")
    
    deserialized_xml = xml_serializer.deserialize(xml_content)
    print(f"反序列化数据: {deserialized_xml}")

if __name__ == "__main__":
    demonstrate_abstract_class_static_methods()

4. 注册机制与虚拟子类

Python的抽象基类支持注册机制,可以将现有的类注册为抽象基类的虚拟子类。

4.1 注册虚拟子类

python 复制代码
from abc import ABC, abstractmethod
from collections.abc import Sequence

class DataProcessor(ABC):
    """数据处理器抽象基类"""
    
    @abstractmethod
    def process(self, data):
        """处理数据"""
        pass
    
    @abstractmethod
    def validate(self, data) -> bool:
        """验证数据"""
        pass
    
    def get_info(self):
        """获取处理器信息"""
        return "通用数据处理器"

# 现有的类,没有继承DataProcessor
class ListProcessor:
    """列表处理器"""
    
    def __init__(self, multiplier=1):
        self.multiplier = multiplier
    
    def process(self, data):
        if isinstance(data, (list, tuple)):
            return [item * self.multiplier for item in data]
        raise TypeError("数据必须是列表或元组")
    
    def validate(self, data) -> bool:
        return isinstance(data, (list, tuple)) and all(
            isinstance(item, (int, float)) for item in data
        )
    
    def get_info(self):
        return f"列表处理器 (乘数: {self.multiplier})"

# 注册为虚拟子类
DataProcessor.register(ListProcessor)

def demonstrate_virtual_subclasses():
    """演示虚拟子类"""
    
    # 创建列表处理器实例
    processor = ListProcessor(multiplier=2)
    
    # 检查类型
    print(f"processor是DataProcessor: {isinstance(processor, DataProcessor)}")
    print(f"ListProcessor是DataProcessor子类: {issubclass(ListProcessor, DataProcessor)}")
    
    # 使用处理器
    test_data = [1, 2, 3, 4, 5]
    print(f"验证数据: {processor.validate(test_data)}")
    print(f"处理结果: {processor.process(test_data)}")
    print(f"处理器信息: {processor.get_info()}")

class DictProcessor:
    """字典处理器"""
    
    def process(self, data):
        if isinstance(data, dict):
            return {k: v.upper() if isinstance(v, str) else v 
                   for k, v in data.items()}
        raise TypeError("数据必须是字典")
    
    def validate(self, data) -> bool:
        return isinstance(data, dict)

# 使用register方法注册
DataProcessor.register(DictProcessor)

def test_virtual_subclass_behavior():
    """测试虚拟子类行为"""
    
    dict_processor = DictProcessor()
    
    print(f"dict_processor是DataProcessor: {isinstance(dict_processor, DataProcessor)}")
    print(f"DictProcessor是DataProcessor子类: {issubclass(DictProcessor, DataProcessor)}")
    
    test_data = {"name": "alice", "age": 30}
    if dict_processor.validate(test_data):
        result = dict_processor.process(test_data)
        print(f"字典处理结果: {result}")

if __name__ == "__main__":
    demonstrate_virtual_subclasses()
    print("\n" + "="*50 + "\n")
    test_virtual_subclass_behavior()

4.2 __subclasshook__方法

__subclasshook__方法允许自定义子类检查逻辑。

python 复制代码
from abc import ABC, abstractmethod, ABCMeta

class Container(ABC):
    """容器接口"""
    
    @abstractmethod
    def add(self, item):
        """添加元素"""
        pass
    
    @abstractmethod
    def remove(self, item):
        """移除元素"""
        pass
    
    @abstractmethod
    def contains(self, item) -> bool:
        """检查是否包含元素"""
        pass
    
    @classmethod
    def __subclasshook__(cls, C):
        """自定义子类检查"""
        if cls is Container:
            # 检查类是否具有必要的方法
            required_methods = {'add', 'remove', 'contains'}
            if all(any(method in B.__dict__ for B in C.__mro__) 
                   for method in required_methods):
                return True
        return NotImplemented

# 现有的类,没有显式继承Container
class CustomCollection:
    """自定义集合类"""
    
    def __init__(self):
        self._items = []
    
    def add(self, item):
        self._items.append(item)
        return True
    
    def remove(self, item):
        if item in self._items:
            self._items.remove(item)
            return True
        return False
    
    def contains(self, item) -> bool:
        return item in self._items
    
    def __len__(self):
        return len(self._items)
    
    def __iter__(self):
        return iter(self._items)

class IncompleteCollection:
    """不完整的集合类(缺少某些方法)"""
    
    def __init__(self):
        self._items = []
    
    def add(self, item):
        self._items.append(item)

def demonstrate_subclasshook():
    """演示__subclasshook__方法"""
    
    custom_collection = CustomCollection()
    incomplete_collection = IncompleteCollection()
    
    # 测试子类关系
    print(f"CustomCollection是Container子类: {issubclass(CustomCollection, Container)}")
    print(f"IncompleteCollection是Container子类: {issubclass(IncompleteCollection, Container)}")
    
    print(f"custom_collection是Container实例: {isinstance(custom_collection, Container)}")
    print(f"incomplete_collection是Container实例: {isinstance(incomplete_collection, Container)}")
    
    # 使用自定义集合
    custom_collection.add("apple")
    custom_collection.add("banana")
    custom_collection.add("orange")
    
    print(f"集合包含'apple': {custom_collection.contains('apple')}")
    print(f"集合大小: {len(custom_collection)}")
    
    custom_collection.remove("banana")
    print(f"移除后集合大小: {len(custom_collection)}")
    
    print("集合内容:")
    for item in custom_collection:
        print(f"  - {item}")

if __name__ == "__main__":
    demonstrate_subclasshook()

5. 标准库中的抽象基类

Python标准库在collections.abc模块中提供了许多有用的抽象基类。

5.1 集合抽象基类

python 复制代码
from collections.abc import Container, Sized, Iterable, Iterator, Sequence, Mapping
import collections.abc

def demonstrate_stdlib_abc():
    """演示标准库中的抽象基类"""
    
    # 测试各种内置类型
    test_cases = [
        (list, "列表"),
        (tuple, "元组"),
        (str, "字符串"),
        (dict, "字典"),
        (set, "集合"),
    ]
    
    abc_types = [
        (Container, "Container"),
        (Sized, "Sized"),
        (Iterable, "Iterable"),
        (Iterator, "Iterator"),
        (Sequence, "Sequence"),
        (Mapping, "Mapping"),
    ]
    
    print("标准库抽象基类兼容性检查:")
    print("=" * 60)
    
    for abc_class, abc_name in abc_types:
        print(f"\n{abc_name}:")
        for test_class, test_name in test_cases:
            result = issubclass(test_class, abc_class)
            print(f"  {test_name:8} -> {result}")

class CustomSequence(collections.abc.Sequence):
    """自定义序列实现"""
    
    def __init__(self, *args):
        self._data = list(args)
    
    def __getitem__(self, index):
        return self._data[index]
    
    def __len__(self):
        return len(self._data)
    
    def __repr__(self):
        return f"CustomSequence({self._data})"
    
    def count(self, value):
        return self._data.count(value)
    
    def index(self, value, start=0, stop=None):
        if stop is None:
            stop = len(self._data)
        try:
            return self._data.index(value, start, stop)
        except ValueError:
            return -1

class CustomMapping(collections.abc.Mapping):
    """自定义映射实现"""
    
    def __init__(self, **kwargs):
        self._data = kwargs
    
    def __getitem__(self, key):
        return self._data[key]
    
    def __iter__(self):
        return iter(self._data)
    
    def __len__(self):
        return len(self._data)
    
    def __repr__(self):
        return f"CustomMapping({self._data})"

def demonstrate_custom_abc_implementations():
    """演示自定义抽象基类实现"""
    
    print("\n" + "=" * 60)
    print("自定义抽象基类实现演示")
    print("=" * 60)
    
    # 自定义序列
    custom_seq = CustomSequence(1, 2, 3, 4, 5, 2, 3)
    print(f"自定义序列: {custom_seq}")
    print(f"序列长度: {len(custom_seq)}")
    print(f"索引2: {custom_seq[2]}")
    print(f"切片[1:4]: {custom_seq[1:4]}")
    print(f"计数2: {custom_seq.count(2)}")
    print(f"索引3: {custom_seq.index(3)}")
    
    print(f"\n自定义序列是Sequence: {isinstance(custom_seq, Sequence)}")
    print(f"自定义序列是Iterable: {isinstance(custom_seq, Iterable)}")
    print(f"自定义序列是Sized: {isinstance(custom_seq, Sized)}")
    
    print("\n" + "-" * 40)
    
    # 自定义映射
    custom_map = CustomMapping(name="Alice", age=30, city="Beijing")
    print(f"自定义映射: {custom_map}")
    print(f"映射大小: {len(custom_map)}")
    print(f"键'name': {custom_map['name']}")
    print(f"所有键: {list(custom_map.keys())}")
    print(f"所有值: {list(custom_map.values())}")
    
    print(f"\n自定义映射是Mapping: {isinstance(custom_map, Mapping)}")
    print(f"自定义映射是Iterable: {isinstance(custom_map, Iterable)}")
    print(f"自定义映射是Sized: {isinstance(custom_map, Sized)}")

if __name__ == "__main__":
    demonstrate_stdlib_abc()
    demonstrate_custom_abc_implementations()

5.2 上下文管理器抽象基类

python 复制代码
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import contextmanager, AbstractContextManager
import time

class DatabaseTransaction(AbstractContextManager):
    """数据库事务上下文管理器"""
    
    def __init__(self, db_connection):
        self.db_connection = db_connection
        self._in_transaction = False
    
    def __enter__(self):
        self._in_transaction = True
        print(f"开始数据库事务: {self.db_connection}")
        # 模拟开始事务
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self._in_transaction = False
        if exc_type is None:
            print(f"提交数据库事务: {self.db_connection}")
            # 模拟提交事务
        else:
            print(f"回滚数据库事务: {self.db_connection} (错误: {exc_val})")
            # 模拟回滚事务
        return False  # 不抑制异常
    
    def execute(self, query):
        """在事务中执行查询"""
        if not self._in_transaction:
            raise RuntimeError("不在事务中")
        print(f"执行查询: {query}")
        return f"结果: {query}"

class FileProcessor(ABC):
    """文件处理器抽象基类"""
    
    @abstractmethod
    def process_file(self, filename):
        """处理文件"""
        pass
    
    @contextmanager
    def file_context(self, filename):
        """文件处理上下文"""
        print(f"打开文件: {filename}")
        start_time = time.time()
        try:
            yield self
        except Exception as e:
            print(f"处理文件时出错: {e}")
            raise
        finally:
            end_time = time.time()
            print(f"关闭文件: {filename} (耗时: {end_time - start_time:.2f}秒)")

class TextFileProcessor(FileProcessor):
    """文本文件处理器"""
    
    def process_file(self, filename):
        with self.file_context(filename):
            print(f"处理文本文件: {filename}")
            # 模拟文件处理
            time.sleep(0.1)
            return f"处理完成: {filename}"

class BinaryFileProcessor(FileProcessor):
    """二进制文件处理器"""
    
    def process_file(self, filename):
        with self.file_context(filename):
            print(f"处理二进制文件: {filename}")
            # 模拟文件处理
            time.sleep(0.2)
            return f"处理完成: {filename}"

def demonstrate_context_managers():
    """演示上下文管理器抽象基类"""
    
    print("数据库事务演示:")
    print("-" * 40)
    
    # 数据库事务演示
    db_conn = "MySQL连接@localhost"
    
    # 正常执行
    with DatabaseTransaction(db_conn) as transaction:
        result1 = transaction.execute("SELECT * FROM users")
        result2 = transaction.execute("UPDATE users SET active = 1")
        print(result1)
        print(result2)
    
    print("\n带异常的事务:")
    try:
        with DatabaseTransaction(db_conn) as transaction:
            transaction.execute("BEGIN TRANSACTION")
            transaction.execute("INSERT INTO users VALUES (...)")
            raise ValueError("模拟的错误")
    except ValueError as e:
        print(f"捕获异常: {e}")
    
    print("\n" + "=" * 50)
    print("文件处理器演示:")
    print("-" * 40)
    
    # 文件处理器演示
    text_processor = TextFileProcessor()
    binary_processor = BinaryFileProcessor()
    
    print(text_processor.process_file("document.txt"))
    print()
    print(binary_processor.process_file("image.jpg"))

if __name__ == "__main__":
    demonstrate_context_managers()

6. 类型检查与抽象基类

抽象基类在类型检查和静态分析中非常有用。

6.1 类型提示与抽象基类

python 复制代码
from abc import ABC, abstractmethod
from typing import List, Dict, Any, TypeVar, Generic, Optional
from dataclasses import dataclass

T = TypeVar('T')
U = TypeVar('U')

class Repository(ABC, Generic[T]):
    """泛型仓库抽象基类"""
    
    @abstractmethod
    def add(self, item: T) -> bool:
        """添加项目"""
        pass
    
    @abstractmethod
    def get(self, identifier: Any) -> Optional[T]:
        """根据标识符获取项目"""
        pass
    
    @abstractmethod
    def get_all(self) -> List[T]:
        """获取所有项目"""
        pass
    
    @abstractmethod
    def update(self, item: T) -> bool:
        """更新项目"""
        pass
    
    @abstractmethod
    def delete(self, identifier: Any) -> bool:
        """删除项目"""
        pass

@dataclass
class User:
    """用户数据类"""
    id: int
    name: str
    email: str
    age: int

class UserRepository(Repository[User]):
    """用户仓库实现"""
    
    def __init__(self):
        self._storage: Dict[int, User] = {}
        self._next_id = 1
    
    def add(self, user: User) -> bool:
        if user.id in self._storage:
            return False
        self._storage[user.id] = user
        return True
    
    def get(self, user_id: int) -> Optional[User]:
        return self._storage.get(user_id)
    
    def get_all(self) -> List[User]:
        return list(self._storage.values())
    
    def update(self, user: User) -> bool:
        if user.id not in self._storage:
            return False
        self._storage[user.id] = user
        return True
    
    def delete(self, user_id: int) -> bool:
        if user_id not in self._storage:
            return False
        del self._storage[user_id]
        return True
    
    def create_user(self, name: str, email: str, age: int) -> User:
        """创建新用户"""
        user = User(id=self._next_id, name=name, email=email, age=age)
        self._next_id += 1
        self.add(user)
        return user

def demonstrate_generic_repository():
    """演示泛型仓库模式"""
    
    user_repo = UserRepository()
    
    # 创建用户
    users = [
        user_repo.create_user("Alice", "alice@example.com", 30),
        user_repo.create_user("Bob", "bob@example.com", 25),
        user_repo.create_user("Charlie", "charlie@example.com", 35)
    ]
    
    print("所有用户:")
    for user in user_repo.get_all():
        print(f"  ID: {user.id}, 姓名: {user.name}, 邮箱: {user.email}, 年龄: {user.age}")
    
    print(f"\n获取ID为2的用户: {user_repo.get(2)}")
    
    # 更新用户
    bob = user_repo.get(2)
    if bob:
        bob.age = 26
        user_repo.update(bob)
        print(f"更新后的Bob: {user_repo.get(2)}")
    
    # 删除用户
    user_repo.delete(3)
    print(f"\n删除ID为3的用户后,总用户数: {len(user_repo.get_all())}")

class DataValidator(ABC):
    """数据验证器抽象基类"""
    
    @abstractmethod
    def validate(self, data: Any) -> bool:
        """验证数据"""
        pass
    
    @abstractmethod
    def get_errors(self) -> List[str]:
        """获取错误信息"""
        pass

class UserValidator(DataValidator):
    """用户数据验证器"""
    
    def __init__(self):
        self._errors: List[str] = []
    
    def validate(self, user: Any) -> bool:
        self._errors.clear()
        
        if not isinstance(user, User):
            self._errors.append("数据必须是User类型")
            return False
        
        if not user.name or len(user.name.strip()) == 0:
            self._errors.append("姓名不能为空")
        
        if not user.email or '@' not in user.email:
            self._errors.append("邮箱格式不正确")
        
        if user.age < 0 or user.age > 150:
            self._errors.append("年龄必须在0-150之间")
        
        return len(self._errors) == 0
    
    def get_errors(self) -> List[str]:
        return self._errors.copy()

def demonstrate_validator_pattern():
    """演示验证器模式"""
    
    validator = UserValidator()
    
    test_users = [
        User(1, "Alice", "alice@example.com", 30),  # 有效
        User(2, "", "invalid-email", -5),           # 无效
        User(3, "Bob", "bob@example.com", 200),     # 无效年龄
    ]
    
    for user in test_users:
        print(f"\n验证用户: {user.name}")
        if validator.validate(user):
            print("  ✓ 验证通过")
        else:
            print("  ✗ 验证失败:")
            for error in validator.get_errors():
                print(f"    - {error}")

if __name__ == "__main__":
    demonstrate_generic_repository()
    print("\n" + "=" * 60)
    demonstrate_validator_pattern()

7. 设计模式与抽象基类

抽象基类在设计模式实现中发挥着重要作用。

7.1 策略模式

python 复制代码
from abc import ABC, abstractmethod
from typing import List
from dataclasses import dataclass

class CompressionStrategy(ABC):
    """压缩策略抽象基类"""
    
    @abstractmethod
    def compress(self, data: str) -> str:
        """压缩数据"""
        pass
    
    @abstractmethod
    def decompress(self, compressed_data: str) -> str:
        """解压缩数据"""
        pass
    
    @abstractmethod
    def get_compression_ratio(self, original_data: str, compressed_data: str) -> float:
        """获取压缩率"""
        pass

class RLECompression(CompressionStrategy):
    """游程编码压缩策略"""
    
    def compress(self, data: str) -> str:
        if not data:
            return ""
        
        compressed = []
        count = 1
        current_char = data[0]
        
        for char in data[1:]:
            if char == current_char:
                count += 1
            else:
                compressed.append(f"{count}{current_char}")
                current_char = char
                count = 1
        
        compressed.append(f"{count}{current_char}")
        return "".join(compressed)
    
    def decompress(self, compressed_data: str) -> str:
        decompressed = []
        i = 0
        
        while i < len(compressed_data):
            count_str = ""
            while i < len(compressed_data) and compressed_data[i].isdigit():
                count_str += compressed_data[i]
                i += 1
            
            if i < len(compressed_data):
                char = compressed_data[i]
                count = int(count_str) if count_str else 1
                decompressed.append(char * count)
                i += 1
        
        return "".join(decompressed)
    
    def get_compression_ratio(self, original_data: str, compressed_data: str) -> float:
        if not original_data:
            return 0.0
        return len(compressed_data) / len(original_data)

class DictionaryCompression(CompressionStrategy):
    """字典压缩策略"""
    
    def __init__(self):
        self.dictionary = {}
        self.next_code = 0
    
    def compress(self, data: str) -> str:
        self.dictionary.clear()
        self.next_code = 0
        
        words = data.split()
        compressed = []
        
        for word in words:
            if word not in self.dictionary:
                self.dictionary[word] = self.next_code
                self.next_code += 1
            compressed.append(str(self.dictionary[word]))
        
        return " ".join(compressed)
    
    def decompress(self, compressed_data: str) -> str:
        if not self.dictionary:
            raise ValueError("字典未初始化")
        
        # 反转字典
        reverse_dict = {v: k for k, v in self.dictionary.items()}
        
        codes = compressed_data.split()
        decompressed = []
        
        for code in codes:
            decompressed.append(reverse_dict[int(code)])
        
        return " ".join(decompressed)
    
    def get_compression_ratio(self, original_data: str, compressed_data: str) -> float:
        if not original_data:
            return 0.0
        return len(compressed_data) / len(original_data)

class CompressionContext:
    """压缩上下文"""
    
    def __init__(self, strategy: CompressionStrategy):
        self._strategy = strategy
    
    def set_strategy(self, strategy: CompressionStrategy):
        """设置压缩策略"""
        self._strategy = strategy
    
    def compress_data(self, data: str) -> str:
        """压缩数据"""
        return self._strategy.compress(data)
    
    def decompress_data(self, compressed_data: str) -> str:
        """解压缩数据"""
        return self._strategy.decompress(compressed_data)
    
    def get_compression_info(self, original_data: str) -> dict:
        """获取压缩信息"""
        compressed_data = self.compress_data(original_data)
        ratio = self._strategy.get_compression_ratio(original_data, compressed_data)
        
        return {
            "original_size": len(original_data),
            "compressed_size": len(compressed_data),
            "compression_ratio": ratio,
            "compressed_data": compressed_data
        }

def demonstrate_strategy_pattern():
    """演示策略模式"""
    
    test_data = "AAAABBBCCDAA"
    print(f"测试数据: {test_data}")
    print("=" * 50)
    
    # RLE压缩
    rle_strategy = RLECompression()
    rle_context = CompressionContext(rle_strategy)
    
    rle_info = rle_context.get_compression_info(test_data)
    print("RLE压缩:")
    print(f"  原始大小: {rle_info['original_size']}")
    print(f"  压缩后大小: {rle_info['compressed_size']}")
    print(f"  压缩率: {rle_info['compression_ratio']:.2%}")
    print(f"  压缩数据: {rle_info['compressed_data']}")
    
    decompressed = rle_context.decompress_data(rle_info['compressed_data'])
    print(f"  解压数据: {decompressed}")
    print(f"  解压成功: {decompressed == test_data}")
    
    print("\n" + "-" * 50)
    
    # 字典压缩
    text_data = "hello world hello python world python"
    print(f"文本数据: {text_data}")
    
    dict_strategy = DictionaryCompression()
    dict_context = CompressionContext(dict_strategy)
    
    dict_info = dict_context.get_compression_info(text_data)
    print("\n字典压缩:")
    print(f"  原始大小: {dict_info['original_size']}")
    print(f"  压缩后大小: {dict_info['compressed_size']}")
    print(f"  压缩率: {dict_info['compression_ratio']:.2%}")
    print(f"  压缩数据: {dict_info['compressed_data']}")
    
    decompressed_text = dict_context.decompress_data(dict_info['compressed_data'])
    print(f"  解压数据: {decompressed_text}")
    print(f"  解压成功: {decompressed_text == text_data}")

if __name__ == "__main__":
    demonstrate_strategy_pattern()

7.2 观察者模式

python 复制代码
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from datetime import datetime
import time

class Observer(ABC):
    """观察者抽象基类"""
    
    @abstractmethod
    def update(self, subject: Any, event: str, data: Any):
        """接收更新通知"""
        pass

class Subject(ABC):
    """主题抽象基类"""
    
    def __init__(self):
        self._observers: List[Observer] = []
    
    def attach(self, observer: Observer):
        """附加观察者"""
        if observer not in self._observers:
            self._observers.append(observer)
    
    def detach(self, observer: Observer):
        """分离观察者"""
        if observer in self._observers:
            self._observers.remove(observer)
    
    def notify(self, event: str, data: Any = None):
        """通知所有观察者"""
        for observer in self._observers:
            observer.update(self, event, data)

class StockMarket(Subject):
    """股票市场主题"""
    
    def __init__(self, name: str):
        super().__init__()
        self.name = name
        self._stocks: Dict[str, float] = {}
        self._price_history: Dict[str, List[float]] = {}
    
    def set_stock_price(self, symbol: str, price: float):
        """设置股票价格"""
        old_price = self._stocks.get(symbol)
        self._stocks[symbol] = price
        
        # 记录价格历史
        if symbol not in self._price_history:
            self._price_history[symbol] = []
        self._price_history[symbol].append(price)
        
        # 通知观察者
        event_data = {
            'symbol': symbol,
            'old_price': old_price,
            'new_price': price,
            'timestamp': datetime.now(),
            'change': price - old_price if old_price else 0
        }
        
        if old_price is not None:
            if price > old_price:
                self.notify('price_increase', event_data)
            elif price < old_price:
                self.notify('price_decrease', event_data)
        else:
            self.notify('price_initial', event_data)
    
    def get_stock_price(self, symbol: str) -> float:
        """获取股票价格"""
        return self._stocks.get(symbol, 0.0)
    
    def get_price_history(self, symbol: str) -> List[float]:
        """获取价格历史"""
        return self._price_history.get(symbol, [])

class StockTrader(Observer):
    """股票交易者观察者"""
    
    def __init__(self, name: str, budget: float):
        self.name = name
        self.budget = budget
        self.portfolio: Dict[str, int] = {}
        self.transaction_history: List[Dict] = []
    
    def update(self, subject: StockMarket, event: str, data: Any):
        """接收市场更新"""
        if event == 'price_decrease':
            self._consider_buying(subject, data)
        elif event == 'price_increase':
            self._consider_selling(subject, data)
    
    def _consider_buying(self, market: StockMarket, data: Dict):
        """考虑买入"""
        symbol = data['symbol']
        price = data['new_price']
        
        # 简单的买入策略:价格下降超过10%且预算足够
        if data['change'] < -price * 0.1 and self.budget >= price:
            shares_to_buy = min(10, int(self.budget // price))
            cost = shares_to_buy * price
            
            # 执行买入
            self.portfolio[symbol] = self.portfolio.get(symbol, 0) + shares_to_buy
            self.budget -= cost
            
            transaction = {
                'type': 'BUY',
                'symbol': symbol,
                'shares': shares_to_buy,
                'price': price,
                'total': cost,
                'timestamp': datetime.now()
            }
            self.transaction_history.append(transaction)
            
            print(f"{self.name} 买入 {shares_to_buy} 股 {symbol} @ ${price:.2f}")
    
    def _consider_selling(self, market: StockMarket, data: Dict):
        """考虑卖出"""
        symbol = data['symbol']
        price = data['new_price']
        
        # 简单的卖出策略:价格上涨超过15%且持有该股票
        if symbol in self.portfolio and data['change'] > price * 0.15:
            shares_owned = self.portfolio[symbol]
            shares_to_sell = min(5, shares_owned)  # 每次最多卖出5股
            revenue = shares_to_sell * price
            
            # 执行卖出
            self.portfolio[symbol] -= shares_to_sell
            if self.portfolio[symbol] == 0:
                del self.portfolio[symbol]
            self.budget += revenue
            
            transaction = {
                'type': 'SELL',
                'symbol': symbol,
                'shares': shares_to_sell,
                'price': price,
                'total': revenue,
                'timestamp': datetime.now()
            }
            self.transaction_history.append(transaction)
            
            print(f"{self.name} 卖出 {shares_to_sell} 股 {symbol} @ ${price:.2f}")
    
    def get_portfolio_value(self, market: StockMarket) -> float:
        """获取投资组合价值"""
        stock_value = sum(
            shares * market.get_stock_price(symbol)
            for symbol, shares in self.portfolio.items()
        )
        return stock_value + self.budget
    
    def print_status(self, market: StockMarket):
        """打印状态"""
        total_value = self.get_portfolio_value(market)
        print(f"\n{self.name} 状态:")
        print(f"  现金: ${self.budget:.2f}")
        print(f"  投资组合:")
        for symbol, shares in self.portfolio.items():
            price = market.get_stock_price(symbol)
            value = shares * price
            print(f"    {symbol}: {shares} 股 @ ${price:.2f} = ${value:.2f}")
        print(f"  总资产: ${total_value:.2f}")

class MarketAnalyst(Observer):
    """市场分析师观察者"""
    
    def __init__(self, name: str):
        self.name = name
        self.analysis_report: List[str] = []
    
    def update(self, subject: StockMarket, event: str, data: Any):
        """分析市场变化"""
        symbol = data['symbol']
        timestamp = data['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
        
        if event == 'price_increase':
            analysis = f"{timestamp} - {symbol} 上涨 ${data['change']:.2f} (新价格: ${data['new_price']:.2f})"
        elif event == 'price_decrease':
            analysis = f"{timestamp} - {symbol} 下跌 ${abs(data['change']):.2f} (新价格: ${data['new_price']:.2f})"
        else:
            analysis = f"{timestamp} - {symbol} 初始价格: ${data['new_price']:.2f}"
        
        self.analysis_report.append(analysis)
        print(f"分析师 {self.name}: {analysis}")
    
    def print_report(self):
        """打印分析报告"""
        print(f"\n{self.name} 的分析报告:")
        for i, analysis in enumerate(self.analysis_report[-5:], 1):  # 只显示最后5条
            print(f"  {i}. {analysis}")

def demonstrate_observer_pattern():
    """演示观察者模式"""
    
    # 创建股票市场
    market = StockMarket("NASDAQ")
    
    # 创建交易者和分析师
    trader1 = StockTrader("Alice", 10000.0)
    trader2 = StockTrader("Bob", 8000.0)
    analyst = MarketAnalyst("Dr. Smith")
    
    # 注册观察者
    market.attach(trader1)
    market.attach(trader2)
    market.attach(analyst)
    
    print("股票市场模拟开始...")
    print("=" * 50)
    
    # 模拟股票价格变化
    stocks = ['AAPL', 'GOOGL', 'MSFT', 'TSLA']
    prices = {
        'AAPL': 150.0,
        'GOOGL': 2800.0,
        'MSFT': 300.0,
        'TSLA': 700.0
    }
    
    import random
    
    for day in range(1, 6):
        print(f"\n第 {day} 天:")
        print("-" * 30)
        
        for symbol in stocks:
            # 随机价格变化 (-5% 到 +5%)
            change_percent = random.uniform(-0.05, 0.05)
            old_price = prices[symbol]
            new_price = old_price * (1 + change_percent)
            prices[symbol] = new_price
            
            market.set_stock_price(symbol, new_price)
            time.sleep(0.1)  # 短暂暂停以便观察
    
    # 显示最终状态
    print("\n" + "=" * 50)
    print("模拟结束")
    print("=" * 50)
    
    trader1.print_status(market)
    trader2.print_status(market)
    analyst.print_report()

if __name__ == "__main__":
    demonstrate_observer_pattern()

8. 完整代码示例:电子商务系统

下面是一个完整的电子商务系统示例,展示抽象基类在实际项目中的应用。

python 复制代码
"""
电子商务系统完整示例
演示抽象基类在实际项目中的应用
"""

from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union
from datetime import datetime
import uuid
from dataclasses import dataclass
from enum import Enum

class OrderStatus(Enum):
    PENDING = "pending"
    CONFIRMED = "confirmed"
    SHIPPED = "shipped"
    DELIVERED = "delivered"
    CANCELLED = "cancelled"

class PaymentStatus(Enum):
    PENDING = "pending"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"
    REFUNDED = "refunded"

@dataclass
class Product:
    id: str
    name: str
    price: float
    description: str
    stock_quantity: int
    category: str

@dataclass
class Customer:
    id: str
    name: str
    email: str
    address: str
    phone: str

@dataclass
class OrderItem:
    product_id: str
    product_name: str
    quantity: int
    unit_price: float
    
    @property
    def total_price(self) -> float:
        return self.quantity * self.unit_price

class Repository(ABC):
    """泛型仓库抽象基类"""
    
    @abstractmethod
    def get_by_id(self, id: str):
        pass
    
    @abstractmethod
    def get_all(self) -> List:
        pass
    
    @abstractmethod
    def add(self, entity) -> bool:
        pass
    
    @abstractmethod
    def update(self, entity) -> bool:
        pass
    
    @abstractmethod
    def delete(self, id: str) -> bool:
        pass

class ProductRepository(Repository):
    """产品仓库"""
    
    def __init__(self):
        self._products: Dict[str, Product] = {}
    
    def get_by_id(self, id: str) -> Optional[Product]:
        return self._products.get(id)
    
    def get_all(self) -> List[Product]:
        return list(self._products.values())
    
    def add(self, product: Product) -> bool:
        if product.id in self._products:
            return False
        self._products[product.id] = product
        return True
    
    def update(self, product: Product) -> bool:
        if product.id not in self._products:
            return False
        self._products[product.id] = product
        return True
    
    def delete(self, id: str) -> bool:
        if id not in self._products:
            return False
        del self._products[id]
        return True
    
    def get_by_category(self, category: str) -> List[Product]:
        return [p for p in self._products.values() if p.category == category]
    
    def update_stock(self, product_id: str, quantity: int) -> bool:
        product = self.get_by_id(product_id)
        if product and product.stock_quantity >= quantity:
            product.stock_quantity -= quantity
            return self.update(product)
        return False

class CustomerRepository(Repository):
    """客户仓库"""
    
    def __init__(self):
        self._customers: Dict[str, Customer] = {}
    
    def get_by_id(self, id: str) -> Optional[Customer]:
        return self._customers.get(id)
    
    def get_all(self) -> List[Customer]:
        return list(self._customers.values())
    
    def add(self, customer: Customer) -> bool:
        if customer.id in self._customers:
            return False
        self._customers[customer.id] = customer
        return True
    
    def update(self, customer: Customer) -> bool:
        if customer.id not in self._customers:
            return False
        self._customers[customer.id] = customer
        return True
    
    def delete(self, id: str) -> bool:
        if id not in self._customers:
            return False
        del self._customers[id]
        return True

class Order:
    """订单类"""
    
    def __init__(self, customer_id: str):
        self.id = str(uuid.uuid4())
        self.customer_id = customer_id
        self.items: List[OrderItem] = []
        self.status = OrderStatus.PENDING
        self.payment_status = PaymentStatus.PENDING
        self.created_at = datetime.now()
        self.updated_at = datetime.now()
        self.shipping_address = ""
        self.total_amount = 0.0
    
    def add_item(self, product: Product, quantity: int) -> bool:
        """添加订单项"""
        if quantity <= 0 or product.stock_quantity < quantity:
            return False
        
        # 检查是否已存在该商品
        for item in self.items:
            if item.product_id == product.id:
                return False
        
        order_item = OrderItem(
            product_id=product.id,
            product_name=product.name,
            quantity=quantity,
            unit_price=product.price
        )
        
        self.items.append(order_item)
        self._update_totals()
        return True
    
    def remove_item(self, product_id: str) -> bool:
        """移除订单项"""
        for i, item in enumerate(self.items):
            if item.product_id == product_id:
                self.items.pop(i)
                self._update_totals()
                return True
        return False
    
    def update_quantity(self, product_id: str, quantity: int) -> bool:
        """更新商品数量"""
        for item in self.items:
            if item.product_id == product_id:
                if quantity <= 0:
                    return self.remove_item(product_id)
                item.quantity = quantity
                self._update_totals()
                return True
        return False
    
    def _update_totals(self):
        """更新总金额"""
        self.total_amount = sum(item.total_price for item in self.items)
        self.updated_at = datetime.now()
    
    def get_order_summary(self) -> Dict:
        """获取订单摘要"""
        return {
            'order_id': self.id,
            'customer_id': self.customer_id,
            'total_amount': self.total_amount,
            'item_count': len(self.items),
            'status': self.status.value,
            'payment_status': self.payment_status.value,
            'created_at': self.created_at.isoformat()
        }

class OrderRepository(Repository):
    """订单仓库"""
    
    def __init__(self):
        self._orders: Dict[str, Order] = {}
    
    def get_by_id(self, id: str) -> Optional[Order]:
        return self._orders.get(id)
    
    def get_all(self) -> List[Order]:
        return list(self._orders.values())
    
    def add(self, order: Order) -> bool:
        if order.id in self._orders:
            return False
        self._orders[order.id] = order
        return True
    
    def update(self, order: Order) -> bool:
        if order.id not in self._orders:
            return False
        self._orders[order.id] = order
        return True
    
    def delete(self, id: str) -> bool:
        if id not in self._orders:
            return False
        del self._orders[id]
        return True
    
    def get_customer_orders(self, customer_id: str) -> List[Order]:
        return [order for order in self._orders.values() 
                if order.customer_id == customer_id]
    
    def get_orders_by_status(self, status: OrderStatus) -> List[Order]:
        return [order for order in self._orders.values() 
                if order.status == status]

class PaymentProcessor(ABC):
    """支付处理器抽象基类"""
    
    @abstractmethod
    def process_payment(self, order: Order, payment_details: Dict) -> bool:
        """处理支付"""
        pass
    
    @abstractmethod
    def refund_payment(self, order: Order, amount: float) -> bool:
        """退款"""
        pass
    
    @abstractmethod
    def get_payment_status(self, order: Order) -> PaymentStatus:
        """获取支付状态"""
        pass

class CreditCardProcessor(PaymentProcessor):
    """信用卡支付处理器"""
    
    def process_payment(self, order: Order, payment_details: Dict) -> bool:
        # 模拟信用卡支付处理
        card_number = payment_details.get('card_number')
        expiry_date = payment_details.get('expiry_date')
        cvv = payment_details.get('cvv')
        
        if not all([card_number, expiry_date, cvv]):
            return False
        
        # 模拟支付处理延迟
        import time
        time.sleep(0.1)
        
        # 模拟支付成功(在实际应用中这里会有真实的支付逻辑)
        order.payment_status = PaymentStatus.COMPLETED
        return True
    
    def refund_payment(self, order: Order, amount: float) -> bool:
        if order.payment_status != PaymentStatus.COMPLETED:
            return False
        
        # 模拟退款处理
        order.payment_status = PaymentStatus.REFUNDED
        return True
    
    def get_payment_status(self, order: Order) -> PaymentStatus:
        return order.payment_status

class NotificationService(ABC):
    """通知服务抽象基类"""
    
    @abstractmethod
    def send_notification(self, recipient: str, message: str, subject: str = "") -> bool:
        """发送通知"""
        pass

class EmailNotificationService(NotificationService):
    """邮件通知服务"""
    
    def send_notification(self, recipient: str, message: str, subject: str = "") -> bool:
        print(f"发送邮件到: {recipient}")
        print(f"主题: {subject}")
        print(f"内容: {message}")
        print("-" * 50)
        return True

class OrderService:
    """订单服务"""
    
    def __init__(
        self,
        order_repository: OrderRepository,
        product_repository: ProductRepository,
        customer_repository: CustomerRepository,
        payment_processor: PaymentProcessor,
        notification_service: NotificationService
    ):
        self.order_repository = order_repository
        self.product_repository = product_repository
        self.customer_repository = customer_repository
        self.payment_processor = payment_processor
        self.notification_service = notification_service
    
    def create_order(self, customer_id: str) -> Optional[Order]:
        """创建新订单"""
        customer = self.customer_repository.get_by_id(customer_id)
        if not customer:
            return None
        
        order = Order(customer_id)
        if self.order_repository.add(order):
            return order
        return None
    
    def add_product_to_order(self, order_id: str, product_id: str, quantity: int) -> bool:
        """添加商品到订单"""
        order = self.order_repository.get_by_id(order_id)
        product = self.product_repository.get_by_id(product_id)
        
        if not order or not product:
            return False
        
        if order.status != OrderStatus.PENDING:
            return False
        
        success = order.add_item(product, quantity)
        if success:
            self.order_repository.update(order)
        
        return success
    
    def place_order(self, order_id: str, payment_details: Dict, shipping_address: str) -> bool:
        """下单"""
        order = self.order_repository.get_by_id(order_id)
        if not order or order.status != OrderStatus.PENDING:
            return False
        
        # 设置配送地址
        order.shipping_address = shipping_address
        
        # 处理支付
        if not self.payment_processor.process_payment(order, payment_details):
            order.payment_status = PaymentStatus.FAILED
            self.order_repository.update(order)
            return False
        
        # 更新库存
        for item in order.items:
            if not self.product_repository.update_stock(item.product_id, item.quantity):
                # 库存不足,回滚
                self.payment_processor.refund_payment(order, order.total_amount)
                return False
        
        # 更新订单状态
        order.status = OrderStatus.CONFIRMED
        self.order_repository.update(order)
        
        # 发送确认邮件
        customer = self.customer_repository.get_by_id(order.customer_id)
        if customer:
            message = f"您的订单 #{order.id} 已确认。总金额: ${order.total_amount:.2f}"
            self.notification_service.send_notification(
                customer.email,
                message,
                "订单确认"
            )
        
        return True
    
    def get_order_status(self, order_id: str) -> Optional[Dict]:
        """获取订单状态"""
        order = self.order_repository.get_by_id(order_id)
        if not order:
            return None
        
        return order.get_order_summary()

class ECommerceSystem:
    """电子商务系统"""
    
    def __init__(self):
        # 初始化所有组件
        self.product_repository = ProductRepository()
        self.customer_repository = CustomerRepository()
        self.order_repository = OrderRepository()
        self.payment_processor = CreditCardProcessor()
        self.notification_service = EmailNotificationService()
        
        self.order_service = OrderService(
            self.order_repository,
            self.product_repository,
            self.customer_repository,
            self.payment_processor,
            self.notification_service
        )
        
        self._initialize_sample_data()
    
    def _initialize_sample_data(self):
        """初始化示例数据"""
        # 添加示例产品
        products = [
            Product("1", "笔记本电脑", 999.99, "高性能笔记本电脑", 10, "电子产品"),
            Product("2", "智能手机", 699.99, "最新款智能手机", 20, "电子产品"),
            Product("3", "书籍", 29.99, "编程书籍", 50, "图书"),
            Product("4", "耳机", 199.99, "无线降噪耳机", 15, "电子产品"),
        ]
        
        for product in products:
            self.product_repository.add(product)
        
        # 添加示例客户
        customers = [
            Customer("1", "张三", "zhangsan@example.com", "北京市朝阳区", "13800138000"),
            Customer("2", "李四", "lisi@example.com", "上海市浦东新区", "13900139000"),
        ]
        
        for customer in customers:
            self.customer_repository.add(customer)
    
    def run_demo(self):
        """运行演示"""
        print("电子商务系统演示")
        print("=" * 50)
        
        # 显示可用产品
        print("\n可用产品:")
        for product in self.product_repository.get_all():
            print(f"  {product.id}. {product.name} - ${product.price:.2f} (库存: {product.stock_quantity})")
        
        # 创建订单
        print("\n1. 创建订单...")
        order = self.order_service.create_order("1")
        if not order:
            print("创建订单失败")
            return
        
        print(f"订单创建成功: {order.id}")
        
        # 添加商品到订单
        print("\n2. 添加商品到订单...")
        self.order_service.add_product_to_order(order.id, "1", 1)  # 笔记本电脑
        self.order_service.add_product_to_order(order.id, "3", 2)  # 书籍
        
        # 显示订单摘要
        order_summary = self.order_service.get_order_status(order.id)
        print(f"订单摘要: {order_summary}")
        
        # 下单
        print("\n3. 下单...")
        payment_details = {
            'card_number': '4111111111111111',
            'expiry_date': '12/25',
            'cvv': '123'
        }
        
        shipping_address = "北京市朝阳区某某街道123号"
        
        if self.order_service.place_order(order.id, payment_details, shipping_address):
            print("下单成功!")
            
            # 显示最终订单状态
            final_summary = self.order_service.get_order_status(order.id)
            print(f"最终订单状态: {final_summary}")
        else:
            print("下单失败")
        
        # 显示库存变化
        print("\n4. 库存状态:")
        for product in self.product_repository.get_all():
            print(f"  {product.name}: {product.stock_quantity} 件")

def main():
    """主函数"""
    ecommerce_system = ECommerceSystem()
    ecommerce_system.run_demo()

if __name__ == "__main__":
    main()

9. 最佳实践与注意事项

9.1 抽象基类设计原则

  1. 单一职责原则:每个抽象基类应该只有一个明确的职责
  2. 接口隔离原则:不要强迫客户端依赖它们不需要的方法
  3. 里氏替换原则:子类应该能够替换它们的父类
  4. 依赖倒置原则:依赖于抽象而不是具体实现

9.2 常见陷阱与解决方案

python 复制代码
from abc import ABC, abstractmethod

class CommonMistakesDemo:
    """常见陷阱演示"""
    
    class BadAbstractClass(ABC):
        """不好的抽象基类设计"""
        
        @abstractmethod
        def method1(self):
            pass
        
        @abstractmethod
        def method2(self):
            pass
        
        def too_many_concrete_methods(self):
            """过多的具体方法"""
            # 这违反了接口隔离原则
            pass
        
        def implementation_details(self):
            """包含实现细节"""
            # 抽象基类应该关注接口,而不是实现
            pass
    
    class GoodAbstractClass(ABC):
        """良好的抽象基类设计"""
        
        @abstractmethod
        def essential_operation(self):
            """必要的操作"""
            pass
        
        def optional_operation(self):
            """可选的操作,提供默认实现"""
            raise NotImplementedError("子类可以选择实现此方法")
        
        def template_method(self):
            """模板方法模式"""
            self.essential_operation()
            self._hook_method()
        
        def _hook_method(self):
            """钩子方法,子类可以重写"""
            pass

def demonstrate_best_practices():
    """演示最佳实践"""
    
    print("抽象基类最佳实践:")
    print("1. 保持抽象基类简洁,只定义必要的抽象方法")
    print("2. 使用模板方法模式提供算法骨架")
    print("3. 为可选操作提供默认实现或抛出NotImplementedError")
    print("4. 使用钩子方法允许子类扩展行为")
    print("5. 避免在抽象基类中包含具体实现细节")

if __name__ == "__main__":
    demonstrate_best_practices()

10. 总结

Python的抽象基类(ABC)是定义接口契约的强大工具,它提供了以下关键优势:

  1. 明确的接口定义:通过抽象方法强制子类实现特定接口
  2. 类型检查支持:增强代码的可靠性和可维护性
  3. 设计模式实现:支持策略模式、观察者模式等经典设计模式
  4. 代码组织:促进清晰的代码结构和架构设计
  5. 多态性:实现统一的接口,不同的行为

10.1 关键要点

  • 使用@abstractmethod定义抽象方法
  • 抽象基类不能被实例化
  • 子类必须实现所有抽象方法
  • 可以使用register方法创建虚拟子类
  • __subclasshook__方法允许自定义子类检查逻辑

10.2 适用场景

  • 定义框架或库的扩展点
  • 实现插件系统
  • 创建清晰的API契约
  • 实现设计模式
  • 大型项目的架构设计

通过合理使用抽象基类,您可以创建更加健壮、可维护和可扩展的Python应用程序。抽象基类不仅是一种技术工具,更是一种设计哲学,它鼓励开发者思考接口设计和组件之间的关系。

11. 代码自查

在完成本文的所有代码示例后,我们进行了以下自查以确保代码质量:

  1. 语法正确性:所有代码都通过Python语法检查
  2. 抽象基类规范 :确保所有抽象方法都正确使用@abstractmethod装饰器
  3. 类型提示:在适当的地方使用了类型提示
  4. 异常处理:关键操作都有适当的异常处理
  5. 代码注释:所有复杂逻辑都有清晰的注释说明
  6. 设计模式应用:正确实现了各种设计模式
  7. 实际应用场景:代码示例基于真实的应用场景

这些代码示例可以直接运行,并且包含了适当的设计模式和最佳实践,可以作为学习和实际项目参考。

相关推荐
qq_172805591 小时前
Go 语言结构型设计模式深度解析
开发语言·设计模式·golang
vx_dmxq2111 小时前
【微信小程序学习交流平台】(免费领源码+演示录像)|可做计算机毕设Java、Python、PHP、小程序APP、C#、爬虫大数据、单片机、文案
java·spring boot·python·mysql·微信小程序·小程序·idea
无垠的广袤2 小时前
【工业树莓派 CM0 NANO 单板计算机】本地部署 EMQX
linux·python·嵌入式硬件·物联网·树莓派·emqx·工业物联网
lkbhua莱克瓦242 小时前
集合进阶8——Stream流
java·开发语言·笔记·github·stream流·学习方法·集合
20岁30年经验的码农2 小时前
Java Elasticsearch 实战指南
java·开发语言·elasticsearch
雾岛听蓝2 小时前
C++ 类和对象(一):从概念到实践,吃透类的核心基础
开发语言·c++·经验分享·笔记
CoderYanger2 小时前
优选算法-优先级队列(堆):75.数据流中的第K大元素
java·开发语言·算法·leetcode·职场和发展·1024程序员节
TracyCoder1233 小时前
MySQL 实战宝典(八):Java后端MySQL分库分表工具解析与选型秘籍
java·开发语言·mysql
非凡的世界3 小时前
为什么我和越来越多的PHP程序员,选择了 Webman ?
开发语言·php·workman·webman