Skip to content

单元测试

构建坚实的代码质量基础

单元测试原则

单元测试是对软件中最小可测试单元进行检查和验证的过程。优秀的单元测试应该遵循 FIRST 原则:

  • Fast(快速):单元测试应该快速执行
  • Independent(独立):测试之间相互独立
  • Repeatable(可重复):在任何环境下都能重复执行
  • Self-Validating(自我验证):测试应该有明确的通过/失败结果
  • Timely(及时):测试应该及时编写(TDD)

单元测试结构

AAA 模式

java
@Test
public void testCalculateDiscount() {
    // Arrange(准备)
    Product product = new Product("laptop", 1000.0);
    DiscountService service = new DiscountService();
    
    // Act(执行)
    double discount = service.calculateDiscount(product, 0.1);
    
    // Assert(断言)
    assertEquals(900.0, discount, 0.01);
}

Given-When-Then 模式

python
def test_user_registration():
    # Given(给定条件)
    user_data = {
        "email": "test@example.com",
        "password": "SecurePass123!"
    }
    service = UserService()
    
    # When(当执行操作)
    result = service.register(user_data)
    
    # Then(那么期望结果)
    assert result.id is not None
    assert result.email == user_data["email"]

Java 单元测试

JUnit 5 基础

java
import org.junit.jupiter.api.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

@DisplayName("订单服务测试")
class OrderServiceTest {
    
    private OrderService orderService;
    private PaymentService paymentService;
    private InventoryService inventoryService;
    
    @BeforeEach
    void setUp() {
        paymentService = mock(PaymentService.class);
        inventoryService = mock(InventoryService.class);
        orderService = new OrderService(paymentService, inventoryService);
    }
    
    @Nested
    @DisplayName("创建订单")
    class CreateOrderTests {
        
        @Test
        @DisplayName("成功创建订单 - 库存充足且支付成功")
        void shouldCreateOrderSuccessfully() {
            // Given
            OrderRequest request = OrderRequest.builder()
                .productId("PROD-001")
                .quantity(2)
                .userId("USER-123")
                .build();
                
            when(inventoryService.checkStock("PROD-001", 2)).thenReturn(true);
            when(paymentService.processPayment(any())).thenReturn(
                PaymentResult.success("PAY-123")
            );
            
            // When
            Order order = orderService.createOrder(request);
            
            // Then
            assertNotNull(order);
            assertEquals(OrderStatus.CONFIRMED, order.getStatus());
            assertEquals("PAY-123", order.getPaymentId());
            
            verify(inventoryService).reserveStock("PROD-001", 2);
            verify(paymentService).processPayment(any());
        }
        
        @Test
        @DisplayName("库存不足时抛出异常")
        void shouldThrowExceptionWhenOutOfStock() {
            // Given
            OrderRequest request = OrderRequest.builder()
                .productId("PROD-001")
                .quantity(10)
                .build();
                
            when(inventoryService.checkStock("PROD-001", 10)).thenReturn(false);
            
            // When & Then
            OutOfStockException exception = assertThrows(
                OutOfStockException.class,
                () -> orderService.createOrder(request)
            );
            
            assertEquals("Product PROD-001 is out of stock", exception.getMessage());
            verify(paymentService, never()).processPayment(any());
        }
    }
    
    @Nested
    @DisplayName("取消订单")
    class CancelOrderTests {
        
        @Test
        @DisplayName("成功取消未发货订单")
        void shouldCancelUnshippedOrder() {
            // Given
            Order order = Order.builder()
                .id("ORDER-123")
                .status(OrderStatus.CONFIRMED)
                .paymentId("PAY-123")
                .build();
                
            when(paymentService.refund("PAY-123")).thenReturn(true);
            
            // When
            boolean result = orderService.cancelOrder(order);
            
            // Then
            assertTrue(result);
            assertEquals(OrderStatus.CANCELLED, order.getStatus());
            verify(paymentService).refund("PAY-123");
        }
        
