Python案例如何做单元测试

wen python案例 54

Python案例如何做单元测试:从入门到实战的完整指南

目录导读

  1. 为什么需要单元测试?
  2. 准备工作:安装与基础框架
  3. 测试一个简单的计算函数
  4. 测试包含外部依赖的类
  5. 使用Mock模拟网络请求
  6. 测试异常与边界条件
  7. 数据驱动测试与参数化
  8. 常见问题FAQ
  9. 总结与最佳实践

为什么需要单元测试?

问:单元测试到底是什么?为什么这么多开发者强调它?

Python案例如何做单元测试

答:单元测试是针对程序中的最小可测试单元(通常是函数或方法)进行的自动化验证,在Python中,单元测试意味着你写出独立的测试代码,来检查某个函数的输入输出是否符合预期,它的核心价值包括:

  • 早期发现bug:在代码合并到主分支前就能捕获逻辑错误
  • 防止回归:修改代码后,只需运行测试就能知道是否破坏了已有功能
  • 文档化行为:测试代码本身就是一种“活文档”,展示了函数应如何使用
  • 提升设计:为了可测试性,你自然会写出更解耦、更模块化的代码

根据Stack Overflow 2023年调查,超过65%的专业开发者会定期编写单元测试,这不再是“可选项”,而是生产级代码的标配。

准备工作:安装与基础框架

Python内置的unittest模块是官方标准库,无需安装,但更流行的选择是pytest——它语法更简洁、支持更丰富的插件生态。

pip install pytest pytest-cov  # pytest-cov用于覆盖率统计

关键概念

  • TestCase:一组相关测试的集合
  • 断言assert语句或self.assertEqual()
  • 覆盖率:被测试代码的执行比例,一般要求≥80%

案例一:测试一个简单的计算函数

假设你有一个函数calculate_discount(price, discount_rate),我们需要验证其输出。

# code.py
def calculate_discount(price: float, discount_rate: float) -> float:
    """应用折扣,返回折扣后的价格"""
    if price < 0:
        raise ValueError("Price cannot be negative")
    if discount_rate < 0 or discount_rate > 1:
        raise ValueError("Discount rate must be between 0 and 1")
    return price * (1 - discount_rate)

对应的测试代码:

# test_code.py
import pytest
from code import calculate_discount
def test_normal_case():
    assert calculate_discount(100, 0.2) == 80.0
def test_no_discount():
    assert calculate_discount(100, 0) == 100.0
def test_full_discount():
    assert calculate_discount(100, 1) == 0.0
def test_negative_price():
    with pytest.raises(ValueError, match="Price cannot be negative"):
        calculate_discount(-10, 0.2)
def test_invalid_rate():
    with pytest.raises(ValueError, match="Discount rate must be between 0 and 1"):
        calculate_discount(100, 1.5)

运行pytest test_code.py -v
你会看到所有测试通过,并且错误路径也被覆盖。

案例二:测试包含外部依赖的类

实际项目中,很多类会依赖数据库、文件系统或其他服务,直接测试会很慢且不稳定,解决方案是依赖注入

# service.py
class UserService:
    def __init__(self, db_connection):
        self.db = db_connection
    def get_user_email(self, user_id):
        user = self.db.find_one({"id": user_id})
        if user is None:
            raise ValueError("User not found")
        return user.get("email")

测试时,使用一个假数据库实例:

# test_service.py
import pytest
from service import UserService
class FakeDB:
    def __init__(self, data):
        self.data = data
    def find_one(self, query):
        return self.data.get(query.get("id"))
def test_get_user_email_found():
    fake_db = FakeDB({1: {"id": 1, "email": "test@example.com"}})
    service = UserService(fake_db)
    assert service.get_user_email(1) == "test@example.com"
def test_get_user_email_not_found():
    fake_db = FakeDB({})
    service = UserService(fake_db)
    with pytest.raises(ValueError, match="User not found"):
        service.get_user_email(99)

这种模式让测试完全在内存中运行,无需启动真实数据库。

案例三:使用Mock模拟网络请求

当代码调用外部API时,我们不想真正发起网络请求(会慢、可能失败、有配额限制)。unittest.mock库可以模拟对象。

# fetcher.py
import requests
def fetch_weather(city):
    resp = requests.get(f"https://api.weather.com/{city}")
    if resp.status_code == 200:
        return resp.json()
    else:
        raise ConnectionError("API request failed")

测试中我们模拟requests.get

