单元测试
构建坚实的代码质量基础
单元测试原则
单元测试是对软件中最小可测试单元进行检查和验证的过程。优秀的单元测试应该遵循 FIRST 原则:
- Fast(快速):单元测试应该快速执行
- Independent(独立):测试之间相互独立
- Repeatable(可重复):在任何环境下都能重复执行
- Self-Validating(自我验证):测试应该有明确的通过/失败结果
- Timely(及时):测试应该及时编写(TDD)
单元测试结构
AAA 模式
java
@Test
public void testCalculateDiscount() {
// Arrange(准备)
Product product = new Product("laptop", 1000.0);
DiscountService service = new DiscountService();
// Act(执行)
double discount = service.calculateDiscount(product, 0.1);
// Assert(断言)
assertEquals(900.0, discount, 0.01);
}Given-When-Then 模式
python
def test_user_registration():
# Given(给定条件)
user_data = {
"email": "test@example.com",
"password": "SecurePass123!"
}
service = UserService()
# When(当执行操作)
result = service.register(user_data)
# Then(那么期望结果)
assert result.id is not None
assert result.email == user_data["email"]Java 单元测试
JUnit 5 基础
java
import org.junit.jupiter.api.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
@DisplayName("订单服务测试")
class OrderServiceTest {
private OrderService orderService;
private PaymentService paymentService;
private InventoryService inventoryService;
@BeforeEach
void setUp() {
paymentService = mock(PaymentService.class);
inventoryService = mock(InventoryService.class);
orderService = new OrderService(paymentService, inventoryService);
}
@Nested
@DisplayName("创建订单")
class CreateOrderTests {
@Test
@DisplayName("成功创建订单 - 库存充足且支付成功")
void shouldCreateOrderSuccessfully() {
// Given
OrderRequest request = OrderRequest.builder()
.productId("PROD-001")
.quantity(2)
.userId("USER-123")
.build();
when(inventoryService.checkStock("PROD-001", 2)).thenReturn(true);
when(paymentService.processPayment(any())).thenReturn(
PaymentResult.success("PAY-123")
);
// When
Order order = orderService.createOrder(request);
// Then
assertNotNull(order);
assertEquals(OrderStatus.CONFIRMED, order.getStatus());
assertEquals("PAY-123", order.getPaymentId());
verify(inventoryService).reserveStock("PROD-001", 2);
verify(paymentService).processPayment(any());
}
@Test
@DisplayName("库存不足时抛出异常")
void shouldThrowExceptionWhenOutOfStock() {
// Given
OrderRequest request = OrderRequest.builder()
.productId("PROD-001")
.quantity(10)
.build();
when(inventoryService.checkStock("PROD-001", 10)).thenReturn(false);
// When & Then
OutOfStockException exception = assertThrows(
OutOfStockException.class,
() -> orderService.createOrder(request)
);
assertEquals("Product PROD-001 is out of stock", exception.getMessage());
verify(paymentService, never()).processPayment(any());
}
}
@Nested
@DisplayName("取消订单")
class CancelOrderTests {
@Test
@DisplayName("成功取消未发货订单")
void shouldCancelUnshippedOrder() {
// Given
Order order = Order.builder()
.id("ORDER-123")
.status(OrderStatus.CONFIRMED)
.paymentId("PAY-123")
.build();
when(paymentService.refund("PAY-123")).thenReturn(true);
// When
boolean result = orderService.cancelOrder(order);
// Then
assertTrue(result);
assertEquals(OrderStatus.CANCELLED, order.getStatus());
verify(paymentService).refund("PAY-123");
}
@Test
@DisplayName("已发货订单无法取消")
void shouldNotCancelShippedOrder() {
// Given
Order order = Order.builder()
.id("ORDER-123")
.status(OrderStatus.SHIPPED)
.build();
// When & Then
assertThrows(
IllegalStateException.class,
() -> orderService.cancelOrder(order),
"Cannot cancel shipped order"
);
}
}
}参数化测试
java
@ParameterizedTest
@DisplayName("价格计算测试")
class PriceCalculatorTest {
@ParameterizedTest(name = "原价 {0}, 折扣 {1}, 期望价格 {2}")
@CsvSource({
"100.0, 0.1, 90.0",
"100.0, 0.2, 80.0",
"100.0, 0.0, 100.0",
"100.0, 1.0, 0.0"
})
void testCalculateDiscountedPrice(double originalPrice, double discount, double expectedPrice) {
PriceCalculator calculator = new PriceCalculator();
double result = calculator.calculateDiscountedPrice(originalPrice, discount);
assertEquals(expectedPrice, result, 0.01);
}
@ParameterizedTest
@MethodSource("provideInvalidInputs")
@DisplayName("无效输入验证")
void testInvalidInputs(double price, double discount, Class<? extends Exception> expectedException) {
PriceCalculator calculator = new PriceCalculator();
assertThrows(expectedException, () ->
calculator.calculateDiscountedPrice(price, discount)
);
}
private static Stream<Arguments> provideInvalidInputs() {
return Stream.of(
Arguments.of(-100.0, 0.1, IllegalArgumentException.class),
Arguments.of(100.0, -0.1, IllegalArgumentException.class),
Arguments.of(100.0, 1.5, IllegalArgumentException.class)
);
}
}Python 单元测试
pytest 测试框架
python
# test_user_service.py
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from src.services.user_service import UserService
from src.models.user import User
from src.exceptions import ValidationError, AuthenticationError
class TestUserService:
"""用户服务单元测试"""
@pytest.fixture
def user_service(self):
"""创建用户服务实例"""
repository = Mock()
email_service = Mock()
cache_service = Mock()
return UserService(repository, email_service, cache_service)
@pytest.fixture
def valid_user(self):
"""有效用户数据"""
return User(
id="123",
email="test@example.com",
username="testuser",
created_at=datetime.now()
)
class TestAuthentication:
"""认证相关测试"""
def test_login_with_valid_credentials(self, user_service, valid_user):
"""测试使用有效凭据登录"""
# Arrange
user_service.repository.find_by_email.return_value = valid_user
user_service.repository.verify_password.return_value = True
# Act
token = user_service.login("test@example.com", "password123")
# Assert
assert token is not None
assert len(token) > 0
user_service.cache_service.set.assert_called_once()
def test_login_with_invalid_password(self, user_service, valid_user):
"""测试使用无效密码登录"""
# Arrange
user_service.repository.find_by_email.return_value = valid_user
user_service.repository.verify_password.return_value = False
# Act & Assert
with pytest.raises(AuthenticationError) as exc_info:
user_service.login("test@example.com", "wrongpassword")
assert str(exc_info.value) == "Invalid credentials"
@pytest.mark.parametrize("email,password,error_message", [
("", "password", "Email is required"),
("test@example.com", "", "Password is required"),
("invalid-email", "password", "Invalid email format"),
("test@example.com", "short", "Password too short")
])
def test_login_validation(self, user_service, email, password, error_message):
"""测试登录参数验证"""
with pytest.raises(ValidationError) as exc_info:
user_service.login(email, password)
assert error_message in str(exc_info.value)
class TestUserManagement:
"""用户管理相关测试"""
def test_create_user_success(self, user_service):
"""测试成功创建用户"""
# Arrange
user_data = {
"email": "new@example.com",
"username": "newuser",
"password": "SecurePass123!"
}
user_service.repository.exists_by_email.return_value = False
user_service.repository.save.return_value = User(
id="456",
**user_data,
created_at=datetime.now()
)
# Act
created_user = user_service.create_user(user_data)
# Assert
assert created_user.id == "456"
assert created_user.email == user_data["email"]
user_service.email_service.send_welcome_email.assert_called_once_with(
user_data["email"],
user_data["username"]
)
@patch('src.services.user_service.datetime')
def test_password_reset_token_generation(self, mock_datetime, user_service, valid_user):
"""测试密码重置令牌生成"""
# Arrange
mock_now = datetime(2024, 1, 1, 12, 0, 0)
mock_datetime.now.return_value = mock_now
user_service.repository.find_by_email.return_value = valid_user
# Act
token = user_service.generate_password_reset_token("test@example.com")
# Assert
assert token is not None
user_service.cache_service.set.assert_called_with(
f"password_reset:{token}",
valid_user.id,
ttl=3600 # 1 hour
)测试 Fixtures 和 Mocks
python
# conftest.py
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.database import Base
from src.models import User, Product, Order
@pytest.fixture(scope="session")
def test_db():
"""创建测试数据库"""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(test_db):
"""创建数据库会话"""
Session = sessionmaker(bind=test_db)
session = Session()
yield session
session.rollback()
session.close()
@pytest.fixture
def mock_user(db_session):
"""创建测试用户"""
user = User(
email="test@example.com",
username="testuser",
password_hash="hashed_password"
)
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def mock_redis():
"""模拟 Redis 客户端"""
import fakeredis
return fakeredis.FakeStrictRedis()
# 使用 fixtures
def test_order_creation(db_session, mock_user, mock_redis):
"""测试订单创建"""
order_service = OrderService(db_session, mock_redis)
order = order_service.create_order(
user_id=mock_user.id,
products=[{"id": "PROD-1", "quantity": 2}]
)
assert order.user_id == mock_user.id
assert order.status == "pending"TypeScript/JavaScript 单元测试
Jest 测试框架
typescript
// userService.test.ts
import { UserService } from '../src/services/userService';
import { UserRepository } from '../src/repositories/userRepository';
import { EmailService } from '../src/services/emailService';
import { CacheService } from '../src/services/cacheService';
import { User } from '../src/models/user';
// Mock 所有依赖
jest.mock('../src/repositories/userRepository');
jest.mock('../src/services/emailService');
jest.mock('../src/services/cacheService');
describe('UserService', () => {
let userService: UserService;
let mockUserRepository: jest.Mocked<UserRepository>;
let mockEmailService: jest.Mocked<EmailService>;
let mockCacheService: jest.Mocked<CacheService>;
beforeEach(() => {
// 清除所有 mocks
jest.clearAllMocks();
// 创建 mock 实例
mockUserRepository = new UserRepository() as jest.Mocked<UserRepository>;
mockEmailService = new EmailService() as jest.Mocked<EmailService>;
mockCacheService = new CacheService() as jest.Mocked<CacheService>;
// 创建服务实例
userService = new UserService(
mockUserRepository,
mockEmailService,
mockCacheService
);
});
describe('getUserById', () => {
it('应该从缓存返回用户', async () => {
// Arrange
const cachedUser = {
id: '123',
email: 'cached@example.com',
username: 'cacheduser'
};
mockCacheService.get.mockResolvedValue(cachedUser);
// Act
const result = await userService.getUserById('123');
// Assert
expect(result).toEqual(cachedUser);
expect(mockCacheService.get).toHaveBeenCalledWith('user:123');
expect(mockUserRepository.findById).not.toHaveBeenCalled();
});
it('缓存未命中时从数据库获取', async () => {
// Arrange
const dbUser = new User({
id: '123',
email: 'db@example.com',
username: 'dbuser'
});
mockCacheService.get.mockResolvedValue(null);
mockUserRepository.findById.mockResolvedValue(dbUser);
// Act
const result = await userService.getUserById('123');
// Assert
expect(result).toEqual(dbUser);
expect(mockUserRepository.findById).toHaveBeenCalledWith('123');
expect(mockCacheService.set).toHaveBeenCalledWith(
'user:123',
dbUser,
{ ttl: 3600 }
);
});
it('用户不存在时抛出错误', async () => {
// Arrange
mockCacheService.get.mockResolvedValue(null);
mockUserRepository.findById.mockResolvedValue(null);
// Act & Assert
await expect(userService.getUserById('999'))
.rejects
.toThrow('User not found');
});
});
describe('updateUser', () => {
const existingUser = new User({
id: '123',
email: 'old@example.com',
username: 'olduser'
});
it('应该更新用户信息', async () => {
// Arrange
const updates = {
username: 'newusername',
bio: 'New bio'
};
mockUserRepository.findById.mockResolvedValue(existingUser);
mockUserRepository.save.mockResolvedValue({
...existingUser,
...updates
} as User);
// Act
const result = await userService.updateUser('123', updates);
// Assert
expect(result.username).toBe('newusername');
expect(result.bio).toBe('New bio');
expect(mockCacheService.delete).toHaveBeenCalledWith('user:123');
});
it('不允许更新受保护字段', async () => {
// Arrange
const updates = {
id: '456', // 尝试更改 ID
email: 'hacker@example.com' // 尝试更改邮箱
};
mockUserRepository.findById.mockResolvedValue(existingUser);
// Act & Assert
await expect(userService.updateUser('123', updates))
.rejects
.toThrow('Cannot update protected fields');
});
});
});测试异步代码
typescript
// asyncService.test.ts
describe('AsyncService', () => {
describe('异步操作测试', () => {
it('使用 async/await', async () => {
const service = new AsyncService();
const result = await service.fetchData();
expect(result).toBeDefined();
});
it('使用 Promise', () => {
const service = new AsyncService();
return service.fetchData().then(result => {
expect(result).toBeDefined();
});
});
it('测试 Promise rejection', async () => {
const service = new AsyncService();
await expect(service.fetchDataWithError())
.rejects
.toThrow('Network error');
});
it('使用 fake timers', () => {
jest.useFakeTimers();
const callback = jest.fn();
const service = new AsyncService();
service.delayedOperation(callback, 1000);
expect(callback).not.toHaveBeenCalled();
jest.advanceTimersByTime(1000);
expect(callback).toHaveBeenCalledTimes(1);
jest.useRealTimers();
});
});
});测试替身(Test Doubles)
Mock vs Stub vs Spy
javascript
// 不同类型的测试替身示例
describe('测试替身示例', () => {
// Stub - 返回预定义的值
it('使用 Stub', () => {
const userRepository = {
findById: jest.fn().mockReturnValue({ id: '123', name: 'Test User' })
};
const service = new UserService(userRepository);
const user = service.getUser('123');
expect(user.name).toBe('Test User');
});
// Mock - 验证交互
it('使用 Mock', () => {
const emailService = {
sendEmail: jest.fn()
};
const service = new NotificationService(emailService);
service.notifyUser('user@example.com', 'Hello');
expect(emailService.sendEmail).toHaveBeenCalledWith(
'user@example.com',
'Hello'
);
});
// Spy - 监视真实对象
it('使用 Spy', () => {
const realService = new CalculatorService();
const spy = jest.spyOn(realService, 'add');
const result = realService.add(2, 3);
expect(spy).toHaveBeenCalledWith(2, 3);
expect(result).toBe(5);
});
});测试覆盖率
配置覆盖率收集
json
// package.json
{
"jest": {
"collectCoverage": true,
"collectCoverageFrom": [
"src/**/*.{js,ts}",
"!src/**/*.d.ts",
"!src/**/index.ts",
"!src/**/*.test.{js,ts}"
],
"coverageThreshold": {
"global": {
"branches": 80,
"functions": 80,
"lines": 80,
"statements": 80
}
},
"coverageReporters": ["text", "lcov", "html"]
}
}覆盖率报告分析
bash
# 运行测试并生成覆盖率报告
npm test -- --coverage
# 输出示例
----------------------|---------|----------|---------|---------|-------------------
File | % Stmts | % Branch | % Funcs | % Lines | Uncovered Line #s
----------------------|---------|----------|---------|---------|-------------------
All files | 85.71 | 83.33 | 88.89 | 85.71 |
services/ | 86.67 | 85.71 | 100 | 86.67 |
userService.ts | 86.67 | 85.71 | 100 | 86.67 | 45-47
repositories/ | 83.33 | 75 | 75 | 83.33 |
userRepository.ts | 83.33 | 75 | 75 | 83.33 | 23,67
----------------------|---------|----------|---------|---------|-------------------单元测试最佳实践
1. 保持测试简单
typescript
// ❌ 不好的例子 - 测试太复杂
it('should handle complex user workflow', async () => {
const user = await createUser();
await verifyEmail(user);
await updateProfile(user);
await addFriends(user);
await createPosts(user);
// ... 更多操作
});
// ✅ 好的例子 - 每个测试只测一件事
it('should create user with valid data', async () => {
const userData = { email: 'test@example.com', password: 'Pass123!' };
const user = await userService.createUser(userData);
expect(user.email).toBe(userData.email);
});
it('should send verification email after user creation', async () => {
const userData = { email: 'test@example.com', password: 'Pass123!' };
await userService.createUser(userData);
expect(mockEmailService.sendVerificationEmail).toHaveBeenCalledWith(userData.email);
});2. 使用描述性的测试名称
python
# ❌ 不好的例子
def test_1():
pass
def test_user():
pass
# ✅ 好的例子
def test_should_return_user_when_valid_id_provided():
pass
def test_should_raise_not_found_error_when_user_does_not_exist():
pass3. 避免测试实现细节
java
// ❌ 测试实现细节
@Test
public void testInternalMethod() {
// 测试私有方法或内部实现
}
// ✅ 测试行为
@Test
public void shouldCalculateTotalPriceWithTax() {
Order order = new Order();
order.addItem(new Item("Book", 10.0));
double total = order.calculateTotal();
assertEquals(10.5, total); // 包含 5% 税
}总结
优秀的单元测试具有以下特点:
- 🚀 快速执行: 毫秒级运行速度
- 🎯 聚焦单一功能: 每个测试只验证一个行为
- 🔄 可重复: 无论何时何地运行结果一致
- 📝 自文档化: 测试即文档
- 🛡️ 高覆盖率: 覆盖关键业务逻辑
通过编写高质量的单元测试,为代码质量提供坚实保障。