        @Test
        @DisplayName("已发货订单无法取消")
        void shouldNotCancelShippedOrder() {
            // Given
            Order order = Order.builder()
                .id("ORDER-123")
                .status(OrderStatus.SHIPPED)
                .build();
            
            // When & Then
            assertThrows(
                IllegalStateException.class,
                () -> orderService.cancelOrder(order),
                "Cannot cancel shipped order"
            );
        }
    }
}

参数化测试

java
@ParameterizedTest
@DisplayName("价格计算测试")
class PriceCalculatorTest {
    
    @ParameterizedTest(name = "原价 {0}, 折扣 {1}, 期望价格 {2}")
    @CsvSource({
        "100.0, 0.1, 90.0",
        "100.0, 0.2, 80.0",
        "100.0, 0.0, 100.0",
        "100.0, 1.0, 0.0"
    })
    void testCalculateDiscountedPrice(double originalPrice, double discount, double expectedPrice) {
        PriceCalculator calculator = new PriceCalculator();
        double result = calculator.calculateDiscountedPrice(originalPrice, discount);
        assertEquals(expectedPrice, result, 0.01);
    }
    
    @ParameterizedTest
    @MethodSource("provideInvalidInputs")
    @DisplayName("无效输入验证")
    void testInvalidInputs(double price, double discount, Class<? extends Exception> expectedException) {
        PriceCalculator calculator = new PriceCalculator();
        assertThrows(expectedException, () -> 
            calculator.calculateDiscountedPrice(price, discount)
        );
    }
    
    private static Stream<Arguments> provideInvalidInputs() {
        return Stream.of(
            Arguments.of(-100.0, 0.1, IllegalArgumentException.class),
            Arguments.of(100.0, -0.1, IllegalArgumentException.class),
            Arguments.of(100.0, 1.5, IllegalArgumentException.class)
        );
    }
}

Python 单元测试

pytest 测试框架

python
# test_user_service.py
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from src.services.user_service import UserService
from src.models.user import User
from src.exceptions import ValidationError, AuthenticationError

class TestUserService:
    """用户服务单元测试"""
    
    @pytest.fixture
    def user_service(self):
        """创建用户服务实例"""
        repository = Mock()
        email_service = Mock()
        cache_service = Mock()
        return UserService(repository, email_service, cache_service)
    
    @pytest.fixture
    def valid_user(self):
        """有效用户数据"""
        return User(
            id="123",
            email="test@example.com",
            username="testuser",
            created_at=datetime.now()
        )
    
    class TestAuthentication:
        """认证相关测试"""
        
        def test_login_with_valid_credentials(self, user_service, valid_user):
            """测试使用有效凭据登录"""
            # Arrange
            user_service.repository.find_by_email.return_value = valid_user
            user_service.repository.verify_password.return_value = True
            
            # Act
            token = user_service.login("test@example.com", "password123")
            
            # Assert
            assert token is not None
            assert len(token) > 0
            user_service.cache_service.set.assert_called_once()
        
        def test_login_with_invalid_password(self, user_service, valid_user):
            """测试使用无效密码登录"""
            # Arrange
            user_service.repository.find_by_email.return_value = valid_user
            user_service.repository.verify_password.return_value = False
            
            # Act & Assert
            with pytest.raises(AuthenticationError) as exc_info:
                user_service.login("test@example.com", "wrongpassword")
            
            assert str(exc_info.value) == "Invalid credentials"
        
        @pytest.mark.parametrize("email,password,error_message", [
            ("", "password", "Email is required"),
            ("test@example.com", "", "Password is required"),
            ("invalid-email", "password", "Invalid email format"),
            ("test@example.com", "short", "Password too short")
        ])
        def test_login_validation(self, user_service, email, password, error_message):
            """测试登录参数验证"""
            with pytest.raises(ValidationError) as exc_info:
                user_service.login(email, password)
            
            assert error_message in str(exc_info.value)
    
    class TestUserManagement:
        """用户管理相关测试"""
        
        def test_create_user_success(self, user_service):
            """测试成功创建用户"""
            # Arrange
            user_data = {
                "email": "new@example.com",
                "username": "newuser",
                "password": "SecurePass123!"
            }
            
            user_service.repository.exists_by_email.return_value = False
            user_service.repository.save.return_value = User(
                id="456",
                **user_data,
                created_at=datetime.now()
            )
            
            # Act
            created_user = user_service.create_user(user_data)
            
            # Assert
            assert created_user.id == "456"
            assert created_user.email == user_data["email"]
            user_service.email_service.send_welcome_email.assert_called_once_with(
                user_data["email"],
                user_data["username"]
            )
        
        @patch('src.services.user_service.datetime')
        def test_password_reset_token_generation(self, mock_datetime, user_service, valid_user):
            """测试密码重置令牌生成"""
            # Arrange
            mock_now = datetime(2024, 1, 1, 12, 0, 0)
            mock_datetime.now.return_value = mock_now
            user_service.repository.find_by_email.return_value = valid_user
            
            # Act
            token = user_service.generate_password_reset_token("test@example.com")
            
            # Assert
            assert token is not None
            user_service.cache_service.set.assert_called_with(
                f"password_reset:{token}",
                valid_user.id,
                ttl=3600  # 1 hour
            )

