目录
- Python的抽象基类(ABC):定义接口契约的艺术
-
- [1. 引言](#1. 引言)
-
- [1.1 什么是抽象基类?](#1.1 什么是抽象基类?)
- [1.2 为什么需要抽象基类?](#1.2 为什么需要抽象基类?)
- [2. 抽象基类的基础](#2. 抽象基类的基础)
-
- [2.1 创建抽象基类](#2.1 创建抽象基类)
- [2.2 实现抽象基类](#2.2 实现抽象基类)
- [2.3 抽象属性的使用](#2.3 抽象属性的使用)
- [3. 高级抽象基类特性](#3. 高级抽象基类特性)
-
- [3.1 多重继承与抽象基类](#3.1 多重继承与抽象基类)
- [3.2 抽象类方法与静态方法](#3.2 抽象类方法与静态方法)
- [4. 注册机制与虚拟子类](#4. 注册机制与虚拟子类)
-
- [4.1 注册虚拟子类](#4.1 注册虚拟子类)
- [4.2 `subclasshook`方法](#4.2
__subclasshook__方法)
- [5. 标准库中的抽象基类](#5. 标准库中的抽象基类)
-
- [5.1 集合抽象基类](#5.1 集合抽象基类)
- [5.2 上下文管理器抽象基类](#5.2 上下文管理器抽象基类)
- [6. 类型检查与抽象基类](#6. 类型检查与抽象基类)
-
- [6.1 类型提示与抽象基类](#6.1 类型提示与抽象基类)
- [7. 设计模式与抽象基类](#7. 设计模式与抽象基类)
-
- [7.1 策略模式](#7.1 策略模式)
- [7.2 观察者模式](#7.2 观察者模式)
- [8. 完整代码示例:电子商务系统](#8. 完整代码示例:电子商务系统)
- [9. 最佳实践与注意事项](#9. 最佳实践与注意事项)
-
- [9.1 抽象基类设计原则](#9.1 抽象基类设计原则)
- [9.2 常见陷阱与解决方案](#9.2 常见陷阱与解决方案)
- [10. 总结](#10. 总结)
-
- [10.1 关键要点](#10.1 关键要点)
- [10.2 适用场景](#10.2 适用场景)
- [11. 代码自查](#11. 代码自查)
『宝藏代码胶囊开张啦!』------ 我的 CodeCapsule 来咯!✨写代码不再头疼!我的新站点 CodeCapsule 主打一个 "白菜价"+"量身定制 "!无论是卡脖子的毕设/课设/文献复现 ,需要灵光一现的算法改进 ,还是想给项目加个"外挂",这里都有便宜又好用的代码方案等你发现!低成本,高适配,助你轻松通关!速来围观 👉 CodeCapsule官网
Python的抽象基类(ABC):定义接口契约的艺术
1. 引言
在面向对象编程中,抽象基类(Abstract Base Classes,简称ABC)是一种强大的工具,用于定义接口契约和建立类之间的规范。Python通过abc模块提供了对抽象基类的原生支持,这使得开发者能够创建更加健壮、可维护的代码。
1.1 什么是抽象基类?
抽象基类是不能被实例化的类,其主要目的是为子类定义接口契约。它规定了子类必须实现的方法和属性,从而确保了一致性和可预测性。
1.2 为什么需要抽象基类?
- 强制接口实现:确保子类实现了必要的方法
- 提供明确的契约:定义清晰的API规范
- 增强代码可读性:明确表达设计意图
- 便于类型检查:支持静态类型检查和IDE提示
- 多态性支持:统一的接口,不同的实现
python
from abc import ABC, abstractmethod
class Shape(ABC):
"""形状抽象基类"""
@abstractmethod
def area(self):
"""计算面积"""
pass
@abstractmethod
def perimeter(self):
"""计算周长"""
pass
def description(self):
"""非抽象方法,可以有默认实现"""
return "这是一个形状"
2. 抽象基类的基础
2.1 创建抽象基类
Python中使用ABC元类和abstractmethod装饰器来创建抽象基类。
python
from abc import ABC, abstractmethod
import math
class Vehicle(ABC):
"""交通工具抽象基类"""
def __init__(self, name, max_speed):
self.name = name
self.max_speed = max_speed
self._current_speed = 0
@abstractmethod
def start_engine(self):
"""启动引擎"""
pass
@abstractmethod
def stop_engine(self):
"""停止引擎"""
pass
@abstractmethod
def accelerate(self, increment):
"""加速"""
pass
def get_current_speed(self):
"""获取当前速度 - 具体方法"""
return self._current_speed
def __str__(self):
return f"{self.name} (最大速度: {self.max_speed} km/h)"
2.2 实现抽象基类
子类必须实现所有抽象方法,否则在实例化时会抛出TypeError。
python
class Car(Vehicle):
"""汽车类 - 实现Vehicle抽象基类"""
def __init__(self, name, max_speed, fuel_type):
super().__init__(name, max_speed)
self.fuel_type = fuel_type
self._engine_started = False
def start_engine(self):
if not self._engine_started:
self._engine_started = True
return f"{self.name}引擎已启动"
return f"{self.name}引擎已经在运行"
def stop_engine(self):
if self._engine_started:
self._engine_started = False
self._current_speed = 0
return f"{self.name}引擎已停止"
return f"{self.name}引擎已经停止"
def accelerate(self, increment):
if not self._engine_started:
return "请先启动引擎"
new_speed = self._current_speed + increment
if new_speed <= self.max_speed:
self._current_speed = new_speed
return f"加速到 {self._current_speed} km/h"
else:
self._current_speed = self.max_speed
return f"已达到最大速度 {self.max_speed} km/h"
class Bicycle(Vehicle):
"""自行车类 - 实现Vehicle抽象基类"""
def __init__(self, name, max_speed, gear_count):
super().__init__(name, max_speed)
self.gear_count = gear_count
self._current_gear = 1
def start_engine(self):
return "自行车没有引擎"
def stop_engine(self):
return "自行车没有引擎"
def accelerate(self, increment):
new_speed = self._current_speed + increment
if new_speed <= self.max_speed:
self._current_speed = new_speed
return f"蹬车加速到 {self._current_speed} km/h"
else:
self._current_speed = self.max_speed
return f"已达到最大速度 {self.max_speed} km/h"
def change_gear(self, new_gear):
if 1 <= new_gear <= self.gear_count:
self._current_gear = new_gear
return f"切换到 {new_gear} 档"
return "无效的档位"
def demonstrate_vehicle_abc():
"""演示抽象基类的使用"""
# 创建汽车实例
car = Car("丰田卡罗拉", 180, "汽油")
print(car)
print(car.start_engine())
print(car.accelerate(50))
print(car.accelerate(100))
print(car.stop_engine())
print("\n" + "="*50 + "\n")
# 创建自行车实例
bike = Bicycle("山地自行车", 40, 21)
print(bike)
print(bike.start_engine())
print(bike.accelerate(15))
print(bike.change_gear(3))
print(bike.accelerate(10))
if __name__ == "__main__":
demonstrate_vehicle_abc()
2.3 抽象属性的使用
除了抽象方法,Python还支持抽象属性,使用abstractproperty或结合property与abstractmethod。
python
from abc import ABC, abstractmethod
class DatabaseConnection(ABC):
"""数据库连接抽象基类"""
@property
@abstractmethod
def is_connected(self):
"""连接状态"""
pass
@property
@abstractmethod
def connection_string(self):
"""连接字符串"""
pass
@abstractmethod
def connect(self):
"""建立连接"""
pass
@abstractmethod
def disconnect(self):
"""断开连接"""
pass
@abstractmethod
def execute_query(self, query):
"""执行查询"""
pass
class MySQLConnection(DatabaseConnection):
"""MySQL数据库连接实现"""
def __init__(self, host, user, password, database):
self._host = host
self._user = user
self._password = password
self._database = database
self._connected = False
@property
def is_connected(self):
return self._connected
@property
def connection_string(self):
return f"mysql://{self._user}@{self._host}/{self._database}"
def connect(self):
if not self._connected:
# 模拟连接建立
self._connected = True
return f"已连接到MySQL数据库: {self.connection_string}"
return "已经连接到数据库"
def disconnect(self):
if self._connected:
self._connected = False
return "MySQL连接已断开"
return "未连接到数据库"
def execute_query(self, query):
if not self._connected:
raise ConnectionError("请先连接到数据库")
return f"执行MySQL查询: {query}"
def demonstrate_abstract_properties():
"""演示抽象属性的使用"""
mysql_conn = MySQLConnection("localhost", "admin", "password", "myapp")
print(f"连接字符串: {mysql_conn.connection_string}")
print(f"连接状态: {mysql_conn.is_connected}")
print(mysql_conn.connect())
print(f"连接状态: {mysql_conn.is_connected}")
print(mysql_conn.execute_query("SELECT * FROM users"))
print(mysql_conn.disconnect())
if __name__ == "__main__":
demonstrate_abstract_properties()
3. 高级抽象基类特性
3.1 多重继承与抽象基类
抽象基类支持多重继承,可以创建复杂的接口层次结构。
python
from abc import ABC, abstractmethod
from typing import List
class Readable(ABC):
"""可读接口"""
@abstractmethod
def read(self) -> str:
"""读取内容"""
pass
@abstractmethod
def get_size(self) -> int:
"""获取大小"""
pass
class Writable(ABC):
"""可写接口"""
@abstractmethod
def write(self, content: str) -> bool:
"""写入内容"""
pass
@abstractmethod
def append(self, content: str) -> bool:
"""追加内容"""
pass
class ReadWriteFile(Readable, Writable):
"""读写文件实现"""
def __init__(self, filename: str):
self.filename = filename
self._content = ""
def read(self) -> str:
return self._content
def get_size(self) -> int:
return len(self._content)
def write(self, content: str) -> bool:
self._content = content
return True
def append(self, content: str) -> bool:
self._content += content
return True
def clear(self):
"""额外的方法"""
self._content = ""
class NetworkStream(Readable, Writable):
"""网络流实现"""
def __init__(self, buffer_size: int = 1024):
self.buffer: List[str] = []
self.buffer_size = buffer_size
def read(self) -> str:
if self.buffer:
return self.buffer.pop(0)
return ""
def get_size(self) -> int:
return sum(len(item) for item in self.buffer)
def write(self, content: str) -> bool:
if len(content) <= self.buffer_size:
self.buffer.append(content)
return True
return False
def append(self, content: str) -> bool:
return self.write(content) # 对于网络流,write和append行为相同
def demonstrate_multiple_inheritance():
"""演示多重继承的抽象基类"""
# 文件操作
file_processor = ReadWriteFile("example.txt")
file_processor.write("Hello, World!")
file_processor.append(" This is additional content.")
print("文件内容:", file_processor.read())
print("文件大小:", file_processor.get_size())
print("\n" + "="*50 + "\n")
# 网络流操作
stream = NetworkStream()
stream.write("数据包1")
stream.write("数据包2")
print("流内容:", stream.read())
print("流大小:", stream.get_size())
# 类型检查
print(f"\nfile_processor是Readable: {isinstance(file_processor, Readable)}")
print(f"file_processor是Writable: {isinstance(file_processor, Writable)}")
print(f"stream是Readable: {isinstance(stream, Readable)}")
print(f"stream是Writable: {isinstance(stream, Writable)}")
if __name__ == "__main__":
demonstrate_multiple_inheritance()
3.2 抽象类方法与静态方法
抽象基类也支持抽象类方法和抽象静态方法。
python
from abc import ABC, abstractmethod, abstractclassmethod, abstractstaticmethod
from datetime import datetime
import json
class Serializer(ABC):
"""序列化器抽象基类"""
@abstractclassmethod
def get_format_name(cls) -> str:
"""获取格式名称"""
pass
@abstractstaticmethod
def is_valid_content(content: str) -> bool:
"""验证内容格式"""
pass
@abstractmethod
def serialize(self, data: dict) -> str:
"""序列化数据"""
pass
@abstractmethod
def deserialize(self, content: str) -> dict:
"""反序列化内容"""
pass
@classmethod
def get_serializer_info(cls):
"""获取序列化器信息"""
return {
"format": cls.get_format_name(),
"timestamp": datetime.now().isoformat()
}
class JSONSerializer(Serializer):
"""JSON序列化器"""
@classmethod
def get_format_name(cls) -> str:
return "JSON"
@staticmethod
def is_valid_content(content: str) -> bool:
try:
json.loads(content)
return True
except (json.JSONDecodeError, TypeError):
return False
def serialize(self, data: dict) -> str:
return json.dumps(data, indent=2, ensure_ascii=False)
def deserialize(self, content: str) -> dict:
if not self.is_valid_content(content):
raise ValueError("无效的JSON内容")
return json.loads(content)
class XMLSerializer(Serializer):
"""XML序列化器(简化版)"""
@classmethod
def get_format_name(cls) -> str:
return "XML"
@staticmethod
def is_valid_content(content: str) -> bool:
# 简化的XML验证
return content.strip().startswith('<') and content.strip().endswith('>')
def serialize(self, data: dict) -> str:
xml_parts = ['<root>']
for key, value in data.items():
xml_parts.append(f' <{key}>{value}</{key}>')
xml_parts.append('</root>')
return '\n'.join(xml_parts)
def deserialize(self, content: str) -> dict:
if not self.is_valid_content(content):
raise ValueError("无效的XML内容")
# 简化的XML解析
lines = content.strip().split('\n')
result = {}
for line in lines:
line = line.strip()
if line.startswith('<') and not line.startswith('</') and not line.startswith('<root'):
key = line.split('>')[0][1:]
value = line.split('>')[1].split('<')[0]
result[key] = value
return result
def demonstrate_abstract_class_static_methods():
"""演示抽象类方法和静态方法"""
# 测试JSON序列化器
json_serializer = JSONSerializer()
print(f"序列化器格式: {JSONSerializer.get_format_name()}")
print(f"序列化器信息: {JSONSerializer.get_serializer_info()}")
test_data = {"name": "Alice", "age": 30, "city": "Beijing"}
json_content = json_serializer.serialize(test_data)
print(f"\nJSON序列化结果:\n{json_content}")
print(f"内容是否有效: {JSONSerializer.is_valid_content(json_content)}")
deserialized_data = json_serializer.deserialize(json_content)
print(f"反序列化数据: {deserialized_data}")
print("\n" + "="*50 + "\n")
# 测试XML序列化器
xml_serializer = XMLSerializer()
print(f"序列化器格式: {XMLSerializer.get_format_name()}")
xml_content = xml_serializer.serialize(test_data)
print(f"\nXML序列化结果:\n{xml_content}")
print(f"内容是否有效: {XMLSerializer.is_valid_content(xml_content)}")
deserialized_xml = xml_serializer.deserialize(xml_content)
print(f"反序列化数据: {deserialized_xml}")
if __name__ == "__main__":
demonstrate_abstract_class_static_methods()
4. 注册机制与虚拟子类
Python的抽象基类支持注册机制,可以将现有的类注册为抽象基类的虚拟子类。
4.1 注册虚拟子类
python
from abc import ABC, abstractmethod
from collections.abc import Sequence
class DataProcessor(ABC):
"""数据处理器抽象基类"""
@abstractmethod
def process(self, data):
"""处理数据"""
pass
@abstractmethod
def validate(self, data) -> bool:
"""验证数据"""
pass
def get_info(self):
"""获取处理器信息"""
return "通用数据处理器"
# 现有的类,没有继承DataProcessor
class ListProcessor:
"""列表处理器"""
def __init__(self, multiplier=1):
self.multiplier = multiplier
def process(self, data):
if isinstance(data, (list, tuple)):
return [item * self.multiplier for item in data]
raise TypeError("数据必须是列表或元组")
def validate(self, data) -> bool:
return isinstance(data, (list, tuple)) and all(
isinstance(item, (int, float)) for item in data
)
def get_info(self):
return f"列表处理器 (乘数: {self.multiplier})"
# 注册为虚拟子类
DataProcessor.register(ListProcessor)
def demonstrate_virtual_subclasses():
"""演示虚拟子类"""
# 创建列表处理器实例
processor = ListProcessor(multiplier=2)
# 检查类型
print(f"processor是DataProcessor: {isinstance(processor, DataProcessor)}")
print(f"ListProcessor是DataProcessor子类: {issubclass(ListProcessor, DataProcessor)}")
# 使用处理器
test_data = [1, 2, 3, 4, 5]
print(f"验证数据: {processor.validate(test_data)}")
print(f"处理结果: {processor.process(test_data)}")
print(f"处理器信息: {processor.get_info()}")
class DictProcessor:
"""字典处理器"""
def process(self, data):
if isinstance(data, dict):
return {k: v.upper() if isinstance(v, str) else v
for k, v in data.items()}
raise TypeError("数据必须是字典")
def validate(self, data) -> bool:
return isinstance(data, dict)
# 使用register方法注册
DataProcessor.register(DictProcessor)
def test_virtual_subclass_behavior():
"""测试虚拟子类行为"""
dict_processor = DictProcessor()
print(f"dict_processor是DataProcessor: {isinstance(dict_processor, DataProcessor)}")
print(f"DictProcessor是DataProcessor子类: {issubclass(DictProcessor, DataProcessor)}")
test_data = {"name": "alice", "age": 30}
if dict_processor.validate(test_data):
result = dict_processor.process(test_data)
print(f"字典处理结果: {result}")
if __name__ == "__main__":
demonstrate_virtual_subclasses()
print("\n" + "="*50 + "\n")
test_virtual_subclass_behavior()
4.2 __subclasshook__方法
__subclasshook__方法允许自定义子类检查逻辑。
python
from abc import ABC, abstractmethod, ABCMeta
class Container(ABC):
"""容器接口"""
@abstractmethod
def add(self, item):
"""添加元素"""
pass
@abstractmethod
def remove(self, item):
"""移除元素"""
pass
@abstractmethod
def contains(self, item) -> bool:
"""检查是否包含元素"""
pass
@classmethod
def __subclasshook__(cls, C):
"""自定义子类检查"""
if cls is Container:
# 检查类是否具有必要的方法
required_methods = {'add', 'remove', 'contains'}
if all(any(method in B.__dict__ for B in C.__mro__)
for method in required_methods):
return True
return NotImplemented
# 现有的类,没有显式继承Container
class CustomCollection:
"""自定义集合类"""
def __init__(self):
self._items = []
def add(self, item):
self._items.append(item)
return True
def remove(self, item):
if item in self._items:
self._items.remove(item)
return True
return False
def contains(self, item) -> bool:
return item in self._items
def __len__(self):
return len(self._items)
def __iter__(self):
return iter(self._items)
class IncompleteCollection:
"""不完整的集合类(缺少某些方法)"""
def __init__(self):
self._items = []
def add(self, item):
self._items.append(item)
def demonstrate_subclasshook():
"""演示__subclasshook__方法"""
custom_collection = CustomCollection()
incomplete_collection = IncompleteCollection()
# 测试子类关系
print(f"CustomCollection是Container子类: {issubclass(CustomCollection, Container)}")
print(f"IncompleteCollection是Container子类: {issubclass(IncompleteCollection, Container)}")
print(f"custom_collection是Container实例: {isinstance(custom_collection, Container)}")
print(f"incomplete_collection是Container实例: {isinstance(incomplete_collection, Container)}")
# 使用自定义集合
custom_collection.add("apple")
custom_collection.add("banana")
custom_collection.add("orange")
print(f"集合包含'apple': {custom_collection.contains('apple')}")
print(f"集合大小: {len(custom_collection)}")
custom_collection.remove("banana")
print(f"移除后集合大小: {len(custom_collection)}")
print("集合内容:")
for item in custom_collection:
print(f" - {item}")
if __name__ == "__main__":
demonstrate_subclasshook()
5. 标准库中的抽象基类
Python标准库在collections.abc模块中提供了许多有用的抽象基类。
5.1 集合抽象基类
python
from collections.abc import Container, Sized, Iterable, Iterator, Sequence, Mapping
import collections.abc
def demonstrate_stdlib_abc():
"""演示标准库中的抽象基类"""
# 测试各种内置类型
test_cases = [
(list, "列表"),
(tuple, "元组"),
(str, "字符串"),
(dict, "字典"),
(set, "集合"),
]
abc_types = [
(Container, "Container"),
(Sized, "Sized"),
(Iterable, "Iterable"),
(Iterator, "Iterator"),
(Sequence, "Sequence"),
(Mapping, "Mapping"),
]
print("标准库抽象基类兼容性检查:")
print("=" * 60)
for abc_class, abc_name in abc_types:
print(f"\n{abc_name}:")
for test_class, test_name in test_cases:
result = issubclass(test_class, abc_class)
print(f" {test_name:8} -> {result}")
class CustomSequence(collections.abc.Sequence):
"""自定义序列实现"""
def __init__(self, *args):
self._data = list(args)
def __getitem__(self, index):
return self._data[index]
def __len__(self):
return len(self._data)
def __repr__(self):
return f"CustomSequence({self._data})"
def count(self, value):
return self._data.count(value)
def index(self, value, start=0, stop=None):
if stop is None:
stop = len(self._data)
try:
return self._data.index(value, start, stop)
except ValueError:
return -1
class CustomMapping(collections.abc.Mapping):
"""自定义映射实现"""
def __init__(self, **kwargs):
self._data = kwargs
def __getitem__(self, key):
return self._data[key]
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
def __repr__(self):
return f"CustomMapping({self._data})"
def demonstrate_custom_abc_implementations():
"""演示自定义抽象基类实现"""
print("\n" + "=" * 60)
print("自定义抽象基类实现演示")
print("=" * 60)
# 自定义序列
custom_seq = CustomSequence(1, 2, 3, 4, 5, 2, 3)
print(f"自定义序列: {custom_seq}")
print(f"序列长度: {len(custom_seq)}")
print(f"索引2: {custom_seq[2]}")
print(f"切片[1:4]: {custom_seq[1:4]}")
print(f"计数2: {custom_seq.count(2)}")
print(f"索引3: {custom_seq.index(3)}")
print(f"\n自定义序列是Sequence: {isinstance(custom_seq, Sequence)}")
print(f"自定义序列是Iterable: {isinstance(custom_seq, Iterable)}")
print(f"自定义序列是Sized: {isinstance(custom_seq, Sized)}")
print("\n" + "-" * 40)
# 自定义映射
custom_map = CustomMapping(name="Alice", age=30, city="Beijing")
print(f"自定义映射: {custom_map}")
print(f"映射大小: {len(custom_map)}")
print(f"键'name': {custom_map['name']}")
print(f"所有键: {list(custom_map.keys())}")
print(f"所有值: {list(custom_map.values())}")
print(f"\n自定义映射是Mapping: {isinstance(custom_map, Mapping)}")
print(f"自定义映射是Iterable: {isinstance(custom_map, Iterable)}")
print(f"自定义映射是Sized: {isinstance(custom_map, Sized)}")
if __name__ == "__main__":
demonstrate_stdlib_abc()
demonstrate_custom_abc_implementations()
5.2 上下文管理器抽象基类
python
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import contextmanager, AbstractContextManager
import time
class DatabaseTransaction(AbstractContextManager):
"""数据库事务上下文管理器"""
def __init__(self, db_connection):
self.db_connection = db_connection
self._in_transaction = False
def __enter__(self):
self._in_transaction = True
print(f"开始数据库事务: {self.db_connection}")
# 模拟开始事务
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._in_transaction = False
if exc_type is None:
print(f"提交数据库事务: {self.db_connection}")
# 模拟提交事务
else:
print(f"回滚数据库事务: {self.db_connection} (错误: {exc_val})")
# 模拟回滚事务
return False # 不抑制异常
def execute(self, query):
"""在事务中执行查询"""
if not self._in_transaction:
raise RuntimeError("不在事务中")
print(f"执行查询: {query}")
return f"结果: {query}"
class FileProcessor(ABC):
"""文件处理器抽象基类"""
@abstractmethod
def process_file(self, filename):
"""处理文件"""
pass
@contextmanager
def file_context(self, filename):
"""文件处理上下文"""
print(f"打开文件: {filename}")
start_time = time.time()
try:
yield self
except Exception as e:
print(f"处理文件时出错: {e}")
raise
finally:
end_time = time.time()
print(f"关闭文件: {filename} (耗时: {end_time - start_time:.2f}秒)")
class TextFileProcessor(FileProcessor):
"""文本文件处理器"""
def process_file(self, filename):
with self.file_context(filename):
print(f"处理文本文件: {filename}")
# 模拟文件处理
time.sleep(0.1)
return f"处理完成: {filename}"
class BinaryFileProcessor(FileProcessor):
"""二进制文件处理器"""
def process_file(self, filename):
with self.file_context(filename):
print(f"处理二进制文件: {filename}")
# 模拟文件处理
time.sleep(0.2)
return f"处理完成: {filename}"
def demonstrate_context_managers():
"""演示上下文管理器抽象基类"""
print("数据库事务演示:")
print("-" * 40)
# 数据库事务演示
db_conn = "MySQL连接@localhost"
# 正常执行
with DatabaseTransaction(db_conn) as transaction:
result1 = transaction.execute("SELECT * FROM users")
result2 = transaction.execute("UPDATE users SET active = 1")
print(result1)
print(result2)
print("\n带异常的事务:")
try:
with DatabaseTransaction(db_conn) as transaction:
transaction.execute("BEGIN TRANSACTION")
transaction.execute("INSERT INTO users VALUES (...)")
raise ValueError("模拟的错误")
except ValueError as e:
print(f"捕获异常: {e}")
print("\n" + "=" * 50)
print("文件处理器演示:")
print("-" * 40)
# 文件处理器演示
text_processor = TextFileProcessor()
binary_processor = BinaryFileProcessor()
print(text_processor.process_file("document.txt"))
print()
print(binary_processor.process_file("image.jpg"))
if __name__ == "__main__":
demonstrate_context_managers()
6. 类型检查与抽象基类
抽象基类在类型检查和静态分析中非常有用。
6.1 类型提示与抽象基类
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, TypeVar, Generic, Optional
from dataclasses import dataclass
T = TypeVar('T')
U = TypeVar('U')
class Repository(ABC, Generic[T]):
"""泛型仓库抽象基类"""
@abstractmethod
def add(self, item: T) -> bool:
"""添加项目"""
pass
@abstractmethod
def get(self, identifier: Any) -> Optional[T]:
"""根据标识符获取项目"""
pass
@abstractmethod
def get_all(self) -> List[T]:
"""获取所有项目"""
pass
@abstractmethod
def update(self, item: T) -> bool:
"""更新项目"""
pass
@abstractmethod
def delete(self, identifier: Any) -> bool:
"""删除项目"""
pass
@dataclass
class User:
"""用户数据类"""
id: int
name: str
email: str
age: int
class UserRepository(Repository[User]):
"""用户仓库实现"""
def __init__(self):
self._storage: Dict[int, User] = {}
self._next_id = 1
def add(self, user: User) -> bool:
if user.id in self._storage:
return False
self._storage[user.id] = user
return True
def get(self, user_id: int) -> Optional[User]:
return self._storage.get(user_id)
def get_all(self) -> List[User]:
return list(self._storage.values())
def update(self, user: User) -> bool:
if user.id not in self._storage:
return False
self._storage[user.id] = user
return True
def delete(self, user_id: int) -> bool:
if user_id not in self._storage:
return False
del self._storage[user_id]
return True
def create_user(self, name: str, email: str, age: int) -> User:
"""创建新用户"""
user = User(id=self._next_id, name=name, email=email, age=age)
self._next_id += 1
self.add(user)
return user
def demonstrate_generic_repository():
"""演示泛型仓库模式"""
user_repo = UserRepository()
# 创建用户
users = [
user_repo.create_user("Alice", "alice@example.com", 30),
user_repo.create_user("Bob", "bob@example.com", 25),
user_repo.create_user("Charlie", "charlie@example.com", 35)
]
print("所有用户:")
for user in user_repo.get_all():
print(f" ID: {user.id}, 姓名: {user.name}, 邮箱: {user.email}, 年龄: {user.age}")
print(f"\n获取ID为2的用户: {user_repo.get(2)}")
# 更新用户
bob = user_repo.get(2)
if bob:
bob.age = 26
user_repo.update(bob)
print(f"更新后的Bob: {user_repo.get(2)}")
# 删除用户
user_repo.delete(3)
print(f"\n删除ID为3的用户后,总用户数: {len(user_repo.get_all())}")
class DataValidator(ABC):
"""数据验证器抽象基类"""
@abstractmethod
def validate(self, data: Any) -> bool:
"""验证数据"""
pass
@abstractmethod
def get_errors(self) -> List[str]:
"""获取错误信息"""
pass
class UserValidator(DataValidator):
"""用户数据验证器"""
def __init__(self):
self._errors: List[str] = []
def validate(self, user: Any) -> bool:
self._errors.clear()
if not isinstance(user, User):
self._errors.append("数据必须是User类型")
return False
if not user.name or len(user.name.strip()) == 0:
self._errors.append("姓名不能为空")
if not user.email or '@' not in user.email:
self._errors.append("邮箱格式不正确")
if user.age < 0 or user.age > 150:
self._errors.append("年龄必须在0-150之间")
return len(self._errors) == 0
def get_errors(self) -> List[str]:
return self._errors.copy()
def demonstrate_validator_pattern():
"""演示验证器模式"""
validator = UserValidator()
test_users = [
User(1, "Alice", "alice@example.com", 30), # 有效
User(2, "", "invalid-email", -5), # 无效
User(3, "Bob", "bob@example.com", 200), # 无效年龄
]
for user in test_users:
print(f"\n验证用户: {user.name}")
if validator.validate(user):
print(" ✓ 验证通过")
else:
print(" ✗ 验证失败:")
for error in validator.get_errors():
print(f" - {error}")
if __name__ == "__main__":
demonstrate_generic_repository()
print("\n" + "=" * 60)
demonstrate_validator_pattern()
7. 设计模式与抽象基类
抽象基类在设计模式实现中发挥着重要作用。
7.1 策略模式
python
from abc import ABC, abstractmethod
from typing import List
from dataclasses import dataclass
class CompressionStrategy(ABC):
"""压缩策略抽象基类"""
@abstractmethod
def compress(self, data: str) -> str:
"""压缩数据"""
pass
@abstractmethod
def decompress(self, compressed_data: str) -> str:
"""解压缩数据"""
pass
@abstractmethod
def get_compression_ratio(self, original_data: str, compressed_data: str) -> float:
"""获取压缩率"""
pass
class RLECompression(CompressionStrategy):
"""游程编码压缩策略"""
def compress(self, data: str) -> str:
if not data:
return ""
compressed = []
count = 1
current_char = data[0]
for char in data[1:]:
if char == current_char:
count += 1
else:
compressed.append(f"{count}{current_char}")
current_char = char
count = 1
compressed.append(f"{count}{current_char}")
return "".join(compressed)
def decompress(self, compressed_data: str) -> str:
decompressed = []
i = 0
while i < len(compressed_data):
count_str = ""
while i < len(compressed_data) and compressed_data[i].isdigit():
count_str += compressed_data[i]
i += 1
if i < len(compressed_data):
char = compressed_data[i]
count = int(count_str) if count_str else 1
decompressed.append(char * count)
i += 1
return "".join(decompressed)
def get_compression_ratio(self, original_data: str, compressed_data: str) -> float:
if not original_data:
return 0.0
return len(compressed_data) / len(original_data)
class DictionaryCompression(CompressionStrategy):
"""字典压缩策略"""
def __init__(self):
self.dictionary = {}
self.next_code = 0
def compress(self, data: str) -> str:
self.dictionary.clear()
self.next_code = 0
words = data.split()
compressed = []
for word in words:
if word not in self.dictionary:
self.dictionary[word] = self.next_code
self.next_code += 1
compressed.append(str(self.dictionary[word]))
return " ".join(compressed)
def decompress(self, compressed_data: str) -> str:
if not self.dictionary:
raise ValueError("字典未初始化")
# 反转字典
reverse_dict = {v: k for k, v in self.dictionary.items()}
codes = compressed_data.split()
decompressed = []
for code in codes:
decompressed.append(reverse_dict[int(code)])
return " ".join(decompressed)
def get_compression_ratio(self, original_data: str, compressed_data: str) -> float:
if not original_data:
return 0.0
return len(compressed_data) / len(original_data)
class CompressionContext:
"""压缩上下文"""
def __init__(self, strategy: CompressionStrategy):
self._strategy = strategy
def set_strategy(self, strategy: CompressionStrategy):
"""设置压缩策略"""
self._strategy = strategy
def compress_data(self, data: str) -> str:
"""压缩数据"""
return self._strategy.compress(data)
def decompress_data(self, compressed_data: str) -> str:
"""解压缩数据"""
return self._strategy.decompress(compressed_data)
def get_compression_info(self, original_data: str) -> dict:
"""获取压缩信息"""
compressed_data = self.compress_data(original_data)
ratio = self._strategy.get_compression_ratio(original_data, compressed_data)
return {
"original_size": len(original_data),
"compressed_size": len(compressed_data),
"compression_ratio": ratio,
"compressed_data": compressed_data
}
def demonstrate_strategy_pattern():
"""演示策略模式"""
test_data = "AAAABBBCCDAA"
print(f"测试数据: {test_data}")
print("=" * 50)
# RLE压缩
rle_strategy = RLECompression()
rle_context = CompressionContext(rle_strategy)
rle_info = rle_context.get_compression_info(test_data)
print("RLE压缩:")
print(f" 原始大小: {rle_info['original_size']}")
print(f" 压缩后大小: {rle_info['compressed_size']}")
print(f" 压缩率: {rle_info['compression_ratio']:.2%}")
print(f" 压缩数据: {rle_info['compressed_data']}")
decompressed = rle_context.decompress_data(rle_info['compressed_data'])
print(f" 解压数据: {decompressed}")
print(f" 解压成功: {decompressed == test_data}")
print("\n" + "-" * 50)
# 字典压缩
text_data = "hello world hello python world python"
print(f"文本数据: {text_data}")
dict_strategy = DictionaryCompression()
dict_context = CompressionContext(dict_strategy)
dict_info = dict_context.get_compression_info(text_data)
print("\n字典压缩:")
print(f" 原始大小: {dict_info['original_size']}")
print(f" 压缩后大小: {dict_info['compressed_size']}")
print(f" 压缩率: {dict_info['compression_ratio']:.2%}")
print(f" 压缩数据: {dict_info['compressed_data']}")
decompressed_text = dict_context.decompress_data(dict_info['compressed_data'])
print(f" 解压数据: {decompressed_text}")
print(f" 解压成功: {decompressed_text == text_data}")
if __name__ == "__main__":
demonstrate_strategy_pattern()
7.2 观察者模式
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from datetime import datetime
import time
class Observer(ABC):
"""观察者抽象基类"""
@abstractmethod
def update(self, subject: Any, event: str, data: Any):
"""接收更新通知"""
pass
class Subject(ABC):
"""主题抽象基类"""
def __init__(self):
self._observers: List[Observer] = []
def attach(self, observer: Observer):
"""附加观察者"""
if observer not in self._observers:
self._observers.append(observer)
def detach(self, observer: Observer):
"""分离观察者"""
if observer in self._observers:
self._observers.remove(observer)
def notify(self, event: str, data: Any = None):
"""通知所有观察者"""
for observer in self._observers:
observer.update(self, event, data)
class StockMarket(Subject):
"""股票市场主题"""
def __init__(self, name: str):
super().__init__()
self.name = name
self._stocks: Dict[str, float] = {}
self._price_history: Dict[str, List[float]] = {}
def set_stock_price(self, symbol: str, price: float):
"""设置股票价格"""
old_price = self._stocks.get(symbol)
self._stocks[symbol] = price
# 记录价格历史
if symbol not in self._price_history:
self._price_history[symbol] = []
self._price_history[symbol].append(price)
# 通知观察者
event_data = {
'symbol': symbol,
'old_price': old_price,
'new_price': price,
'timestamp': datetime.now(),
'change': price - old_price if old_price else 0
}
if old_price is not None:
if price > old_price:
self.notify('price_increase', event_data)
elif price < old_price:
self.notify('price_decrease', event_data)
else:
self.notify('price_initial', event_data)
def get_stock_price(self, symbol: str) -> float:
"""获取股票价格"""
return self._stocks.get(symbol, 0.0)
def get_price_history(self, symbol: str) -> List[float]:
"""获取价格历史"""
return self._price_history.get(symbol, [])
class StockTrader(Observer):
"""股票交易者观察者"""
def __init__(self, name: str, budget: float):
self.name = name
self.budget = budget
self.portfolio: Dict[str, int] = {}
self.transaction_history: List[Dict] = []
def update(self, subject: StockMarket, event: str, data: Any):
"""接收市场更新"""
if event == 'price_decrease':
self._consider_buying(subject, data)
elif event == 'price_increase':
self._consider_selling(subject, data)
def _consider_buying(self, market: StockMarket, data: Dict):
"""考虑买入"""
symbol = data['symbol']
price = data['new_price']
# 简单的买入策略:价格下降超过10%且预算足够
if data['change'] < -price * 0.1 and self.budget >= price:
shares_to_buy = min(10, int(self.budget // price))
cost = shares_to_buy * price
# 执行买入
self.portfolio[symbol] = self.portfolio.get(symbol, 0) + shares_to_buy
self.budget -= cost
transaction = {
'type': 'BUY',
'symbol': symbol,
'shares': shares_to_buy,
'price': price,
'total': cost,
'timestamp': datetime.now()
}
self.transaction_history.append(transaction)
print(f"{self.name} 买入 {shares_to_buy} 股 {symbol} @ ${price:.2f}")
def _consider_selling(self, market: StockMarket, data: Dict):
"""考虑卖出"""
symbol = data['symbol']
price = data['new_price']
# 简单的卖出策略:价格上涨超过15%且持有该股票
if symbol in self.portfolio and data['change'] > price * 0.15:
shares_owned = self.portfolio[symbol]
shares_to_sell = min(5, shares_owned) # 每次最多卖出5股
revenue = shares_to_sell * price
# 执行卖出
self.portfolio[symbol] -= shares_to_sell
if self.portfolio[symbol] == 0:
del self.portfolio[symbol]
self.budget += revenue
transaction = {
'type': 'SELL',
'symbol': symbol,
'shares': shares_to_sell,
'price': price,
'total': revenue,
'timestamp': datetime.now()
}
self.transaction_history.append(transaction)
print(f"{self.name} 卖出 {shares_to_sell} 股 {symbol} @ ${price:.2f}")
def get_portfolio_value(self, market: StockMarket) -> float:
"""获取投资组合价值"""
stock_value = sum(
shares * market.get_stock_price(symbol)
for symbol, shares in self.portfolio.items()
)
return stock_value + self.budget
def print_status(self, market: StockMarket):
"""打印状态"""
total_value = self.get_portfolio_value(market)
print(f"\n{self.name} 状态:")
print(f" 现金: ${self.budget:.2f}")
print(f" 投资组合:")
for symbol, shares in self.portfolio.items():
price = market.get_stock_price(symbol)
value = shares * price
print(f" {symbol}: {shares} 股 @ ${price:.2f} = ${value:.2f}")
print(f" 总资产: ${total_value:.2f}")
class MarketAnalyst(Observer):
"""市场分析师观察者"""
def __init__(self, name: str):
self.name = name
self.analysis_report: List[str] = []
def update(self, subject: StockMarket, event: str, data: Any):
"""分析市场变化"""
symbol = data['symbol']
timestamp = data['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
if event == 'price_increase':
analysis = f"{timestamp} - {symbol} 上涨 ${data['change']:.2f} (新价格: ${data['new_price']:.2f})"
elif event == 'price_decrease':
analysis = f"{timestamp} - {symbol} 下跌 ${abs(data['change']):.2f} (新价格: ${data['new_price']:.2f})"
else:
analysis = f"{timestamp} - {symbol} 初始价格: ${data['new_price']:.2f}"
self.analysis_report.append(analysis)
print(f"分析师 {self.name}: {analysis}")
def print_report(self):
"""打印分析报告"""
print(f"\n{self.name} 的分析报告:")
for i, analysis in enumerate(self.analysis_report[-5:], 1): # 只显示最后5条
print(f" {i}. {analysis}")
def demonstrate_observer_pattern():
"""演示观察者模式"""
# 创建股票市场
market = StockMarket("NASDAQ")
# 创建交易者和分析师
trader1 = StockTrader("Alice", 10000.0)
trader2 = StockTrader("Bob", 8000.0)
analyst = MarketAnalyst("Dr. Smith")
# 注册观察者
market.attach(trader1)
market.attach(trader2)
market.attach(analyst)
print("股票市场模拟开始...")
print("=" * 50)
# 模拟股票价格变化
stocks = ['AAPL', 'GOOGL', 'MSFT', 'TSLA']
prices = {
'AAPL': 150.0,
'GOOGL': 2800.0,
'MSFT': 300.0,
'TSLA': 700.0
}
import random
for day in range(1, 6):
print(f"\n第 {day} 天:")
print("-" * 30)
for symbol in stocks:
# 随机价格变化 (-5% 到 +5%)
change_percent = random.uniform(-0.05, 0.05)
old_price = prices[symbol]
new_price = old_price * (1 + change_percent)
prices[symbol] = new_price
market.set_stock_price(symbol, new_price)
time.sleep(0.1) # 短暂暂停以便观察
# 显示最终状态
print("\n" + "=" * 50)
print("模拟结束")
print("=" * 50)
trader1.print_status(market)
trader2.print_status(market)
analyst.print_report()
if __name__ == "__main__":
demonstrate_observer_pattern()
8. 完整代码示例:电子商务系统
下面是一个完整的电子商务系统示例,展示抽象基类在实际项目中的应用。
python
"""
电子商务系统完整示例
演示抽象基类在实际项目中的应用
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union
from datetime import datetime
import uuid
from dataclasses import dataclass
from enum import Enum
class OrderStatus(Enum):
PENDING = "pending"
CONFIRMED = "confirmed"
SHIPPED = "shipped"
DELIVERED = "delivered"
CANCELLED = "cancelled"
class PaymentStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
REFUNDED = "refunded"
@dataclass
class Product:
id: str
name: str
price: float
description: str
stock_quantity: int
category: str
@dataclass
class Customer:
id: str
name: str
email: str
address: str
phone: str
@dataclass
class OrderItem:
product_id: str
product_name: str
quantity: int
unit_price: float
@property
def total_price(self) -> float:
return self.quantity * self.unit_price
class Repository(ABC):
"""泛型仓库抽象基类"""
@abstractmethod
def get_by_id(self, id: str):
pass
@abstractmethod
def get_all(self) -> List:
pass
@abstractmethod
def add(self, entity) -> bool:
pass
@abstractmethod
def update(self, entity) -> bool:
pass
@abstractmethod
def delete(self, id: str) -> bool:
pass
class ProductRepository(Repository):
"""产品仓库"""
def __init__(self):
self._products: Dict[str, Product] = {}
def get_by_id(self, id: str) -> Optional[Product]:
return self._products.get(id)
def get_all(self) -> List[Product]:
return list(self._products.values())
def add(self, product: Product) -> bool:
if product.id in self._products:
return False
self._products[product.id] = product
return True
def update(self, product: Product) -> bool:
if product.id not in self._products:
return False
self._products[product.id] = product
return True
def delete(self, id: str) -> bool:
if id not in self._products:
return False
del self._products[id]
return True
def get_by_category(self, category: str) -> List[Product]:
return [p for p in self._products.values() if p.category == category]
def update_stock(self, product_id: str, quantity: int) -> bool:
product = self.get_by_id(product_id)
if product and product.stock_quantity >= quantity:
product.stock_quantity -= quantity
return self.update(product)
return False
class CustomerRepository(Repository):
"""客户仓库"""
def __init__(self):
self._customers: Dict[str, Customer] = {}
def get_by_id(self, id: str) -> Optional[Customer]:
return self._customers.get(id)
def get_all(self) -> List[Customer]:
return list(self._customers.values())
def add(self, customer: Customer) -> bool:
if customer.id in self._customers:
return False
self._customers[customer.id] = customer
return True
def update(self, customer: Customer) -> bool:
if customer.id not in self._customers:
return False
self._customers[customer.id] = customer
return True
def delete(self, id: str) -> bool:
if id not in self._customers:
return False
del self._customers[id]
return True
class Order:
"""订单类"""
def __init__(self, customer_id: str):
self.id = str(uuid.uuid4())
self.customer_id = customer_id
self.items: List[OrderItem] = []
self.status = OrderStatus.PENDING
self.payment_status = PaymentStatus.PENDING
self.created_at = datetime.now()
self.updated_at = datetime.now()
self.shipping_address = ""
self.total_amount = 0.0
def add_item(self, product: Product, quantity: int) -> bool:
"""添加订单项"""
if quantity <= 0 or product.stock_quantity < quantity:
return False
# 检查是否已存在该商品
for item in self.items:
if item.product_id == product.id:
return False
order_item = OrderItem(
product_id=product.id,
product_name=product.name,
quantity=quantity,
unit_price=product.price
)
self.items.append(order_item)
self._update_totals()
return True
def remove_item(self, product_id: str) -> bool:
"""移除订单项"""
for i, item in enumerate(self.items):
if item.product_id == product_id:
self.items.pop(i)
self._update_totals()
return True
return False
def update_quantity(self, product_id: str, quantity: int) -> bool:
"""更新商品数量"""
for item in self.items:
if item.product_id == product_id:
if quantity <= 0:
return self.remove_item(product_id)
item.quantity = quantity
self._update_totals()
return True
return False
def _update_totals(self):
"""更新总金额"""
self.total_amount = sum(item.total_price for item in self.items)
self.updated_at = datetime.now()
def get_order_summary(self) -> Dict:
"""获取订单摘要"""
return {
'order_id': self.id,
'customer_id': self.customer_id,
'total_amount': self.total_amount,
'item_count': len(self.items),
'status': self.status.value,
'payment_status': self.payment_status.value,
'created_at': self.created_at.isoformat()
}
class OrderRepository(Repository):
"""订单仓库"""
def __init__(self):
self._orders: Dict[str, Order] = {}
def get_by_id(self, id: str) -> Optional[Order]:
return self._orders.get(id)
def get_all(self) -> List[Order]:
return list(self._orders.values())
def add(self, order: Order) -> bool:
if order.id in self._orders:
return False
self._orders[order.id] = order
return True
def update(self, order: Order) -> bool:
if order.id not in self._orders:
return False
self._orders[order.id] = order
return True
def delete(self, id: str) -> bool:
if id not in self._orders:
return False
del self._orders[id]
return True
def get_customer_orders(self, customer_id: str) -> List[Order]:
return [order for order in self._orders.values()
if order.customer_id == customer_id]
def get_orders_by_status(self, status: OrderStatus) -> List[Order]:
return [order for order in self._orders.values()
if order.status == status]
class PaymentProcessor(ABC):
"""支付处理器抽象基类"""
@abstractmethod
def process_payment(self, order: Order, payment_details: Dict) -> bool:
"""处理支付"""
pass
@abstractmethod
def refund_payment(self, order: Order, amount: float) -> bool:
"""退款"""
pass
@abstractmethod
def get_payment_status(self, order: Order) -> PaymentStatus:
"""获取支付状态"""
pass
class CreditCardProcessor(PaymentProcessor):
"""信用卡支付处理器"""
def process_payment(self, order: Order, payment_details: Dict) -> bool:
# 模拟信用卡支付处理
card_number = payment_details.get('card_number')
expiry_date = payment_details.get('expiry_date')
cvv = payment_details.get('cvv')
if not all([card_number, expiry_date, cvv]):
return False
# 模拟支付处理延迟
import time
time.sleep(0.1)
# 模拟支付成功(在实际应用中这里会有真实的支付逻辑)
order.payment_status = PaymentStatus.COMPLETED
return True
def refund_payment(self, order: Order, amount: float) -> bool:
if order.payment_status != PaymentStatus.COMPLETED:
return False
# 模拟退款处理
order.payment_status = PaymentStatus.REFUNDED
return True
def get_payment_status(self, order: Order) -> PaymentStatus:
return order.payment_status
class NotificationService(ABC):
"""通知服务抽象基类"""
@abstractmethod
def send_notification(self, recipient: str, message: str, subject: str = "") -> bool:
"""发送通知"""
pass
class EmailNotificationService(NotificationService):
"""邮件通知服务"""
def send_notification(self, recipient: str, message: str, subject: str = "") -> bool:
print(f"发送邮件到: {recipient}")
print(f"主题: {subject}")
print(f"内容: {message}")
print("-" * 50)
return True
class OrderService:
"""订单服务"""
def __init__(
self,
order_repository: OrderRepository,
product_repository: ProductRepository,
customer_repository: CustomerRepository,
payment_processor: PaymentProcessor,
notification_service: NotificationService
):
self.order_repository = order_repository
self.product_repository = product_repository
self.customer_repository = customer_repository
self.payment_processor = payment_processor
self.notification_service = notification_service
def create_order(self, customer_id: str) -> Optional[Order]:
"""创建新订单"""
customer = self.customer_repository.get_by_id(customer_id)
if not customer:
return None
order = Order(customer_id)
if self.order_repository.add(order):
return order
return None
def add_product_to_order(self, order_id: str, product_id: str, quantity: int) -> bool:
"""添加商品到订单"""
order = self.order_repository.get_by_id(order_id)
product = self.product_repository.get_by_id(product_id)
if not order or not product:
return False
if order.status != OrderStatus.PENDING:
return False
success = order.add_item(product, quantity)
if success:
self.order_repository.update(order)
return success
def place_order(self, order_id: str, payment_details: Dict, shipping_address: str) -> bool:
"""下单"""
order = self.order_repository.get_by_id(order_id)
if not order or order.status != OrderStatus.PENDING:
return False
# 设置配送地址
order.shipping_address = shipping_address
# 处理支付
if not self.payment_processor.process_payment(order, payment_details):
order.payment_status = PaymentStatus.FAILED
self.order_repository.update(order)
return False
# 更新库存
for item in order.items:
if not self.product_repository.update_stock(item.product_id, item.quantity):
# 库存不足,回滚
self.payment_processor.refund_payment(order, order.total_amount)
return False
# 更新订单状态
order.status = OrderStatus.CONFIRMED
self.order_repository.update(order)
# 发送确认邮件
customer = self.customer_repository.get_by_id(order.customer_id)
if customer:
message = f"您的订单 #{order.id} 已确认。总金额: ${order.total_amount:.2f}"
self.notification_service.send_notification(
customer.email,
message,
"订单确认"
)
return True
def get_order_status(self, order_id: str) -> Optional[Dict]:
"""获取订单状态"""
order = self.order_repository.get_by_id(order_id)
if not order:
return None
return order.get_order_summary()
class ECommerceSystem:
"""电子商务系统"""
def __init__(self):
# 初始化所有组件
self.product_repository = ProductRepository()
self.customer_repository = CustomerRepository()
self.order_repository = OrderRepository()
self.payment_processor = CreditCardProcessor()
self.notification_service = EmailNotificationService()
self.order_service = OrderService(
self.order_repository,
self.product_repository,
self.customer_repository,
self.payment_processor,
self.notification_service
)
self._initialize_sample_data()
def _initialize_sample_data(self):
"""初始化示例数据"""
# 添加示例产品
products = [
Product("1", "笔记本电脑", 999.99, "高性能笔记本电脑", 10, "电子产品"),
Product("2", "智能手机", 699.99, "最新款智能手机", 20, "电子产品"),
Product("3", "书籍", 29.99, "编程书籍", 50, "图书"),
Product("4", "耳机", 199.99, "无线降噪耳机", 15, "电子产品"),
]
for product in products:
self.product_repository.add(product)
# 添加示例客户
customers = [
Customer("1", "张三", "zhangsan@example.com", "北京市朝阳区", "13800138000"),
Customer("2", "李四", "lisi@example.com", "上海市浦东新区", "13900139000"),
]
for customer in customers:
self.customer_repository.add(customer)
def run_demo(self):
"""运行演示"""
print("电子商务系统演示")
print("=" * 50)
# 显示可用产品
print("\n可用产品:")
for product in self.product_repository.get_all():
print(f" {product.id}. {product.name} - ${product.price:.2f} (库存: {product.stock_quantity})")
# 创建订单
print("\n1. 创建订单...")
order = self.order_service.create_order("1")
if not order:
print("创建订单失败")
return
print(f"订单创建成功: {order.id}")
# 添加商品到订单
print("\n2. 添加商品到订单...")
self.order_service.add_product_to_order(order.id, "1", 1) # 笔记本电脑
self.order_service.add_product_to_order(order.id, "3", 2) # 书籍
# 显示订单摘要
order_summary = self.order_service.get_order_status(order.id)
print(f"订单摘要: {order_summary}")
# 下单
print("\n3. 下单...")
payment_details = {
'card_number': '4111111111111111',
'expiry_date': '12/25',
'cvv': '123'
}
shipping_address = "北京市朝阳区某某街道123号"
if self.order_service.place_order(order.id, payment_details, shipping_address):
print("下单成功!")
# 显示最终订单状态
final_summary = self.order_service.get_order_status(order.id)
print(f"最终订单状态: {final_summary}")
else:
print("下单失败")
# 显示库存变化
print("\n4. 库存状态:")
for product in self.product_repository.get_all():
print(f" {product.name}: {product.stock_quantity} 件")
def main():
"""主函数"""
ecommerce_system = ECommerceSystem()
ecommerce_system.run_demo()
if __name__ == "__main__":
main()
9. 最佳实践与注意事项
9.1 抽象基类设计原则
- 单一职责原则:每个抽象基类应该只有一个明确的职责
- 接口隔离原则:不要强迫客户端依赖它们不需要的方法
- 里氏替换原则:子类应该能够替换它们的父类
- 依赖倒置原则:依赖于抽象而不是具体实现
9.2 常见陷阱与解决方案
python
from abc import ABC, abstractmethod
class CommonMistakesDemo:
"""常见陷阱演示"""
class BadAbstractClass(ABC):
"""不好的抽象基类设计"""
@abstractmethod
def method1(self):
pass
@abstractmethod
def method2(self):
pass
def too_many_concrete_methods(self):
"""过多的具体方法"""
# 这违反了接口隔离原则
pass
def implementation_details(self):
"""包含实现细节"""
# 抽象基类应该关注接口,而不是实现
pass
class GoodAbstractClass(ABC):
"""良好的抽象基类设计"""
@abstractmethod
def essential_operation(self):
"""必要的操作"""
pass
def optional_operation(self):
"""可选的操作,提供默认实现"""
raise NotImplementedError("子类可以选择实现此方法")
def template_method(self):
"""模板方法模式"""
self.essential_operation()
self._hook_method()
def _hook_method(self):
"""钩子方法,子类可以重写"""
pass
def demonstrate_best_practices():
"""演示最佳实践"""
print("抽象基类最佳实践:")
print("1. 保持抽象基类简洁,只定义必要的抽象方法")
print("2. 使用模板方法模式提供算法骨架")
print("3. 为可选操作提供默认实现或抛出NotImplementedError")
print("4. 使用钩子方法允许子类扩展行为")
print("5. 避免在抽象基类中包含具体实现细节")
if __name__ == "__main__":
demonstrate_best_practices()
10. 总结
Python的抽象基类(ABC)是定义接口契约的强大工具,它提供了以下关键优势:
- 明确的接口定义:通过抽象方法强制子类实现特定接口
- 类型检查支持:增强代码的可靠性和可维护性
- 设计模式实现:支持策略模式、观察者模式等经典设计模式
- 代码组织:促进清晰的代码结构和架构设计
- 多态性:实现统一的接口,不同的行为
10.1 关键要点
- 使用
@abstractmethod定义抽象方法 - 抽象基类不能被实例化
- 子类必须实现所有抽象方法
- 可以使用
register方法创建虚拟子类 __subclasshook__方法允许自定义子类检查逻辑
10.2 适用场景
- 定义框架或库的扩展点
- 实现插件系统
- 创建清晰的API契约
- 实现设计模式
- 大型项目的架构设计
通过合理使用抽象基类,您可以创建更加健壮、可维护和可扩展的Python应用程序。抽象基类不仅是一种技术工具,更是一种设计哲学,它鼓励开发者思考接口设计和组件之间的关系。
11. 代码自查
在完成本文的所有代码示例后,我们进行了以下自查以确保代码质量:
- 语法正确性:所有代码都通过Python语法检查
- 抽象基类规范 :确保所有抽象方法都正确使用
@abstractmethod装饰器 - 类型提示:在适当的地方使用了类型提示
- 异常处理:关键操作都有适当的异常处理
- 代码注释:所有复杂逻辑都有清晰的注释说明
- 设计模式应用:正确实现了各种设计模式
- 实际应用场景:代码示例基于真实的应用场景
这些代码示例可以直接运行,并且包含了适当的设计模式和最佳实践,可以作为学习和实际项目参考。