单元测试是专业开发的基石,它不仅能捕获bug,还能作为代码的活文档,并支持代码重构。今天,我们将深入探讨Python中的单元测试实践,让你的代码更加可靠和健壮。
一、为什么需要单元测试?
单元测试的价值:
- 早期发现bug:在代码部署前发现问题
- 支持重构:修改代码时有安全网
- 文档作用:测试用例本身就是最好的使用文档
- 设计指导:编写可测试的代码通常意味着更好的设计
- 持续集成:自动化测试是CI/CD流程的核心
没有测试的代价:
python
# 一个简单的计算器函数
def divide(a, b):
return a / b
# 看起来没问题,但如果b=0呢?
result = divide(10, 0) # ZeroDivisionError!
二、Python测试框架:unittest vs pytest
Python有两个主流的测试框架:
- unittest:Python标准库的一部分,基于Java的JUnit
- pytest:第三方框架,更Pythonic,功能更强大
我们将重点介绍pytest,因为它更简洁、强大,是当前Python社区的主流选择。
三、安装和配置pytest
bash
# 安装pytest
pip install pytest
# 安装常用插件
pip install pytest-cov # 测试覆盖率
pip install pytest-mock # Mock对象支持
四、第一个测试用例
待测试的代码(calculator.py):
python
# calculator.py
class Calculator:
"""简单的计算器类"""
def add(self, a, b):
"""加法"""
return a + b
def subtract(self, a, b):
"""减法"""
return a - b
def multiply(self, a, b):
"""乘法"""
return a * b
def divide(self, a, b):
"""除法"""
if b == 0:
raise ValueError("除数不能为零")
return a / b
def power(self, base, exponent):
"""幂运算"""
return base ** exponent
def is_even(self, number):
"""判断是否为偶数"""
return number % 2 == 0
测试代码(test_calculator.py):
python
# test_calculator.py
import pytest
from calculator import Calculator
class TestCalculator:
"""Calculator类的测试用例"""
def setup_method(self):
"""每个测试方法前执行"""
self.calc = Calculator()
def test_add(self):
"""测试加法"""
result = self.calc.add(2, 3)
assert result == 5
def test_subtract(self):
"""测试减法"""
result = self.calc.subtract(10, 4)
assert result == 6
def test_multiply(self):
"""测试乘法"""
result = self.calc.multiply(3, 4)
assert result == 12
def test_divide(self):
"""测试除法"""
result = self.calc.divide(10, 2)
assert result == 5
def test_divide_by_zero(self):
"""测试除零异常"""
with pytest.raises(ValueError) as exc_info:
self.calc.divide(10, 0)
assert "除数不能为零" in str(exc_info.value)
def test_power(self):
"""测试幂运算"""
result = self.calc.power(2, 3)
assert result == 8
def test_is_even(self):
"""测试偶数判断"""
assert self.calc.is_even(4) is True
assert self.calc.is_even(5) is False
运行测试:
bash
# 运行所有测试
pytest
# 运行特定文件
pytest test_calculator.py
# 运行特定类
pytest test_calculator.py::TestCalculator
# 运行特定方法
pytest test_calculator.py::TestCalculator::test_add
# 显示详细输出
pytest -v
# 遇到失败时停止
pytest -x
五、pytest的高级特性
5.1 参数化测试
python
import pytest
class TestCalculatorParametrized:
"""参数化测试示例"""
def setup_method(self):
self.calc = Calculator()
@pytest.mark.parametrize("a,b,expected", [
(1, 1, 2),
(0, 0, 0),
(-1, 1, 0),
(100, 200, 300),
])
def test_add_parametrized(self, a, b, expected):
"""参数化测试加法"""
result = self.calc.add(a, b)
assert result == expected
@pytest.mark.parametrize("a,b,expected,exception", [
(10, 2, 5, None),
(10, 0, None, ValueError),
(0, 5, 0, None),
])
def test_divide_parametrized(self, a, b, expected, exception):
"""参数化测试除法(包含正常和异常情况)"""
if exception:
with pytest.raises(exception):
self.calc.divide(a, b)
else:
result = self.calc.divide(a, b)
assert result == expected
5.2 夹具(Fixtures)
夹具用于设置测试环境,可以在多个测试间共享。
python
import pytest
@pytest.fixture
def calculator():
"""提供Calculator实例的夹具"""
return Calculator()
@pytest.fixture
def sample_data():
"""提供测试数据的夹具"""
return {
'numbers': [1, 2, 3, 4, 5],
'positive_numbers': [10, 20, 30],
'negative_numbers': [-1, -2, -3]
}
def test_with_fixtures(calculator, sample_data):
"""使用夹具的测试"""
result = calculator.add(sample_data['numbers'][0], sample_data['numbers'][1])
assert result == 3
@pytest.fixture
def database_connection():
"""模拟数据库连接的夹具"""
print("建立数据库连接")
connection = "模拟数据库连接"
yield connection # 测试执行时使用这个值
print("关闭数据库连接") # 清理操作
def test_database_operation(database_connection):
"""使用带清理的夹具"""
assert database_connection == "模拟数据库连接"
# 测试结束后会自动执行清理代码
5.3 标记(Markers)
python
class TestMarkedCalculator:
"""使用标记的测试"""
@pytest.mark.slow
def test_slow_operation(self):
"""标记为慢测试"""
# 模拟耗时操作
import time
time.sleep(2)
assert True
@pytest.mark.skip(reason="功能尚未实现")
def test_skipped(self):
"""跳过测试"""
assert False
@pytest.mark.skipif(True, reason="条件跳过")
def test_conditional_skip(self):
"""条件跳过"""
assert False
@pytest.mark.xfail
def test_expected_failure(self):
"""预期会失败的测试"""
assert False # 我们知道这个测试目前会失败
六、Mock对象:隔离测试
当测试一个模块时,我们经常需要模拟(mock)它的依赖项。
待测试的代码(payment.py):
python
# payment.py
class PaymentProcessor:
"""支付处理器"""
def process_payment(self, amount, card_number):
"""处理支付(依赖外部API)"""
# 这里会调用真实的支付网关
# 在测试中我们不想真的扣款!
response = self._call_payment_gateway(amount, card_number)
return response['success']
def _call_payment_gateway(self, amount, card_number):
"""调用支付网关(需要被mock)"""
# 实际实现会调用外部API
return {'success': True, 'transaction_id': '12345'}
class Order:
"""订单类"""
def __init__(self, payment_processor):
self.payment_processor = payment_processor
self.paid = False
def checkout(self, amount, card_number):
"""结账"""
success = self.payment_processor.process_payment(amount, card_number)
if success:
self.paid = True
return success
使用mock的测试:
python
# test_payment.py
import pytest
from unittest.mock import Mock, patch
from payment import PaymentProcessor, Order
class TestPayment:
"""支付相关测试"""
def test_order_checkout_success(self):
"""测试订单支付成功"""
# 创建mock支付处理器
mock_processor = Mock(spec=PaymentProcessor)
mock_processor.process_payment.return_value = True
order = Order(mock_processor)
result = order.checkout(100.0, "1234-5678-9012-3456")
assert result is True
assert order.paid is True
mock_processor.process_payment.assert_called_once_with(100.0, "1234-5678-9012-3456")
def test_order_checkout_failure(self):
"""测试订单支付失败"""
mock_processor = Mock(spec=PaymentProcessor)
mock_processor.process_payment.return_value = False
order = Order(mock_processor)
result = order.checkout(100.0, "1234-5678-9012-3456")
assert result is False
assert order.paid is False
@patch('payment.PaymentProcessor._call_payment_gateway')
def test_process_payment_with_patch(self, mock_gateway):
"""使用patch装饰器mock方法"""
mock_gateway.return_value = {'success': True, 'transaction_id': 'mock123'}
processor = PaymentProcessor()
result = processor.process_payment(100.0, "1234-5678-9012-3456")
assert result is True
mock_gateway.assert_called_once_with(100.0, "1234-5678-9012-3456")
七、测试覆盖率(pytest-cov)
测试覆盖率衡量有多少代码被测试覆盖。
运行覆盖率测试:
bash
# 基本覆盖率
pytest --cov=myproject
# 生成HTML报告
pytest --cov=myproject --cov-report=html
# 指定最小覆盖率阈值
pytest --cov=myproject --cov-fail-under=80
覆盖率配置文件(.coveragerc):
ini
[run]
source = myproject
omit =
*/tests/*
*/migrations/*
*/__pycache__/*
*/venv/*
[report]
exclude_lines =
pragma: no cover
def __repr__
if self\.debug
raise AssertionError
raise NotImplementedError
if 0:
if __name__ == .__main__.:
详解每个模式:
-
pragma: no cover含义 : 匹配包含注释
# pragma: no cover的行。作用: 这是最常用、最推荐的方式。你可以在任何不想测量的代码行后面添加这个注释,coverage.py就会忽略该行。它提供了最精确的控制。
-
def __repr__含义 : 匹配定义
__repr__方法的行(如def __repr__(self):)。
作用 :__repr__方法通常用于调试输出,其正确性通常由代码审查或开发过程中的使用来保证,而非专门的单元测试。排除它们可以提高覆盖率百分比,更关注核心逻辑。 -
if self\.debug
含义 : 匹配包含if self.debug:或类似写法的行(反斜杠\用于转义点.,使其匹配字面点)。作用: 排除调试分支。这些分支通常只在开发时启用,在生产或测试环境中不会执行。测试它们可能没有意义或很困难。
-
raise AssertionError含义 : 匹配显式抛出 AssertionError异常的行(如
raise AssertionError("message"))。作用 : 这些通常是代码中的断言(
assert语句在失败时也会抛出AssertionError)。有时开发者会直接raise AssertionError来表示"不应该到达这里"的逻辑。测试这些行通常意味着强制触发错误条件,可能比较麻烦或意义不大。 -
raise NotImplementedError含义 : 匹配显式抛出
NotImplementedError异常的行。作用: 表示某个方法或功能在基类或抽象类中尚未实现,需要子类重写。测试基类时,这些行本身就不应该被执行(因为期望子类覆盖它们),排除它们是合理的。
-
if 0:
含义 : 匹配包含if 0:的行。
作用 : 这是一种常见的临时禁用代码块的方法(将if 1:改为if 0:)。这些代码块永远不会执行,排除它们避免拉低覆盖率。 -
if __name__ == .__main__.:含义 : 匹配包含
if __name__ == '__main__':的行(注意:配置文件中的单引号 '被转义或表示为 .可能是为了适应配置文件解析或正则表达式,但在标准 .coveragerc中,正则表达式需要用引号括起来或正确转义。这里.__main__.可能是为了匹配'__main__'字符串)。作用 : 排除模块作为主程序运行时的入口代码块(
if __name__ == '__main__':下面的代码)。这些代码通常不是模块核心功能的一部分,在导入模块时不会执行,通常也不需要单元测试来覆盖(它们可能包含简单的演示或命令行调用)。