Python案例如何做单元测试:从入门到实战的完整指南
目录导读
为什么需要单元测试?
问:单元测试到底是什么?为什么这么多开发者强调它?

答:单元测试是针对程序中的最小可测试单元(通常是函数或方法)进行的自动化验证,在Python中,单元测试意味着你写出独立的测试代码,来检查某个函数的输入输出是否符合预期,它的核心价值包括:
- 早期发现bug:在代码合并到主分支前就能捕获逻辑错误
- 防止回归:修改代码后,只需运行测试就能知道是否破坏了已有功能
- 文档化行为:测试代码本身就是一种“活文档”,展示了函数应如何使用
- 提升设计:为了可测试性,你自然会写出更解耦、更模块化的代码
根据Stack Overflow 2023年调查,超过65%的专业开发者会定期编写单元测试,这不再是“可选项”,而是生产级代码的标配。
准备工作:安装与基础框架
Python内置的unittest模块是官方标准库,无需安装,但更流行的选择是pytest——它语法更简洁、支持更丰富的插件生态。
pip install pytest pytest-cov # pytest-cov用于覆盖率统计
关键概念:
- TestCase:一组相关测试的集合
- 断言:
assert语句或self.assertEqual()等 - 覆盖率:被测试代码的执行比例,一般要求≥80%
案例一:测试一个简单的计算函数
假设你有一个函数calculate_discount(price, discount_rate),我们需要验证其输出。
# code.py
def calculate_discount(price: float, discount_rate: float) -> float:
"""应用折扣,返回折扣后的价格"""
if price < 0:
raise ValueError("Price cannot be negative")
if discount_rate < 0 or discount_rate > 1:
raise ValueError("Discount rate must be between 0 and 1")
return price * (1 - discount_rate)
对应的测试代码:
# test_code.py
import pytest
from code import calculate_discount
def test_normal_case():
assert calculate_discount(100, 0.2) == 80.0
def test_no_discount():
assert calculate_discount(100, 0) == 100.0
def test_full_discount():
assert calculate_discount(100, 1) == 0.0
def test_negative_price():
with pytest.raises(ValueError, match="Price cannot be negative"):
calculate_discount(-10, 0.2)
def test_invalid_rate():
with pytest.raises(ValueError, match="Discount rate must be between 0 and 1"):
calculate_discount(100, 1.5)
运行:pytest test_code.py -v
你会看到所有测试通过,并且错误路径也被覆盖。
案例二:测试包含外部依赖的类
实际项目中,很多类会依赖数据库、文件系统或其他服务,直接测试会很慢且不稳定,解决方案是依赖注入。
# service.py
class UserService:
def __init__(self, db_connection):
self.db = db_connection
def get_user_email(self, user_id):
user = self.db.find_one({"id": user_id})
if user is None:
raise ValueError("User not found")
return user.get("email")
测试时,使用一个假数据库实例:
# test_service.py
import pytest
from service import UserService
class FakeDB:
def __init__(self, data):
self.data = data
def find_one(self, query):
return self.data.get(query.get("id"))
def test_get_user_email_found():
fake_db = FakeDB({1: {"id": 1, "email": "test@example.com"}})
service = UserService(fake_db)
assert service.get_user_email(1) == "test@example.com"
def test_get_user_email_not_found():
fake_db = FakeDB({})
service = UserService(fake_db)
with pytest.raises(ValueError, match="User not found"):
service.get_user_email(99)
这种模式让测试完全在内存中运行,无需启动真实数据库。
案例三:使用Mock模拟网络请求
当代码调用外部API时,我们不想真正发起网络请求(会慢、可能失败、有配额限制)。unittest.mock库可以模拟对象。
# fetcher.py
import requests
def fetch_weather(city):
resp = requests.get(f"https://api.weather.com/{city}")
if resp.status_code == 200:
return resp.json()
else:
raise ConnectionError("API request failed")
测试中我们模拟requests.get:
# test_fetcher.py
import pytest
from unittest.mock import patch, Mock
from fetcher import fetch_weather
@patch("fetcher.requests.get")
def test_fetch_weather_success(mock_get):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"temp": 25, "city": "Beijing"}
mock_get.return_value = mock_response
result = fetch_weather("Beijing")
assert result == {"temp": 25, "city": "Beijing"}
mock_get.assert_called_once_with("https://api.weather.com/Beijing")
@patch("fetcher.requests.get")
def test_fetch_weather_failure(mock_get):
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
with pytest.raises(ConnectionError):
fetch_weather("Unknown")
关键点:
@patch装饰器临时替换fetcher.requests.get为Mock对象- Mock对象可以预设返回值、属性
- 使用
assert_called_once_with验证调用参数
案例四:测试异常与边界条件
好的单元测试不仅要测“正常路径”,还要测试异常和边界,比如处理用户输入的验证函数:
# validator.py
def validate_age(age: int) -> bool:
if not isinstance(age, int):
raise TypeError("Age must be an integer")
if age < 0 or age > 150:
raise ValueError("Age must be between 0 and 150")
return True
# 测试边界
def test_valid_ages():
assert validate_age(0) == True
assert validate_age(150) == True
assert validate_age(25) == True
def test_invalid_ages():
with pytest.raises(ValueError):
validate_age(-1)
with pytest.raises(ValueError):
validate_age(151)
def test_non_integer():
with pytest.raises(TypeError):
validate_age("25")
with pytest.raises(TypeError):
validate_age(25.5)
边界值测试(0, 150, -1, 151)帮你发现常见“栅栏错误”。
案例五:数据驱动测试与参数化
当测试逻辑相同但数据不同时,用参数化避免重复代码。
import pytest
from code import calculate_discount
@pytest.mark.parametrize("price,rate,expected", [
(100, 0.2, 80),
(100, 0, 100),
(100, 1, 0),
(50.5, 0.1, 45.45),
])
def test_calculate_discount_param(price, rate, expected):
assert abs(calculate_discount(price, rate) - expected) < 0.001
同时测试异常情况:
@pytest.mark.parametrize("price,rate,expected_exception", [
(-10, 0.2, ValueError),
(100, 1.5, ValueError),
])
def test_calculate_discount_invalid(price, rate, expected_exception):
with pytest.raises(expected_exception):
calculate_discount(price, rate)
参数化让测试数量轻松扩展,且每个参数组合都独立报告。
常见问题FAQ
Q1:单元测试覆盖率要达到100%吗?
A:理论上100%覆盖率意味着每行代码都被测试执行过,但实际中很难(比如异常处理分支过于复杂),一般目标是80%-90%,核心逻辑应100%覆盖,重点不是“数字”,而是“每个重要行为都有测试”。
Q2:测试代码应该放在哪里?
A:主流做法是在项目根目录下创建tests/文件夹,测试文件名为test_*.py,这样pytest会自动发现,保持测试与源码的模块结构对应。
Q3:如何测试私有方法(以开头的方法)?
A:不推荐测试私有方法,应通过公共接口间接验证,如果私有方法逻辑很复杂,说明它应当被提取成一个独立的函数或类。
Q4:测试会很慢,怎么办?
A:将测试分层:
- 快速单元测试(毫秒级):纯逻辑、无IO
- 慢速集成测试:需要数据库、外部服务(用标记
@pytest.mark.integration区分)
持续集成中只运行单元测试,集成测试单独触发。
Q5:如何处理临时文件或数据库?
A:使用tmp_path fixture(pytest内置)创建临时目录,测试结束后自动清理,数据库测试用内存数据库(如SQLite memory:)。
总结与最佳实践
通过以上五个案例,你已掌握Python单元测试的核心技术:
- 纯函数测试:最基础也最高效,覆盖正常值和边界
- 依赖注入:通过假对象替代真实依赖,保持测试独立性
- Mock外部调用:模拟网络请求、文件操作等副作用
- 异常验证:确保错误处理逻辑正确
- 参数化测试:减少重复,提高覆盖率
记住的关键原则:
- 每个测试只验证一个行为
- 测试应当是独立的,任意顺序运行都应通过
- 使用有意义的命名(
test_功能_场景_预期结果) - 优先测试公共接口,而非内部实现细节
- 持续集成中自动化运行,失败时立即修复
当你开始习惯先写测试再实现功能(TDD),你会发现代码质量显著提升,从现在开始,为你的下一个Python功能编写第一个单元测试吧。