# test_fetcher.py
import pytest
from unittest.mock import patch, Mock
from fetcher import fetch_weather
@patch("fetcher.requests.get")
def test_fetch_weather_success(mock_get):
    mock_response = Mock()
    mock_response.status_code = 200
    mock_response.json.return_value = {"temp": 25, "city": "Beijing"}
    mock_get.return_value = mock_response
    result = fetch_weather("Beijing")
    assert result == {"temp": 25, "city": "Beijing"}
    mock_get.assert_called_once_with("https://api.weather.com/Beijing")
@patch("fetcher.requests.get")
def test_fetch_weather_failure(mock_get):
    mock_response = Mock()
    mock_response.status_code = 404
    mock_get.return_value = mock_response
    with pytest.raises(ConnectionError):
        fetch_weather("Unknown")

关键点:

  1. @patch装饰器临时替换fetcher.requests.get为Mock对象
  2. Mock对象可以预设返回值、属性
  3. 使用assert_called_once_with验证调用参数

案例四:测试异常与边界条件

好的单元测试不仅要测“正常路径”,还要测试异常和边界,比如处理用户输入的验证函数:

# validator.py
def validate_age(age: int) -> bool:
    if not isinstance(age, int):
        raise TypeError("Age must be an integer")
    if age < 0 or age > 150:
        raise ValueError("Age must be between 0 and 150")
    return True
# 测试边界
def test_valid_ages():
    assert validate_age(0) == True
    assert validate_age(150) == True
    assert validate_age(25) == True
def test_invalid_ages():
    with pytest.raises(ValueError):
        validate_age(-1)
    with pytest.raises(ValueError):
        validate_age(151)
def test_non_integer():
    with pytest.raises(TypeError):
        validate_age("25")
    with pytest.raises(TypeError):
        validate_age(25.5)

边界值测试(0, 150, -1, 151)帮你发现常见“栅栏错误”。

案例五:数据驱动测试与参数化

当测试逻辑相同但数据不同时,用参数化避免重复代码。

import pytest
from code import calculate_discount
@pytest.mark.parametrize("price,rate,expected", [
    (100, 0.2, 80),
    (100, 0, 100),
    (100, 1, 0),
    (50.5, 0.1, 45.45),
])
def test_calculate_discount_param(price, rate, expected):
    assert abs(calculate_discount(price, rate) - expected) < 0.001

同时测试异常情况:

@pytest.mark.parametrize("price,rate,expected_exception", [
    (-10, 0.2, ValueError),
    (100, 1.5, ValueError),
])
def test_calculate_discount_invalid(price, rate, expected_exception):
    with pytest.raises(expected_exception):
        calculate_discount(price, rate)

参数化让测试数量轻松扩展,且每个参数组合都独立报告。

常见问题FAQ

Q1:单元测试覆盖率要达到100%吗?
A:理论上100%覆盖率意味着每行代码都被测试执行过,但实际中很难(比如异常处理分支过于复杂),一般目标是80%-90%,核心逻辑应100%覆盖,重点不是“数字”,而是“每个重要行为都有测试”。

Q2:测试代码应该放在哪里?
A:主流做法是在项目根目录下创建tests/文件夹,测试文件名为test_*.py,这样pytest会自动发现,保持测试与源码的模块结构对应。

Q3:如何测试私有方法(以开头的方法)?
A:不推荐测试私有方法,应通过公共接口间接验证,如果私有方法逻辑很复杂,说明它应当被提取成一个独立的函数或类。

Q4:测试会很慢,怎么办?
A:将测试分层:

  • 快速单元测试(毫秒级):纯逻辑、无IO
  • 慢速集成测试:需要数据库、外部服务(用标记@pytest.mark.integration区分)
    持续集成中只运行单元测试,集成测试单独触发。

Q5:如何处理临时文件或数据库?
A:使用tmp_path fixture(pytest内置)创建临时目录,测试结束后自动清理,数据库测试用内存数据库(如SQLite memory:)。

总结与最佳实践

通过以上五个案例,你已掌握Python单元测试的核心技术:

  1. 纯函数测试:最基础也最高效,覆盖正常值和边界
  2. 依赖注入:通过假对象替代真实依赖,保持测试独立性
  3. Mock外部调用:模拟网络请求、文件操作等副作用
  4. 异常验证:确保错误处理逻辑正确
  5. 参数化测试:减少重复,提高覆盖率

记住的关键原则

  • 每个测试只验证一个行为
  • 测试应当是独立的,任意顺序运行都应通过
  • 使用有意义的命名(test_功能_场景_预期结果
  • 优先测试公共接口,而非内部实现细节
  • 持续集成中自动化运行,失败时立即修复

当你开始习惯先写测试再实现功能(TDD),你会发现代码质量显著提升,从现在开始,为你的下一个Python功能编写第一个单元测试吧。

抱歉,评论功能暂时关闭!