模板方法模式全解析:用抽象基类定义算法骨架,让子类优雅填充细节
一、引言:你是否写过这样的重复代码?
在真实项目中,你是否遇到过这种情况:
python
# 数据处理流程 A
def process_csv_data():
data = read_csv_file() # 读取数据
data = clean_data(data) # 清洗
data = validate_csv(data) # 校验(CSV特有)
result = analyze(data) # 分析
save_to_database(result) # 保存
send_report_email(result) # 发送报告
# 数据处理流程 B(几乎一样,但某些步骤不同)
def process_json_data():
data = fetch_from_api() # 读取数据(不同)
data = clean_data(data) # 清洗(相同)
data = validate_json(data) # 校验(JSON特有)
result = analyze(data) # 分析(相同)
save_to_database(result) # 保存(相同)
# 不需要发送报告(不同)
两个流程的骨架完全相同,只有若干步骤有差异。如果用复制粘贴解决,当「分析」逻辑需要修改时,你要同时改两处;如果流程再增加到五种、十种,维护将成为噩梦。
这正是模板方法模式(Template Method Pattern) 要解决的核心问题。
模板方法模式是 GoF 设计模式中最朴素、最实用的模式之一,其核心思想一句话概括:在父类中定义算法的骨架(模板),将某些步骤延迟到子类中实现,子类可以在不改变算法结构的前提下,重新定义特定步骤的具体实现。
Python 的 abc 模块为这一模式提供了天然的语言支持。本文将带你从理论到实战,完整掌握模板方法模式在 Python 工程中的精髓应用。
二、核心概念:骨架与填充
2.1 三个关键元素
模板方法模式由三类方法构成,理解它们是掌握模式的关键:
模板方法(Template Method) :定义算法骨架的方法,通常在抽象基类中实现,不允许子类覆盖 (Python 中可通过约定或 __init_subclass__ 强制)。它按固定顺序调用其他步骤。
抽象步骤(Abstract Steps) :算法中必须由子类实现的步骤,用 @abstractmethod 标注,子类必须覆盖,否则无法实例化。
钩子方法(Hook Methods):提供默认实现(通常为空或返回默认值)的可选步骤,子类可以选择性覆盖,用于控制算法中的条件分支。
2.2 最小骨架示例
python
from abc import ABC, abstractmethod
class DataProcessor(ABC):
"""抽象基类:数据处理算法骨架"""
def process(self) -> None:
"""模板方法:定义处理流程,子类不应覆盖此方法"""
raw_data = self.read_data() # 抽象步骤
clean = self.clean_data(raw_data) # 抽象步骤
if self.should_validate(): # 钩子:是否需要校验
clean = self.validate(clean)
result = self.analyze(clean) # 抽象步骤
self.save_result(result) # 抽象步骤
if self.should_notify(): # 钩子:是否需要通知
self.send_notification(result)
# ===== 抽象步骤:子类必须实现 =====
@abstractmethod
def read_data(self) -> list:
pass
@abstractmethod
def clean_data(self, data: list) -> list:
pass
@abstractmethod
def analyze(self, data: list) -> dict:
pass
@abstractmethod
def save_result(self, result: dict) -> None:
pass
# ===== 钩子方法:子类可选择覆盖 =====
def validate(self, data: list) -> list:
"""默认校验:过滤 None 值"""
return [item for item in data if item is not None]
def should_validate(self) -> bool:
"""钩子:是否执行校验步骤,默认开启"""
return True
def should_notify(self) -> bool:
"""钩子:是否发送通知,默认关闭"""
return False
def send_notification(self, result: dict) -> None:
"""钩子:通知逻辑,子类可覆盖"""
pass
三、实战案例一:多格式数据处理管道
3.1 实现 CSV 和 API 两种数据处理器
python
import csv
import json
import io
from datetime import datetime
class CSVDataProcessor(DataProcessor):
"""CSV 文件数据处理器"""
def __init__(self, filepath: str, email: str = ''):
self.filepath = filepath
self.email = email
def read_data(self) -> list:
print(f"[CSV处理器] 从文件读取: {self.filepath}")
# 模拟 CSV 读取
sample_csv = "name,age,score\n张三,25,88\n李四,30,\n王五,22,95\n"
reader = csv.DictReader(io.StringIO(sample_csv))
return list(reader)
def clean_data(self, data: list) -> list:
print(f"[CSV处理器] 清洗数据: {len(data)} 条")
for row in data:
# 类型转换
row['age'] = int(row['age']) if row.get('age') else 0
row['score'] = float(row['score']) if row.get('score') else None
return data
def validate(self, data: list) -> list:
"""CSV 特有校验:过滤 score 为空的记录"""
valid = [row for row in data if row['score'] is not None]
print(f"[CSV处理器] 校验后: {len(valid)}/{len(data)} 条有效")
return valid
def analyze(self, data: list) -> dict:
scores = [row['score'] for row in data]
return {
'count': len(scores),
'avg_score': sum(scores) / len(scores) if scores else 0,
'max_score': max(scores) if scores else 0,
'min_score': min(scores) if scores else 0,
'processed_at': datetime.now().isoformat()
}
def save_result(self, result: dict) -> None:
print(f"[CSV处理器] 保存分析结果到数据库: {result}")
def should_notify(self) -> bool:
return bool(self.email) # 有邮箱才发通知
def send_notification(self, result: dict) -> None:
print(f"[CSV处理器] 发送报告邮件至 {self.email}: "
f"平均分 {result['avg_score']:.1f}")
class APIDataProcessor(DataProcessor):
"""API 接口数据处理器"""
def __init__(self, api_url: str):
self.api_url = api_url
def read_data(self) -> list:
print(f"[API处理器] 从接口获取数据: {self.api_url}")
# 模拟 API 返回
return [
{'id': 1, 'value': 42.5, 'tag': 'A'},
{'id': 2, 'value': None, 'tag': 'B'},
{'id': 3, 'value': 88.0, 'tag': 'A'},
{'id': 4, 'value': 15.3, 'tag': 'C'},
]
def clean_data(self, data: list) -> list:
print(f"[API处理器] 标准化字段格式")
for item in data:
item['value'] = float(item['value']) if item['value'] else None
return data
def analyze(self, data: list) -> dict:
valid_values = [item['value'] for item in data if item['value']]
tag_groups = {}
for item in data:
if item['value']:
tag_groups.setdefault(item['tag'], []).append(item['value'])
return {
'total': len(data),
'valid': len(valid_values),
'avg': sum(valid_values) / len(valid_values) if valid_values else 0,
'by_tag': {tag: sum(vals)/len(vals)
for tag, vals in tag_groups.items()}
}
def save_result(self, result: dict) -> None:
print(f"[API处理器] 写入缓存层: {json.dumps(result, ensure_ascii=False)}")
def should_validate(self) -> bool:
return True # API 数据必须校验
# ===== 运行演示 =====
print("=" * 50)
print("【CSV 数据处理(带邮件通知)】")
csv_processor = CSVDataProcessor('students.csv', email='admin@example.com')
csv_processor.process()
print("\n" + "=" * 50)
print("【API 数据处理(无通知)】")
api_processor = APIDataProcessor('https://api.example.com/data')
api_processor.process()
运行效果:
==================================================
【CSV 数据处理(带邮件通知)】
[CSV处理器] 从文件读取: students.csv
[CSV处理器] 清洗数据: 3 条
[CSV处理器] 校验后: 2/3 条有效
[CSV处理器] 保存分析结果到数据库: {'count': 2, 'avg_score': 91.5, ...}
[CSV处理器] 发送报告邮件至 admin@example.com: 平均分 91.5
==================================================
【API 数据处理(无通知)】
[API处理器] 从接口获取数据: https://api.example.com/data
[API处理器] 标准化字段格式
[API处理器] 写入缓存层: {"total": 4, "valid": 3, ...}
核心骨架 process() 方法从未修改,两个子类只是填充了自己负责的那几块「拼图」。
四、实战案例二:Web 爬虫框架
爬虫系统是模板方法的另一个绝佳应用场景------爬取流程(请求→解析→存储→限速)高度一致,但不同网站的解析逻辑完全不同:
python
import time
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class CrawlResult:
url: str
title: str
items: list = field(default_factory=list)
error: Optional[str] = None
crawled_at: str = field(
default_factory=lambda: datetime.now().strftime('%Y-%m-%d %H:%M:%S')
)
class BaseCrawler(ABC):
"""爬虫抽象基类:定义爬取算法骨架"""
def __init__(self, delay_range: tuple = (1.0, 3.0)):
self.delay_range = delay_range
self._results: list[CrawlResult] = []
def crawl(self, urls: list[str]) -> list[CrawlResult]:
"""模板方法:完整爬取流程"""
self.setup() # 钩子:初始化(如登录、设置 headers)
for i, url in enumerate(urls):
print(f"\n[{i+1}/{len(urls)}] 处理: {url}")
try:
html = self.fetch(url) # 抽象:发起请求
if not self.is_valid(html): # 钩子:校验响应
print(f" ⚠ 跳过无效响应")
continue
result = self.parse(html, url) # 抽象:解析内容
result = self.transform(result) # 钩子:数据转换
self.store(result) # 抽象:存储结果
self._results.append(result)
except Exception as e:
print(f" ✗ 抓取失败: {e}")
self._results.append(CrawlResult(url=url, title='', error=str(e)))
finally:
self._rate_limit() # 模板内置:限速(子类不可跳过)
self.teardown() # 钩子:清理资源
return self._results
# ===== 必须由子类实现 =====
@abstractmethod
def fetch(self, url: str) -> str:
"""发起 HTTP 请求,返回响应内容"""
pass
@abstractmethod
def parse(self, html: str, url: str) -> CrawlResult:
"""解析 HTML,提取目标数据"""
pass
@abstractmethod
def store(self, result: CrawlResult) -> None:
"""存储爬取结果"""
pass
# ===== 钩子方法:子类可选择覆盖 =====
def setup(self) -> None:
print(f"[{self.__class__.__name__}] 爬虫初始化完成")
def teardown(self) -> None:
print(f"\n[{self.__class__.__name__}] 爬取完成,共 {len(self._results)} 条结果")
def is_valid(self, html: str) -> bool:
"""默认校验:响应不为空"""
return bool(html and len(html) > 10)
def transform(self, result: CrawlResult) -> CrawlResult:
"""默认转换:不做任何处理"""
return result
def _rate_limit(self) -> None:
"""内置限速:子类无法绕过(封装在模板方法内)"""
delay = random.uniform(*self.delay_range)
print(f" → 限速等待 {delay:.1f}s")
time.sleep(delay * 0.01) # 演示时缩短延迟
class NewsArticleCrawler(BaseCrawler):
"""新闻文章爬虫"""
def __init__(self):
super().__init__(delay_range=(2.0, 5.0))
self._db: list = []
def setup(self) -> None:
print("[新闻爬虫] 加载 User-Agent 池,设置代理...")
super().setup()
def fetch(self, url: str) -> str:
# 模拟 HTTP 请求
print(f" ↓ GET {url}")
if 'error' in url:
raise ConnectionError(f"无法连接: {url}")
return f"<html><title>新闻:{url.split('/')[-1]}</title><p>文章内容...</p></html>"
def parse(self, html: str, url: str) -> CrawlResult:
# 模拟解析
import re
title_match = re.search(r'<title>(.*?)</title>', html)
title = title_match.group(1) if title_match else 'Unknown'
return CrawlResult(
url=url,
title=title,
items=[{'content': '文章正文段落...', 'word_count': 500}]
)
def transform(self, result: CrawlResult) -> CrawlResult:
"""新闻特有转换:标题去除前缀"""
result.title = result.title.replace('新闻:', '').strip()
return result
def store(self, result: CrawlResult) -> None:
self._db.append(result)
print(f" ✓ 存入数据库: 《{result.title}》")
class ProductCrawler(BaseCrawler):
"""商品信息爬虫"""
def __init__(self, output_file: str):
super().__init__(delay_range=(1.0, 2.0))
self.output_file = output_file
self._products: list = []
def fetch(self, url: str) -> str:
print(f" ↓ 请求商品页: {url}")
sku = url.split('/')[-1]
return json.dumps({
'sku': sku,
'name': f'商品{sku}',
'price': round(random.uniform(10, 999), 2),
'stock': random.randint(0, 1000),
'rating': round(random.uniform(3.0, 5.0), 1)
})
def is_valid(self, html: str) -> bool:
"""商品爬虫:校验 JSON 格式"""
try:
data = json.loads(html)
return 'sku' in data and data.get('stock', 0) > 0
except Exception:
return False
def parse(self, html: str, url: str) -> CrawlResult:
data = json.loads(html)
return CrawlResult(
url=url,
title=data['name'],
items=[{
'price': data['price'],
'stock': data['stock'],
'rating': data['rating']
}]
)
def store(self, result: CrawlResult) -> None:
if result.items:
item = result.items[0]
self._products.append({**item, 'name': result.title})
print(f" ✓ 记录商品: {result.title} "
f"¥{item['price']} 库存:{item['stock']}")
def teardown(self) -> None:
# 爬取结束后批量写文件
print(f"\n[商品爬虫] 导出 {len(self._products)} 条商品到 {self.output_file}")
super().teardown()
# ===== 运行演示 =====
news_urls = [
'https://news.example.com/tech/python-3-13',
'https://news.example.com/ai/gpt-update',
'https://news.example.com/error-page', # 模拟失败
]
product_urls = [f'https://shop.example.com/products/{i}' for i in [101, 102, 103]]
print("【新闻爬虫】")
news_crawler = NewsArticleCrawler()
news_results = news_crawler.crawl(news_urls)
print("\n" + "=" * 60)
print("【商品爬虫】")
product_crawler = ProductCrawler('products_export.json')
product_crawler.crawl(product_urls)
五、实战案例三:报表生成系统
报表生成是另一个模板方法的高频场景------数据查询→处理→渲染→输出,骨架固定:
python
from abc import ABC, abstractmethod
from typing import Any
class ReportGenerator(ABC):
"""报表生成抽象基类"""
def generate(self, params: dict) -> str:
"""模板方法:报表生成全流程"""
print(f"\n[{self.report_name}] 开始生成报表...")
# 1. 参数校验
self._validate_params(params)
# 2. 查询数据
raw_data = self.query_data(params)
print(f" 数据查询完成: {len(raw_data)} 条原始记录")
# 3. 数据聚合
aggregated = self.aggregate(raw_data)
# 4. 应用过滤(钩子)
if self.apply_filter(params):
aggregated = self.filter_data(aggregated, params)
# 5. 渲染输出
output = self.render(aggregated, params)
# 6. 后处理(钩子)
output = self.post_process(output)
print(f" ✓ 报表生成完成")
return output
@property
@abstractmethod
def report_name(self) -> str:
pass
@abstractmethod
def query_data(self, params: dict) -> list:
pass
@abstractmethod
def aggregate(self, data: list) -> dict:
pass
@abstractmethod
def render(self, data: dict, params: dict) -> str:
pass
# 钩子方法
def _validate_params(self, params: dict) -> None:
if 'start_date' not in params or 'end_date' not in params:
raise ValueError("缺少必要参数: start_date, end_date")
def apply_filter(self, params: dict) -> bool:
return 'filter' in params
def filter_data(self, data: dict, params: dict) -> dict:
return data # 默认不过滤
def post_process(self, output: str) -> str:
return output # 默认不处理
class SalesReport(ReportGenerator):
"""销售报表"""
@property
def report_name(self) -> str:
return "销售月报"
def query_data(self, params: dict) -> list:
# 模拟查询销售数据库
return [
{'product': 'A', 'region': '华北', 'amount': 15000, 'qty': 120},
{'product': 'B', 'region': '华南', 'amount': 28000, 'qty': 85},
{'product': 'A', 'region': '华南', 'amount': 9500, 'qty': 60},
{'product': 'C', 'region': '华北', 'amount': 42000, 'qty': 200},
{'product': 'B', 'region': '华东', 'amount': 18000, 'qty': 95},
]
def aggregate(self, data: list) -> dict:
total = sum(row['amount'] for row in data)
by_product = {}
by_region = {}
for row in data:
by_product[row['product']] = by_product.get(row['product'], 0) + row['amount']
by_region[row['region']] = by_region.get(row['region'], 0) + row['amount']
return {
'total_amount': total,
'by_product': by_product,
'by_region': by_region,
'records': len(data)
}
def filter_data(self, data: dict, params: dict) -> dict:
"""按区域过滤"""
target_region = params.get('filter')
# 实际项目中重新聚合,此处简化
print(f" 应用过滤: 区域 = {target_region}")
return data
def render(self, data: dict, params: dict) -> str:
lines = [
f"{'=' * 40}",
f" 销售月报 {params['start_date']} ~ {params['end_date']}",
f"{'=' * 40}",
f" 总销售额: ¥{data['total_amount']:,.0f}",
f" 记录条数: {data['records']} 笔",
f"\n 按产品分布:",
]
for product, amount in sorted(data['by_product'].items(),
key=lambda x: -x[1]):
pct = amount / data['total_amount'] * 100
lines.append(f" 产品{product}: ¥{amount:>8,.0f} ({pct:.1f}%)")
lines.append(f"\n 按区域分布:")
for region, amount in sorted(data['by_region'].items(),
key=lambda x: -x[1]):
lines.append(f" {region}: ¥{amount:>8,.0f}")
lines.append(f"{'=' * 40}")
return '\n'.join(lines)
def post_process(self, output: str) -> str:
"""添加生成时间戳"""
return output + f"\n 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 使用
report = SalesReport()
output = report.generate({
'start_date': '2025-01-01',
'end_date': '2025-01-31',
})
print(output)
六、Python 进阶技巧:让模板方法更 Pythonic
6.1 用 __init_subclass__ 防止模板方法被意外覆盖
python
class StrictTemplate(ABC):
"""严格模板:防止子类覆盖核心模板方法"""
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# 检查子类是否错误地覆盖了模板方法
protected = ['process', 'execute', 'run']
for method in protected:
if method in cls.__dict__:
raise TypeError(
f"子类 {cls.__name__} 不允许覆盖模板方法 '{method}'!"
f"请覆盖对应的抽象步骤方法。"
)
class MyProcessor(StrictTemplate):
def process(self): # 这会在类定义时立即报错
pass
# TypeError: 子类 MyProcessor 不允许覆盖模板方法 'process'!
6.2 用 Mixin 组合多个「局部模板」
python
class LoggingMixin:
"""日志 Mixin:为任何处理器添加日志能力"""
def process(self):
print(f"[LOG] {self.__class__.__name__} 开始执行")
result = super().process()
print(f"[LOG] {self.__class__.__name__} 执行完成")
return result
class TimingMixin:
"""计时 Mixin:为任何处理器添加计时能力"""
def process(self):
start = time.perf_counter()
result = super().process()
elapsed = time.perf_counter() - start
print(f"[TIMING] 总耗时: {elapsed:.3f}s")
return result
# 多重继承组合:Mixin 增强 + 模板方法骨架
class EnhancedCSVProcessor(LoggingMixin, TimingMixin, CSVDataProcessor):
"""带日志和计时的 CSV 处理器"""
pass
# Python MRO 保证调用顺序:LoggingMixin → TimingMixin → CSVDataProcessor
processor = EnhancedCSVProcessor('data.csv', email='report@example.com')
processor.process()
6.3 结合 dataclass 简化配置传递
python
from dataclasses import dataclass
@dataclass
class ProcessorConfig:
source: str
output_dir: str = './output'
validate: bool = True
notify_email: str = ''
batch_size: int = 1000
class ConfigurableProcessor(DataProcessor):
"""通过配置对象控制模板方法行为"""
def __init__(self, config: ProcessorConfig):
self.config = config
def should_validate(self) -> bool:
return self.config.validate # 由配置决定
def should_notify(self) -> bool:
return bool(self.config.notify_email) # 由配置决定
def read_data(self) -> list:
print(f"从 {self.config.source} 批量读取, 批次大小: {self.config.batch_size}")
return []
def clean_data(self, data: list) -> list: return data
def analyze(self, data: list) -> dict: return {}
def save_result(self, result: dict) -> None:
print(f"保存到 {self.config.output_dir}")
6.4 单元测试:聚焦每个步骤
模板方法模式让测试极为简单------可以单独测试每个抽象步骤:
python
import pytest
from unittest.mock import patch, MagicMock
class TestCSVDataProcessor:
def setup_method(self):
self.processor = CSVDataProcessor('test.csv')
def test_clean_data_converts_types(self):
raw = [{'name': '张三', 'age': '25', 'score': '88.5'}]
clean = self.processor.clean_data(raw)
assert clean[0]['age'] == 25
assert clean[0]['score'] == 88.5
def test_validate_filters_null_scores(self):
data = [
{'name': 'A', 'score': 90.0},
{'name': 'B', 'score': None},
{'name': 'C', 'score': 75.0},
]
valid = self.processor.validate(data)
assert len(valid) == 2
assert all(row['score'] is not None for row in valid)
def test_analyze_computes_correct_stats(self):
data = [{'score': 80.0}, {'score': 90.0}, {'score': 100.0}]
result = self.processor.analyze(data)
assert result['count'] == 3
assert result['avg_score'] == 90.0
def test_should_notify_with_email(self):
processor_with_email = CSVDataProcessor('test.csv', email='a@b.com')
processor_no_email = CSVDataProcessor('test.csv')
assert processor_with_email.should_notify() is True
assert processor_no_email.should_notify() is False
def test_full_process_calls_steps_in_order(self):
"""测试模板方法调用顺序"""
call_log = []
self.processor.read_data = lambda: (call_log.append('read'), [])[1]
self.processor.clean_data = lambda d: (call_log.append('clean'), d)[1]
self.processor.analyze = lambda d: (call_log.append('analyze'), {})[1]
self.processor.save_result = lambda r: call_log.append('save')
self.processor.process()
assert call_log == ['read', 'clean', 'analyze', 'save']
七、最佳实践与常见陷阱
陷阱一:模板方法中的步骤太多。当模板方法调用超过 7~8 个步骤时,说明类的职责过重,应考虑将部分步骤提取为子模板或辅助方法。
陷阱二:钩子方法过于复杂。钩子的默认实现应简单(空操作或返回简单布尔值),复杂的默认行为应封装为独立方法,通过钩子调用。
陷阱三:抽象步骤设计过于细碎。不是每一行代码都需要成为抽象步骤,只有真正需要子类差异化实现的步骤才应抽象化。
💡 黄金原则: 好的模板方法应该像一份食谱------步骤清晰、顺序固定、关键环节留给厨师发挥,而不是每一克调料都规定死。
八、总结
模板方法模式是「好莱坞原则」(Don't call us, we'll call you)的完美体现------父类掌控整体流程,子类只需响应父类的「召唤」实现特定步骤。
回顾本文的核心要点:三类方法(模板方法、抽象步骤、钩子方法)各司其职;抽象步骤保证子类实现完整性,钩子方法提供灵活的条件分支;结合 Mixin 可实现模块化的行为组合;配合单元测试,每个步骤可独立验证;Python 的 abc 模块是实现此模式的天然利器。
模板方法模式最适合的场景正是那些「整体流程固定,局部步骤各异」的业务------数据处理管道、报表生成、Web 爬虫、测试框架、游戏关卡逻辑......在这些场景中,它能帮你消除重复、统一约束,让代码像一首乐曲:主旋律固定,每个乐器演奏自己的声部。
你在项目中是否用过模板方法模式?有没有遇到「骨架设计太死,子类施展不开」的困境?欢迎在评论区分享你的设计取舍和实战心得,让我们一起打磨更优雅的 Python 架构。
参考资料
- 《Design Patterns: Elements of Reusable Object-Oriented Software》- GoF,第 325 页模板方法
- 《Head First 设计模式》- Freeman,第八章「封装算法」
- 《流畅的Python》第二版 - Luciano Ramalho,第 11 章「接口:从协议到 ABC」
- Python
abc官方文档:https://docs.python.org/3/library/abc.html - Refactoring Guru 模板方法:https://refactoring.guru/design-patterns/template-method/python/example
- PEP 3119 -- Introducing Abstract Base Classes:https://peps.python.org/pep-3119/