文章目录
-
- 一、运算符重载的工程动机
- 二、算术运算符:完整的反射与就地版本
-
- [2.1 二维向量的完整实现](#2.1 二维向量的完整实现)
- [三、比较运算符:`functools.total_ordering` 的工程价值](#三、比较运算符:
functools.total_ordering的工程价值) - 四、容器协议:`getitem`、`contains`与序列语义
-
- [4.1 最小可迭代实现](#4.1 最小可迭代实现)
- [4.2 有序列表:完整的可变序列实现](#4.2 有序列表:完整的可变序列实现)
- [五、逻辑运算符 DSL:查询构建器](#五、逻辑运算符 DSL:查询构建器)
- 六、`call`:让类实例像函数一样被调用
- 七、运算符重载的设计原则
- 八、与内建类型的互操作
- 小结
一、运算符重载的工程动机
运算符重载在 C++ 社区曾饱受争议------过度使用会让代码变成谜语。Python 采用了更保守的设计:运算符重载必须通过双下划线方法显式声明,而不是随意改写语法含义。
工程中真正需要运算符重载的场景,通常满足以下几个条件之一:
- 数学对象:向量、矩阵、多项式、分数------这些对象在数学中本来就有运算符语义,代码应当与数学符号一致
- 领域专用语言(DSL) :SQLAlchemy 的
User.age > 18、Django Q 对象的Q(a=1) | Q(b=2)------用运算符表达查询条件,远比方法链清晰 - 容器与集合 :自定义集合类实现
in、[]、len()操作,保持与内建容器一致的接口 - 配置与构建器模式 :
pipeline = step1 | step2 | step3通过|组合处理步骤
本篇以三个完整案例递进展示:二维向量 (算术运算符)、有序列表 (容器协议)、查询构建器(逻辑运算符 DSL)。
二、算术运算符:完整的反射与就地版本
Python 的算术运算符重载有三组方法,以加法为例:
| 方法组 | 对应语法 | 触发时机 |
|---|---|---|
__add__(self, other) |
a + b |
a 是左操作数 |
__radd__(self, other) |
b + a |
a 是右操作数,且 b.__add__ 返回 NotImplemented |
__iadd__(self, other) |
a += b |
就地修改(可变对象),不修改时回退到 __add__ |
这套分发机制可以用一张流程图来描述:
非 NotImplemented
NotImplemented 或 AttributeError
是
否
非 NotImplemented
NotImplemented
a + b
调用 type(a).add(a, b)
返回值
返回结果
type(b) 是 type(a) 的子类?
优先调用 type(b).radd(b, a)
调用 type(b).radd(b, a)
返回值
抛出 TypeError
2.1 二维向量的完整实现
python
from __future__ import annotations
import math
from typing import Union
class Vector2D:
"""
二维向量,展示算术运算符重载的完整写法
支持向量加减、标量乘除、点积、取模、比较
"""
__slots__ = ("x", "y") # 优化内存,同时防止意外添加属性
def __init__(self, x: float, y: float) -> None:
self.x = float(x)
self.y = float(y)
# ── 字符串表示
def __repr__(self) -> str:
return f"Vector2D({self.x}, {self.y})"
def __format__(self, fmt_spec: str) -> str:
if fmt_spec == "c":
# 笛卡尔格式
return f"({self.x:.4g}, {self.y:.4g})"
if fmt_spec == "p":
# 极坐标格式
return f"<r={abs(self):.4g}, θ={math.degrees(self.angle):.2f}°>"
return format(repr(self), fmt_spec)
# ── 向量运算 ─────────────────────────────────────────────────
def __add__(self, other: Vector2D) -> Vector2D:
if not isinstance(other, Vector2D):
return NotImplemented
return Vector2D(self.x + other.x, self.y + other.y)
def __radd__(self, other: Vector2D) -> Vector2D:
# 交换律:b + a 等价于 a + b
return self.__add__(other)
def __iadd__(self, other: Vector2D) -> Vector2D:
if not isinstance(other, Vector2D):
return NotImplemented
# __slots__ 使对象不可变结构更清晰,这里返回新对象也符合惯例
self.x += other.x
self.y += other.y
return self
def __sub__(self, other: Vector2D) -> Vector2D:
if not isinstance(other, Vector2D):
return NotImplemented
return Vector2D(self.x - other.x, self.y - other.y)
def __rsub__(self, other: Vector2D) -> Vector2D:
# 注意:减法不满足交换律!b - a ≠ a - b
if not isinstance(other, Vector2D):
return NotImplemented
return Vector2D(other.x - self.x, other.y - self.y)
def __mul__(self, scalar: float) -> Vector2D:
# 向量 × 标量
if isinstance(scalar, (int, float)):
return Vector2D(self.x * scalar, self.y * scalar)
return NotImplemented
def __rmul__(self, scalar: float) -> Vector2D:
# 标量 × 向量(利用交换律)
return self.__mul__(scalar)
def __truediv__(self, scalar: float) -> Vector2D:
if not isinstance(scalar, (int, float)):
return NotImplemented
if scalar == 0:
raise ZeroDivisionError("向量不能除以零")
return Vector2D(self.x / scalar, self.y / scalar)
def __neg__(self) -> Vector2D:
return Vector2D(-self.x, -self.y)
def __pos__(self) -> Vector2D:
return Vector2D(self.x, self.y)
def __abs__(self) -> float:
# abs(v) 返回向量模长
return math.hypot(self.x, self.y)
# ── 点积(使用 @ 运算符,Python 3.5+)
def __matmul__(self, other: Vector2D) -> float:
if not isinstance(other, Vector2D):
return NotImplemented
return self.x * other.x + self.y * other.y
# ── 比较
def __eq__(self, other: object) -> bool:
if not isinstance(other, Vector2D):
return NotImplemented
return math.isclose(self.x, other.x) and math.isclose(self.y, other.y)
def __hash__(self) -> int:
# 浮点数哈希使用 round 截断,避免精度问题
return hash((round(self.x, 10), round(self.y, 10)))
def __lt__(self, other: Vector2D) -> bool:
# 按模长比较
if not isinstance(other, Vector2D):
return NotImplemented
return abs(self) < abs(other)
# ── 属性
@property
def angle(self) -> float:
return math.atan2(self.y, self.x)
@property
def normalized(self) -> Vector2D:
mag = abs(self)
if mag == 0:
raise ValueError("零向量无法归一化")
return self / mag
验证代码:
python
v1 = Vector2D(3, 4)
v2 = Vector2D(1, 2)
print(repr(v1)) # Vector2D(3.0, 4.0)
print(f"{v1:c}") # (3, 4)
print(f"{v1:p}") # <r=5, θ=53.13°>
print(abs(v1)) # 5.0
print(v1 + v2) # Vector2D(4.0, 6.0)
print(v1 - v2) # Vector2D(2.0, 2.0)
print(v1 * 2) # Vector2D(6.0, 8.0)
print(3 * v1) # Vector2D(9.0, 12.0) --- 触发 __rmul__
print(v1 @ v2) # 11.0(点积:3×1 + 4×2)
print(v1 < v2) # False(模长 5 > √5)
print(v1.normalized) # Vector2D(0.6, 0.8)
三、比较运算符:functools.total_ordering 的工程价值
完整实现六个比较运算符(__eq__、__ne__、__lt__、__le__、__gt__、__ge__)是重复劳动。functools.total_ordering 允许只实现 __eq__ 和任意一个排序运算符,自动推导其余四个:
python
from functools import total_ordering
from dataclasses import dataclass
@total_ordering
@dataclass
class Version:
"""
语义化版本号(SemVer),演示 total_ordering 的工程用法
仅需定义 __eq__ 和 __lt__,其余比较运算符自动生成
"""
major: int
minor: int
patch: int
@classmethod
def parse(cls, version_str: str) -> "Version":
parts = version_str.strip("v").split(".")
if len(parts) != 3:
raise ValueError(f"版本号格式错误:{version_str!r}")
return cls(*map(int, parts))
def __str__(self) -> str:
return f"v{self.major}.{self.minor}.{self.patch}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, Version):
return NotImplemented
return (self.major, self.minor, self.patch) == (
other.major, other.minor, other.patch
)
def __lt__(self, other: "Version") -> bool:
if not isinstance(other, Version):
return NotImplemented
return (self.major, self.minor, self.patch) < (
other.major, other.minor, other.patch
)
def __hash__(self) -> int:
return hash((self.major, self.minor, self.patch))
v1 = Version.parse("1.2.3")
v2 = Version.parse("1.10.0")
v3 = Version.parse("2.0.0")
print(v1 < v2) # True
print(v1 <= v2) # True(由 total_ordering 生成)
print(v2 > v1) # True(由 total_ordering 生成)
print(v3 >= v2) # True
print(sorted([v3, v1, v2])) # [v1.2.3, v1.10.0, v2.0.0]
total_ordering 的代价:生成的方法有一定性能开销(每次比较都需要额外的函数调用),对性能敏感的场景(如排序百万级对象),建议手动实现全部六个方法。
四、容器协议:__getitem__、__contains__与序列语义
4.1 最小可迭代实现
一个只实现了 __getitem__ 的对象,Python 可以自动对其进行迭代(通过依次传入 0, 1, 2... 直到 IndexError)------这是 Python 2 时代的遗留行为,现代代码推荐同时实现 __iter__:
python
class PaginatedResult:
"""
分页查询结果,演示最小序列协议
只实现 __getitem__ 和 __len__,Python 自动支持 for/in/切片
"""
def __init__(self, items: list, page_size: int = 10):
self._items = items
self.page_size = page_size
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index):
# 支持整数索引和切片
if isinstance(index, slice):
return PaginatedResult(self._items[index], self.page_size)
if isinstance(index, int):
if index < 0:
index += len(self) # 负索引支持
if not (0 <= index < len(self)):
raise IndexError(f"索引 {index} 超出范围 [0, {len(self)})")
return self._items[index]
raise TypeError(f"索引类型不支持:{type(index).__name__}")
def __contains__(self, item) -> bool:
# 显式实现 __contains__ 可以优化搜索(这里是线性,可根据业务优化)
return item in self._items
def __repr__(self) -> str:
return f"PaginatedResult({self._items!r})"
results = PaginatedResult([10, 20, 30, 40, 50])
print(results[2]) # 30
print(results[-1]) # 50
print(30 in results) # True
print(results[1:3]) # PaginatedResult([20, 30])
for item in results: # 自动迭代(通过 __getitem__)
print(item, end=" ") # 10 20 30 40 50
4.2 有序列表:完整的可变序列实现
python
from __future__ import annotations
import bisect
from typing import TypeVar, Generic, Iterator, Iterable
T = TypeVar("T")
class SortedList(Generic[T]):
"""
始终保持有序的列表,插入时自动排序
演示可变序列协议的完整实现
"""
def __init__(self, iterable: Iterable[T] = ()) -> None:
self._data: list[T] = sorted(iterable)
# ── 基础序列方法
def __len__(self) -> int:
return len(self._data)
def __getitem__(self, index):
return self._data[index]
def __iter__(self) -> Iterator[T]:
return iter(self._data)
def __reversed__(self) -> Iterator[T]:
return reversed(self._data)
def __contains__(self, item: object) -> bool:
# 利用二分查找,O(log n),比线性搜索更快
pos = bisect.bisect_left(self._data, item)
return pos < len(self._data) and self._data[pos] == item
# ── 修改操作
def add(self, item: T) -> None:
"""有序插入,O(log n) 查找 + O(n) 移位"""
bisect.insort(self._data, item)
def remove(self, item: T) -> None:
"""移除第一个匹配项,O(log n) 查找 + O(n) 移位"""
pos = bisect.bisect_left(self._data, item)
if pos >= len(self._data) or self._data[pos] != item:
raise ValueError(f"{item!r} 不在列表中")
del self._data[pos]
def __delitem__(self, index) -> None:
"""支持 del sl[index] 和 del sl[start:end]"""
del self._data[index]
# ── 集合操作(使用 | & - 运算符)
def __or__(self, other: SortedList[T]) -> SortedList[T]:
"""并集:合并两个有序列表(去重)"""
if not isinstance(other, SortedList):
return NotImplemented
result = SortedList()
# 归并两个已排序序列,O(m + n)
i = j = 0
while i < len(self) and j < len(other):
if self._data[i] < other._data[j]:
result._data.append(self._data[i]); i += 1
elif self._data[i] > other._data[j]:
result._data.append(other._data[j]); j += 1
else:
result._data.append(self._data[i]); i += 1; j += 1
result._data.extend(self._data[i:])
result._data.extend(other._data[j:])
return result
def __and__(self, other: SortedList[T]) -> SortedList[T]:
"""交集:两个有序列表的公共元素"""
if not isinstance(other, SortedList):
return NotImplemented
result = SortedList()
i = j = 0
while i < len(self) and j < len(other):
if self._data[i] < other._data[j]:
i += 1
elif self._data[i] > other._data[j]:
j += 1
else:
result._data.append(self._data[i]); i += 1; j += 1
return result
def __sub__(self, other: SortedList[T]) -> SortedList[T]:
"""差集:在 self 中但不在 other 中的元素"""
if not isinstance(other, SortedList):
return NotImplemented
return SortedList(x for x in self._data if x not in other)
# ── 字符串与比较
def __repr__(self) -> str:
return f"SortedList({self._data!r})"
def __eq__(self, other: object) -> bool:
if isinstance(other, SortedList):
return self._data == other._data
if isinstance(other, list):
return self._data == sorted(other)
return NotImplemented
# 验证
sl1 = SortedList([5, 3, 1, 4, 2])
sl2 = SortedList([4, 5, 6, 7])
print(sl1) # SortedList([1, 2, 3, 4, 5])
print(3 in sl1) # True(二分查找)
sl1.add(3.5)
print(sl1) # SortedList([1, 2, 3, 3.5, 4, 5])
print(sl1 | sl2) # SortedList([1, 2, 3, 3.5, 4, 5, 6, 7])
print(sl1 & sl2) # SortedList([4, 5])
print(sl1 - sl2) # SortedList([1, 2, 3, 3.5])
五、逻辑运算符 DSL:查询构建器
这是运算符重载最有工程价值的应用之一------用 &、|、~ 构建复合查询条件,而不必手写嵌套 and/or 逻辑:
python
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class Condition(ABC):
"""查询条件的抽象基类"""
@abstractmethod
def evaluate(self, record: dict) -> bool:
...
@abstractmethod
def __repr__(self) -> str:
...
# 运算符重载------构建复合条件
def __and__(self, other: Condition) -> AndCondition:
return AndCondition(self, other)
def __or__(self, other: Condition) -> OrCondition:
return OrCondition(self, other)
def __invert__(self) -> NotCondition:
# ~ 操作符(按位取反)用来表示"NOT"
return NotCondition(self)
class FieldCondition(Condition):
"""字段比较条件:field op value"""
_OPS = {
"eq": "=",
"ne": "≠",
"lt": "<",
"le": "≤",
"gt": ">",
"ge": "≥",
"contains": "∋",
"startswith": "startswith",
}
def __init__(self, field: str, op: str, value: Any) -> None:
if op not in self._OPS:
raise ValueError(f"不支持的操作符:{op!r}")
self.field = field
self.op = op
self.value = value
def evaluate(self, record: dict) -> bool:
record_value = record.get(self.field)
if record_value is None:
return False
match self.op:
case "eq": return record_value == self.value
case "ne": return record_value != self.value
case "lt": return record_value < self.value
case "le": return record_value <= self.value
case "gt": return record_value > self.value
case "ge": return record_value >= self.value
case "contains": return self.value in record_value
case "startswith": return str(record_value).startswith(str(self.value))
case _: return False
def __repr__(self) -> str:
return f"({self.field} {self._OPS[self.op]} {self.value!r})"
class AndCondition(Condition):
def __init__(self, left: Condition, right: Condition) -> None:
self.left = left
self.right = right
def evaluate(self, record: dict) -> bool:
return self.left.evaluate(record) and self.right.evaluate(record)
def __repr__(self) -> str:
return f"({self.left!r} AND {self.right!r})"
class OrCondition(Condition):
def __init__(self, left: Condition, right: Condition) -> None:
self.left = left
self.right = right
def evaluate(self, record: dict) -> bool:
return self.left.evaluate(record) or self.right.evaluate(record)
def __repr__(self) -> str:
return f"({self.left!r} OR {self.right!r})"
class NotCondition(Condition):
def __init__(self, condition: Condition) -> None:
self.condition = condition
def evaluate(self, record: dict) -> bool:
return not self.condition.evaluate(record)
def __repr__(self) -> str:
return f"(NOT {self.condition!r})"
# ── 便捷的字段构建器
class Field:
"""通过属性访问和运算符重载构建条件"""
def __init__(self, name: str):
self.name = name
def __eq__(self, other) -> FieldCondition: # type: ignore[override]
return FieldCondition(self.name, "eq", other)
def __ne__(self, other) -> FieldCondition: # type: ignore[override]
return FieldCondition(self.name, "ne", other)
def __lt__(self, other) -> FieldCondition:
return FieldCondition(self.name, "lt", other)
def __le__(self, other) -> FieldCondition:
return FieldCondition(self.name, "le", other)
def __gt__(self, other) -> FieldCondition:
return FieldCondition(self.name, "gt", other)
def __ge__(self, other) -> FieldCondition:
return FieldCondition(self.name, "ge", other)
def contains(self, value) -> FieldCondition:
return FieldCondition(self.name, "contains", value)
def startswith(self, value) -> FieldCondition:
return FieldCondition(self.name, "startswith", value)
# ── 使用演示
age = Field("age")
name = Field("name")
role = Field("role")
# 构建复合查询条件
query = (age >= 18) & (age < 60) & (role == "engineer")
print(repr(query))
# ((age ≥ 18) AND (age < 60)) AND (role = 'engineer')
# 也可以用 OR 和 NOT
query2 = (name.startswith("Li")) | ~(role == "intern")
print(repr(query2))
# ((name startswith 'Li') OR (NOT (role = 'intern')))
# 应用到数据集
records = [
{"name": "Li Wei", "age": 28, "role": "engineer"},
{"name": "Zhang San", "age": 17, "role": "intern"},
{"name": "Wang Fang", "age": 35, "role": "engineer"},
{"name": "Liu Yang", "age": 25, "role": "manager"},
]
results = [r for r in records if query.evaluate(r)]
print(results)
# [{'name': 'Li Wei', ...}, {'name': 'Wang Fang', ...}]
这种模式在 SQLAlchemy、Pandas query、Django ORM 中广泛使用,是运算符重载最有说服力的工程案例。
六、__call__:让类实例像函数一样被调用
__call__ 前文已在 #13 中介绍,这里补充一个更具工程深度的案例------有状态的函数替代品:
python
from __future__ import annotations
import time
from collections import deque
from typing import Callable, TypeVar
R = TypeVar("R")
class ThrottledFunction:
"""
节流函数:在 interval 秒内只允许执行一次
状态保存在实例中,而不是全局变量
支持嵌套装饰(每个被装饰的函数持有独立的 ThrottledFunction 实例)
"""
def __init__(self, func: Callable, interval: float) -> None:
self._func = func
self._interval = interval
self._last_call: float = 0.0
self._call_count = 0
# 保留函数元信息(functools.wraps 的效果)
self.__name__ = func.__name__
self.__doc__ = func.__doc__
self.__module__ = func.__module__
def __call__(self, *args, **kwargs):
now = time.monotonic()
if now - self._last_call < self._interval:
remaining = self._interval - (now - self._last_call)
raise RuntimeError(
f"{self.__name__} 节流中,"
f"还需等待 {remaining:.2f}s"
)
self._last_call = now
self._call_count += 1
return self._func(*args, **kwargs)
def __repr__(self) -> str:
return (
f"ThrottledFunction({self._func.__name__!r}, "
f"interval={self._interval}s, "
f"calls={self._call_count})"
)
# 让 ThrottledFunction 也可以作为描述符(支持方法绑定)
def __get__(self, obj, objtype=None):
if obj is None:
return self
# 绑定到实例,创建一个 bound method 式的可调用对象
import functools
return functools.partial(self.__call__)
def throttle(interval: float):
"""节流装饰器工厂"""
def decorator(func: Callable) -> ThrottledFunction:
return ThrottledFunction(func, interval)
return decorator
@throttle(interval=1.0)
def send_alert(message: str) -> None:
"""发送告警(每秒最多一次)"""
print(f"[ALERT] {message}")
send_alert("CPU 使用率过高") # 正常执行
try:
send_alert("内存不足") # 被节流,抛出 RuntimeError
except RuntimeError as e:
print(f"节流:{e}")
time.sleep(1.1)
send_alert("磁盘空间不足") # 冷却后正常执行
print(repr(send_alert)) # ThrottledFunction('send_alert', interval=1.0s, calls=2)
七、运算符重载的设计原则
在决定是否重载某个运算符之前,应当用以下几条原则来检验:
是
否
是
否
是
否
是
否
运算符在这个领域
有公认的语义?
重载后行为是否
满足代数性质?
(交换律/结合律等)
❌ 不建议重载
使用具名方法更清晰
返回 NotImplemented
当类型不兼容时?
⚠️ 需要文档说明
违反直觉的行为
是否考虑了
反射方法(radd 等)?
❌ 会导致混乱的
TypeError 信息
✅ 安全重载
⚠️ 可能无法与
其他类型协作
七条设计原则 (来源于 Fluent Python 及工程经验总结):
| 原则 | 说明 |
|---|---|
| 语义一致性 | 重载的运算符语义必须与数学或领域惯例一致,+ 不应当表示减法 |
| 类型安全 | 遇到不支持的类型时,返回 NotImplemented 而非抛出 TypeError |
| 对称性 | 实现 __add__ 的同时考虑 __radd__,让 a + b 和 b + a 在合理时等价 |
| 不可变性偏好 | 算术运算符应返回新对象,而非修改 self;就地操作用 __iadd__ 等 |
| 哈希一致性 | 定义 __eq__ 后必须同步考虑 __hash__,防止集合/字典行为异常 |
| total_ordering | 只需定义 __eq__ 和一个排序方法,@total_ordering 推导其余 |
| 类型注解 | 重载方法的返回值类型应精确标注(使用 from __future__ import annotations) |
八、与内建类型的互操作
运算符重载的一个常见需求是与 Python 内建类型(int、float、list)互操作。这时需要特别注意两点:
其一,子类优先 :如果 b 是 a 的子类,a + b 时 Python 会优先尝试 type(b).__radd__,让子类有机会覆盖父类的行为:
python
class ExtendedVector(Vector2D):
"""演示子类优先原则"""
def __radd__(self, other):
if isinstance(other, tuple) and len(other) == 2:
return ExtendedVector(other[0] + self.x, other[1] + self.y)
return super().__radd__(other)
v = ExtendedVector(1, 2)
result = (10, 20) + v # 触发 ExtendedVector.__radd__,因为 tuple.__add__ 返回 NotImplemented
# 如果 ExtendedVector 是 Vector2D 子类,Python 会优先调用 ExtendedVector.__radd__
其二,NumPy 的 __array_ufunc__ :与 NumPy 数组互操作时,需要实现 __array_ufunc__ 和 __array__ 方法,让 NumPy 知道如何处理自定义类型------这是数据模型在科学计算生态中的延伸。
小结
运算符重载不是语法糖,而是 Python 数据模型赋予自定义类型"平等公民权"的机制:
- 算术运算符三组 :正向(
__add__)、反射(__radd__)、就地(__iadd__)------三组共同保证类型安全的运算符分发 - 比较运算符 :
@total_ordering减少重复实现;__eq__与__hash__必须保持一致 - 容器协议 :
__getitem__、__setitem__、__delitem__、__contains__、__len__构成完整的序列/映射语义 - DSL 构建 :
&、|、~运算符重载是查询构建器的核心技术,SQLAlchemy/Pandas 均采用此模式 - 可调用对象 :
__call__让有状态的类实例充当函数,避免全局变量和闭包的限制 - 设计原则 :语义一致、返回
NotImplemented、考虑反射方法、不可变性偏好
如果这篇文章对运算符重载的理解有所帮助,欢迎点赞收藏支持------原创技术内容需要积累,每一个点赞都是持续输出的动力。
在实际项目中使用过运算符重载、或者踩过
__eq__忘记同步__hash__这类坑的,欢迎在评论区分享经验。关注账号不会错过 Python 进阶系列的后续内容。