测试 Fixtures 和 Mocks

python
# conftest.py
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.database import Base
from src.models import User, Product, Order

@pytest.fixture(scope="session")
def test_db():
    """创建测试数据库"""
    engine = create_engine("sqlite:///:memory:")
    Base.metadata.create_all(engine)
    yield engine
    Base.metadata.drop_all(engine)

@pytest.fixture
def db_session(test_db):
    """创建数据库会话"""
    Session = sessionmaker(bind=test_db)
    session = Session()
    yield session
    session.rollback()
    session.close()

@pytest.fixture
def mock_user(db_session):
    """创建测试用户"""
    user = User(
        email="test@example.com",
        username="testuser",
        password_hash="hashed_password"
    )
    db_session.add(user)
    db_session.commit()
    return user

@pytest.fixture
def mock_redis():
    """模拟 Redis 客户端"""
    import fakeredis
    return fakeredis.FakeStrictRedis()

# 使用 fixtures
def test_order_creation(db_session, mock_user, mock_redis):
    """测试订单创建"""
    order_service = OrderService(db_session, mock_redis)
    order = order_service.create_order(
        user_id=mock_user.id,
        products=[{"id": "PROD-1", "quantity": 2}]
    )
    
    assert order.user_id == mock_user.id
    assert order.status == "pending"

TypeScript/JavaScript 单元测试

Jest 测试框架

typescript
// userService.test.ts
import { UserService } from '../src/services/userService';
import { UserRepository } from '../src/repositories/userRepository';
import { EmailService } from '../src/services/emailService';
import { CacheService } from '../src/services/cacheService';
import { User } from '../src/models/user';

// Mock 所有依赖
jest.mock('../src/repositories/userRepository');
jest.mock('../src/services/emailService');
jest.mock('../src/services/cacheService');

