【Python进阶】第2篇:单元测试

单元测试是专业开发的基石,它不仅能捕获bug,还能作为代码的活文档,并支持代码重构。今天,我们将深入探讨Python中的单元测试实践,让你的代码更加可靠和健壮。

一、为什么需要单元测试?​

单元测试的价值:​​

  • 早期发现bug​:在代码部署前发现问题
  • 支持重构​:修改代码时有安全网
  • 文档作用​:测试用例本身就是最好的使用文档
  • 设计指导:编写可测试的代码通常意味着更好的设计
  • 持续集成​:自动化测试是CI/CD流程的核心

没有测试的代价:​

python 复制代码
# 一个简单的计算器函数
def divide(a, b):
    return a / b

# 看起来没问题,但如果b=0呢?
result = divide(10, 0)  # ZeroDivisionError!

二、Python测试框架:unittest vs pytest​

Python有两个主流的测试框架:

  • unittest:Python标准库的一部分,基于Java的JUnit
  • pytest:第三方框架,更Pythonic,功能更强大

我们将重点介绍pytest,因为它更简洁、强大,是当前Python社区的主流选择。

三、安装和配置pytest

bash 复制代码
# 安装pytest
pip install pytest

# 安装常用插件
pip install pytest-cov    # 测试覆盖率
pip install pytest-mock   # Mock对象支持

四、第一个测试用例