describe('UserService', () => {
  let userService: UserService;
  let mockUserRepository: jest.Mocked<UserRepository>;
  let mockEmailService: jest.Mocked<EmailService>;
  let mockCacheService: jest.Mocked<CacheService>;

  beforeEach(() => {
    // 清除所有 mocks
    jest.clearAllMocks();
    
    // 创建 mock 实例
    mockUserRepository = new UserRepository() as jest.Mocked<UserRepository>;
    mockEmailService = new EmailService() as jest.Mocked<EmailService>;
    mockCacheService = new CacheService() as jest.Mocked<CacheService>;
    
    // 创建服务实例
    userService = new UserService(
      mockUserRepository,
      mockEmailService,
      mockCacheService
    );
  });

  describe('getUserById', () => {
    it('应该从缓存返回用户', async () => {
      // Arrange
      const cachedUser = {
        id: '123',
        email: 'cached@example.com',
        username: 'cacheduser'
      };
      mockCacheService.get.mockResolvedValue(cachedUser);

      // Act
      const result = await userService.getUserById('123');

      // Assert
      expect(result).toEqual(cachedUser);
      expect(mockCacheService.get).toHaveBeenCalledWith('user:123');
      expect(mockUserRepository.findById).not.toHaveBeenCalled();
    });

    it('缓存未命中时从数据库获取', async () => {
      // Arrange
      const dbUser = new User({
        id: '123',
        email: 'db@example.com',
        username: 'dbuser'
      });
      mockCacheService.get.mockResolvedValue(null);
      mockUserRepository.findById.mockResolvedValue(dbUser);

      // Act
      const result = await userService.getUserById('123');

      // Assert
      expect(result).toEqual(dbUser);
      expect(mockUserRepository.findById).toHaveBeenCalledWith('123');
      expect(mockCacheService.set).toHaveBeenCalledWith(
        'user:123',
        dbUser,
        { ttl: 3600 }
      );
    });

    it('用户不存在时抛出错误', async () => {
      // Arrange
      mockCacheService.get.mockResolvedValue(null);
      mockUserRepository.findById.mockResolvedValue(null);

      // Act & Assert
      await expect(userService.getUserById('999'))
        .rejects
        .toThrow('User not found');
    });
  });

  describe('updateUser', () => {
    const existingUser = new User({
      id: '123',
      email: 'old@example.com',
      username: 'olduser'
    });

    it('应该更新用户信息', async () => {
      // Arrange
      const updates = {
        username: 'newusername',
        bio: 'New bio'
      };
      
      mockUserRepository.findById.mockResolvedValue(existingUser);
      mockUserRepository.save.mockResolvedValue({
        ...existingUser,
        ...updates
      } as User);

      // Act
      const result = await userService.updateUser('123', updates);

      // Assert
      expect(result.username).toBe('newusername');
      expect(result.bio).toBe('New bio');
      expect(mockCacheService.delete).toHaveBeenCalledWith('user:123');
    });

    it('不允许更新受保护字段', async () => {
      // Arrange
      const updates = {
        id: '456',  // 尝试更改 ID
        email: 'hacker@example.com'  // 尝试更改邮箱
      };

      mockUserRepository.findById.mockResolvedValue(existingUser);

      // Act & Assert
      await expect(userService.updateUser('123', updates))
        .rejects
        .toThrow('Cannot update protected fields');
    });
  });
});

测试异步代码

typescript
// asyncService.test.ts
describe('AsyncService', () => {
  describe('异步操作测试', () => {
    it('使用 async/await', async () => {
      const service = new AsyncService();
      const result = await service.fetchData();
      expect(result).toBeDefined();
    });

    it('使用 Promise', () => {
      const service = new AsyncService();
      return service.fetchData().then(result => {
        expect(result).toBeDefined();
      });
    });

    it('测试 Promise rejection', async () => {
      const service = new AsyncService();
      await expect(service.fetchDataWithError())
        .rejects
        .toThrow('Network error');
    });

    it('使用 fake timers', () => {
      jest.useFakeTimers();
      const callback = jest.fn();
      
      const service = new AsyncService();
      service.delayedOperation(callback, 1000);
      
      expect(callback).not.toHaveBeenCalled();
      
      jest.advanceTimersByTime(1000);
      
      expect(callback).toHaveBeenCalledTimes(1);
      jest.useRealTimers();
    });
  });
});

测试替身(Test Doubles)

Mock vs Stub vs Spy

javascript
// 不同类型的测试替身示例
describe('测试替身示例', () => {
  // Stub - 返回预定义的值
  it('使用 Stub', () => {
    const userRepository = {
      findById: jest.fn().mockReturnValue({ id: '123', name: 'Test User' })
    };
    
    const service = new UserService(userRepository);
    const user = service.getUser('123');
    
    expect(user.name).toBe('Test User');
  });

  // Mock - 验证交互
  it('使用 Mock', () => {
    const emailService = {
      sendEmail: jest.fn()
    };
    
    const service = new NotificationService(emailService);
    service.notifyUser('user@example.com', 'Hello');
    
    expect(emailService.sendEmail).toHaveBeenCalledWith(
      'user@example.com',
      'Hello'
    );
  });

  // Spy - 监视真实对象
  it('使用 Spy', () => {
    const realService = new CalculatorService();
    const spy = jest.spyOn(realService, 'add');
    
    const result = realService.add(2, 3);
    
    expect(spy).toHaveBeenCalledWith(2, 3);
    expect(result).toBe(5);
  });
});

测试覆盖率

配置覆盖率收集

json
// package.json
{
  "jest": {
    "collectCoverage": true,
    "collectCoverageFrom": [
      "src/**/*.{js,ts}",
      "!src/**/*.d.ts",
      "!src/**/index.ts",
      "!src/**/*.test.{js,ts}"
    ],
    "coverageThreshold": {
      "global": {
        "branches": 80,
        "functions": 80,
        "lines": 80,
        "statements": 80
      }
    },
    "coverageReporters": ["text", "lcov", "html"]
  }
}

覆盖率报告分析

bash
# 运行测试并生成覆盖率报告
npm test -- --coverage

# 输出示例
----------------------|---------|----------|---------|---------|-------------------
File                  | % Stmts | % Branch | % Funcs | % Lines | Uncovered Line #s
----------------------|---------|----------|---------|---------|-------------------
All files             |   85.71 |    83.33 |   88.89 |   85.71 |
 services/            |   86.67 |    85.71 |     100 |   86.67 |
  userService.ts      |   86.67 |    85.71 |     100 |   86.67 | 45-47
 repositories/        |   83.33 |       75 |      75 |   83.33 |
  userRepository.ts   |   83.33 |       75 |      75 |   83.33 | 23,67
----------------------|---------|----------|---------|---------|-------------------

单元测试最佳实践

1. 保持测试简单

typescript
// ❌ 不好的例子 - 测试太复杂
it('should handle complex user workflow', async () => {
  const user = await createUser();
  await verifyEmail(user);
  await updateProfile(user);
  await addFriends(user);
  await createPosts(user);
  // ... 更多操作
});

// ✅ 好的例子 - 每个测试只测一件事
it('should create user with valid data', async () => {
  const userData = { email: 'test@example.com', password: 'Pass123!' };
  const user = await userService.createUser(userData);
  expect(user.email).toBe(userData.email);
});

it('should send verification email after user creation', async () => {
  const userData = { email: 'test@example.com', password: 'Pass123!' };
  await userService.createUser(userData);
  expect(mockEmailService.sendVerificationEmail).toHaveBeenCalledWith(userData.email);
});

2. 使用描述性的测试名称

python
# ❌ 不好的例子
def test_1():
    pass

def test_user():
    pass

# ✅ 好的例子
def test_should_return_user_when_valid_id_provided():
    pass

def test_should_raise_not_found_error_when_user_does_not_exist():
    pass

3. 避免测试实现细节

java
// ❌ 测试实现细节
@Test
public void testInternalMethod() {
    // 测试私有方法或内部实现
}

// ✅ 测试行为
@Test
public void shouldCalculateTotalPriceWithTax() {
    Order order = new Order();
    order.addItem(new Item("Book", 10.0));
    
    double total = order.calculateTotal();
    
    assertEquals(10.5, total); // 包含 5% 税
}

总结

优秀的单元测试具有以下特点:

  • 🚀 快速执行: 毫秒级运行速度
  • 🎯 聚焦单一功能: 每个测试只验证一个行为
  • 🔄 可重复: 无论何时何地运行结果一致
  • 📝 自文档化: 测试即文档
  • 🛡️ 高覆盖率: 覆盖关键业务逻辑

通过编写高质量的单元测试,为代码质量提供坚实保障。

SOLO Development Guide