待测试的代码(calculator.py:​

python 复制代码
# calculator.py
class Calculator:
    """简单的计算器类"""
    
    def add(self, a, b):
        """加法"""
        return a + b
    
    def subtract(self, a, b):
        """减法"""
        return a - b
    
    def multiply(self, a, b):
        """乘法"""
        return a * b
    
    def divide(self, a, b):
        """除法"""
        if b == 0:
            raise ValueError("除数不能为零")
        return a / b
    
    def power(self, base, exponent):
        """幂运算"""
        return base ** exponent
    
    def is_even(self, number):
        """判断是否为偶数"""
        return number % 2 == 0

测试代码(test_calculator.py):​

python 复制代码
# test_calculator.py
import pytest
from calculator import Calculator

class TestCalculator:
    """Calculator类的测试用例"""
    
    def setup_method(self):
        """每个测试方法前执行"""
        self.calc = Calculator()
    
    def test_add(self):
        """测试加法"""
        result = self.calc.add(2, 3)
        assert result == 5
    
    def test_subtract(self):
        """测试减法"""
        result = self.calc.subtract(10, 4)
        assert result == 6
    
    def test_multiply(self):
        """测试乘法"""
        result = self.calc.multiply(3, 4)
        assert result == 12
    
    def test_divide(self):
        """测试除法"""
        result = self.calc.divide(10, 2)
        assert result == 5
    
    def test_divide_by_zero(self):
        """测试除零异常"""
        with pytest.raises(ValueError) as exc_info:
            self.calc.divide(10, 0)
        
        assert "除数不能为零" in str(exc_info.value)
    
    def test_power(self):
        """测试幂运算"""
        result = self.calc.power(2, 3)
        assert result == 8
    
    def test_is_even(self):
        """测试偶数判断"""
        assert self.calc.is_even(4) is True
        assert self.calc.is_even(5) is False

运行测试:​

bash 复制代码
# 运行所有测试
pytest

# 运行特定文件
pytest test_calculator.py

# 运行特定类
pytest test_calculator.py::TestCalculator

# 运行特定方法
pytest test_calculator.py::TestCalculator::test_add

# 显示详细输出
pytest -v

# 遇到失败时停止
pytest -x

五、pytest的高级特性

5.1 参数化测试

python 复制代码
import pytest

class TestCalculatorParametrized:
    """参数化测试示例"""
    
    def setup_method(self):
        self.calc = Calculator()
    
    @pytest.mark.parametrize("a,b,expected", [
        (1, 1, 2),
        (0, 0, 0),
        (-1, 1, 0),
        (100, 200, 300),
    ])
    def test_add_parametrized(self, a, b, expected):
        """参数化测试加法"""
        result = self.calc.add(a, b)
        assert result == expected
    
    @pytest.mark.parametrize("a,b,expected,exception", [
        (10, 2, 5, None),
        (10, 0, None, ValueError),
        (0, 5, 0, None),
    ])
    def test_divide_parametrized(self, a, b, expected, exception):
        """参数化测试除法(包含正常和异常情况)"""
        if exception:
            with pytest.raises(exception):
                self.calc.divide(a, b)
        else:
            result = self.calc.divide(a, b)
            assert result == expected

5.2 夹具(Fixtures)​

夹具用于设置测试环境,可以在多个测试间共享。

python 复制代码
import pytest

@pytest.fixture
def calculator():
    """提供Calculator实例的夹具"""
    return Calculator()

@pytest.fixture
def sample_data():
    """提供测试数据的夹具"""
    return {
        'numbers': [1, 2, 3, 4, 5],
        'positive_numbers': [10, 20, 30],
        'negative_numbers': [-1, -2, -3]
    }

def test_with_fixtures(calculator, sample_data):
    """使用夹具的测试"""
    result = calculator.add(sample_data['numbers'][0], sample_data['numbers'][1])
    assert result == 3

@pytest.fixture
def database_connection():
    """模拟数据库连接的夹具"""
    print("建立数据库连接")
    connection = "模拟数据库连接"
    yield connection  # 测试执行时使用这个值
    print("关闭数据库连接")  # 清理操作

def test_database_operation(database_connection):
    """使用带清理的夹具"""
    assert database_connection == "模拟数据库连接"
    # 测试结束后会自动执行清理代码

5.3 标记(Markers)​

python 复制代码
class TestMarkedCalculator:
    """使用标记的测试"""
    
    @pytest.mark.slow
    def test_slow_operation(self):
        """标记为慢测试"""
        # 模拟耗时操作
        import time
        time.sleep(2)
        assert True
    
    @pytest.mark.skip(reason="功能尚未实现")
    def test_skipped(self):
        """跳过测试"""
        assert False
    
    @pytest.mark.skipif(True, reason="条件跳过")
    def test_conditional_skip(self):
        """条件跳过"""
        assert False
    
    @pytest.mark.xfail
    def test_expected_failure(self):
        """预期会失败的测试"""
        assert False  # 我们知道这个测试目前会失败

六、Mock对象:隔离测试

当测试一个模块时,我们经常需要模拟(mock)它的依赖项。
待测试的代码(payment.py

python 复制代码
# payment.py
class PaymentProcessor:
    """支付处理器"""
    
    def process_payment(self, amount, card_number):
        """处理支付(依赖外部API)"""
        # 这里会调用真实的支付网关
        # 在测试中我们不想真的扣款!
        response = self._call_payment_gateway(amount, card_number)
        return response['success']
    
    def _call_payment_gateway(self, amount, card_number):
        """调用支付网关(需要被mock)"""
        # 实际实现会调用外部API
        return {'success': True, 'transaction_id': '12345'}

class Order:
    """订单类"""
    
    def __init__(self, payment_processor):
        self.payment_processor = payment_processor
        self.paid = False
    
    def checkout(self, amount, card_number):
        """结账"""
        success = self.payment_processor.process_payment(amount, card_number)
        if success:
            self.paid = True
        return success

使用mock的测试:​

python 复制代码
# test_payment.py
import pytest
from unittest.mock import Mock, patch
from payment import PaymentProcessor, Order

class TestPayment:
    """支付相关测试"""
    
    def test_order_checkout_success(self):
        """测试订单支付成功"""
        # 创建mock支付处理器
        mock_processor = Mock(spec=PaymentProcessor)
        mock_processor.process_payment.return_value = True
        
        order = Order(mock_processor)
        result = order.checkout(100.0, "1234-5678-9012-3456")
        
        assert result is True
        assert order.paid is True
        mock_processor.process_payment.assert_called_once_with(100.0, "1234-5678-9012-3456")
    
    def test_order_checkout_failure(self):
        """测试订单支付失败"""
        mock_processor = Mock(spec=PaymentProcessor)
        mock_processor.process_payment.return_value = False
        
        order = Order(mock_processor)
        result = order.checkout(100.0, "1234-5678-9012-3456")
        
        assert result is False
        assert order.paid is False
    
    @patch('payment.PaymentProcessor._call_payment_gateway')
    def test_process_payment_with_patch(self, mock_gateway):
        """使用patch装饰器mock方法"""
        mock_gateway.return_value = {'success': True, 'transaction_id': 'mock123'}
        
        processor = PaymentProcessor()
        result = processor.process_payment(100.0, "1234-5678-9012-3456")
        
        assert result is True
        mock_gateway.assert_called_once_with(100.0, "1234-5678-9012-3456")

七、测试覆盖率​(pytest-cov)

测试覆盖率衡量有多少代码被测试覆盖。

运行覆盖率测试:​

bash 复制代码
# 基本覆盖率
pytest --cov=myproject

# 生成HTML报告
pytest --cov=myproject --cov-report=html

# 指定最小覆盖率阈值
pytest --cov=myproject --cov-fail-under=80

覆盖率配置文件(.coveragerc):​

ini 复制代码
[run]
source = myproject
omit = 
    */tests/*
    */migrations/*
    */__pycache__/*
    */venv/*

[report]
exclude_lines =
    pragma: no cover
    def __repr__
    if self\.debug
    raise AssertionError
    raise NotImplementedError
    if 0:
    if __name__ == .__main__.:

详解每个模式:​​

  1. pragma: no cover

    含义 :​​ 匹配包含注释 # pragma: no cover的行。

    作用:​​ 这是最常用、最推荐的方式。你可以在任何不想测量的代码行后面添加这个注释,coverage.py就会忽略该行。它提供了最精确的控制。

  2. def __repr__

    含义 :​​ 匹配定义 __repr__方法的行(如 def __repr__(self):)。
    作用 :​​ __repr__方法通常用于调试输出,其正确性通常由代码审查或开发过程中的使用来保证,而非专门的单元测试。排除它们可以提高覆盖率百分比,更关注核心逻辑。

  3. if self\.debug
    ​含义 :​​ 匹配包含 if self.debug:或类似写法的行(反斜杠 \ 用于转义点.,使其匹配字面点)。

    作用:​​ 排除调试分支。这些分支通常只在开发时启用,在生产或测试环境中不会执行。测试它们可能没有意义或很困难。

  4. raise AssertionError

    含义 :​​ 匹配显式抛出 AssertionError异常的行(如 raise AssertionError("message"))。

    作用 :​​ 这些通常是代码中的断言(assert语句在失败时也会抛出 AssertionError)。有时开发者会直接 raise AssertionError来表示"不应该到达这里"的逻辑。测试这些行通常意味着强制触发错误条件,可能比较麻烦或意义不大。

  5. raise NotImplementedError

    含义 :​​ 匹配显式抛出 NotImplementedError异常的行。

    作用:​​ 表示某个方法或功能在基类或抽象类中尚未实现,需要子类重写。测试基类时,这些行本身就不应该被执行(因为期望子类覆盖它们),排除它们是合理的。

  6. if 0:
    ​含义 :​​ 匹配包含 if 0:的行。
    ​作用 :​​ 这是一种常见的临时禁用代码块的方法(将 if 1:改为 if 0:)。这些代码块永远不会执行,排除它们避免拉低覆盖率。

  7. if __name__ == .__main__.:

    含义 :​​ 匹配包含 if __name__ == '__main__':的行(注意:配置文件中的单引号 '被转义或表示为 .可能是为了适应配置文件解析或正则表达式,但在标准 .coveragerc中,正则表达式需要用引号括起来或正确转义。这里 .__main__.可能是为了匹配 '__main__'字符串)。

    作用 :​​ 排除模块作为主程序运行时的入口代码块(if __name__ == '__main__':下面的代码)。这些代码通常不是模块核心功能的一部分,在导入模块时不会执行,通常也不需要单元测试来覆盖(它们可能包含简单的演示或命令行调用)。

相关推荐
唐叔在学习4 小时前
200kb能作甚?mss表示我给你整个截图程序
后端·python
智能化咨询4 小时前
Python 小工具实战:图片水印批量添加工具——从原理到实现的全流程指南
python
用户3721574261354 小时前
如何使用 Python 自动调整 Excel 行高和列宽
python
用户8356290780514 小时前
用Python自动化转换PowerPoint幻灯片为图片
后端·python
yugi9878384 小时前
基于Qt实现百度地图路径规划功能
开发语言·qt
小年糕是糕手4 小时前
【数据结构】队列“0”基础知识讲解 + 实战演练
c语言·开发语言·数据结构·c++·学习·算法
程序员爱钓鱼4 小时前
Python编程实战 · 基础入门篇 | 推导式(列表推导式 / 字典推导式)
后端·python
无限进步_4 小时前
【C语言】函数指针数组:从条件分支到转移表的优雅进化
c语言·开发语言·数据结构·后端·算法·visual studio