第 9 章

AI 驱动的 TDD——让 Cursor 写测试,你负责提需求

第9章:AI 驱动的 TDD——让 Cursor 写测试,你负责提需求

传统 TDD 的瓶颈在于"写测试很慢"——开发者需要花大量时间构造测试数据、写断言、配置 mock。AI-TDD 重新分工:你负责描述期望的行为,AI 负责把这些描述变成可运行的测试。测试失败了,AI 再写实现代码直到通过。本章给出完整的 AI-TDD 工作流和两个真实案例。

三种开发方式对比

方式 写代码速度 代码质量 测试覆盖率 长期维护成本
传统开发(先写代码后补测试) 参差不齐
传统 TDD(先写测试后写代码)
AI-TDD(AI 写测试 + AI 写实现)

AI-TDD 的核心优势:保留了 TDD 的质量优势,同时解决了"写测试太慢"的痛点。

AI-TDD 的 7 步循环

  1. 描述需求:用自然语言描述函数/类应该做什么,包括边界情况
  2. AI 生成测试:让 AI 把需求描述变成测试用例
  3. 审查测试:你检查测试是否覆盖了所有重要场景,补充遗漏的
  4. 运行测试:此时测试应该全部失败(红色)
  5. AI 生成实现:给 AI 看测试文件,让它写通过测试的实现代码
  6. 运行测试:测试应该全部通过(绿色)
  7. 重构:在测试保护下,让 AI 优化代码结构

Python FastAPI 完整示例

需求描述(你提供):

实现一个用户注册接口,要求:
- 邮箱必须合法格式,且在系统中唯一
- 密码至少8位,包含字母和数字
- 注册成功返回用户ID和创建时间
- 邮箱重复返回 409 Conflict
- 其他验证失败返回 422 Unprocessable Entity

步骤一:让 AI 生成测试文件(conftest.py)

# tests/conftest.py — AI 生成
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.main import app
from app.database import Base, get_db

# 测试用内存数据库
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

@pytest.fixture(autouse=True)
def setup_database():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

@pytest.fixture
def client():
    def override_get_db():
        db = TestingSessionLocal()
        try:
            yield db
        finally:
            db.close()
    app.dependency_overrides[get_db] = override_get_db
    with TestClient(app) as c:
        yield c

步骤二:AI 生成测试用例

# tests/test_auth.py — AI 生成
import pytest

class TestUserRegistration:
    """用户注册接口测试"""

    def test_register_success(self, client):
        """正常注册场景"""
        response = client.post("/api/auth/register", json={
            "email": "[email protected]",
            "password": "Password123"
        })
        assert response.status_code == 201
        data = response.json()
        assert "id" in data
        assert "created_at" in data
        assert "password" not in data  # 密码不应出现在响应里

    def test_register_duplicate_email(self, client):
        """重复邮箱应返回 409"""
        payload = {"email": "[email protected]", "password": "Password123"}
        client.post("/api/auth/register", json=payload)
        response = client.post("/api/auth/register", json=payload)
        assert response.status_code == 409

    def test_register_invalid_email(self, client):
        """无效邮箱格式"""
        response = client.post("/api/auth/register", json={
            "email": "not-an-email",
            "password": "Password123"
        })
        assert response.status_code == 422

    def test_register_weak_password(self, client):
        """密码不满足强度要求"""
        cases = [
            "short",       # 太短
            "alllowercase", # 只有字母
            "12345678",    # 只有数字
        ]
        for password in cases:
            response = client.post("/api/auth/register", json={
                "email": "[email protected]",
                "password": password
            })
            assert response.status_code == 422, f"Failed for password: {password}"

    def test_register_returns_no_sensitive_data(self, client):
        """响应不包含密码哈希"""
        response = client.post("/api/auth/register", json={
            "email": "[email protected]",
            "password": "Password123"
        })
        assert "password" not in response.json()
        assert "hashed_password" not in response.json()

步骤三:AI 生成实现(看着测试写)

# app/routers/auth.py — AI 生成(通过以上测试为目标)
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from pydantic import BaseModel, EmailStr, validator
from datetime import datetime
import re
import bcrypt
from app.database import get_db
from app.models import User

router = APIRouter(prefix="/api/auth")

class RegisterRequest(BaseModel):
    email: EmailStr
    password: str

    @validator("password")
    def validate_password(cls, v):
        if len(v) < 8:
            raise ValueError("密码至少需要8位")
        if not re.search(r"[a-zA-Z]", v):
            raise ValueError("密码必须包含字母")
        if not re.search(r"\d", v):
            raise ValueError("密码必须包含数字")
        return v

class RegisterResponse(BaseModel):
    id: str
    created_at: datetime

@router.post("/register", response_model=RegisterResponse, status_code=201)
def register(request: RegisterRequest, db: Session = Depends(get_db)):
    existing = db.query(User).filter(User.email == request.email).first()
    if existing:
        raise HTTPException(
            status_code=status.HTTP_409_CONFLICT,
            detail="邮箱已被注册"
        )

    hashed_password = bcrypt.hashpw(
        request.password.encode(),
        bcrypt.gensalt()
    ).decode()

    user = User(email=request.email, hashed_password=hashed_password)
    db.add(user)
    db.commit()
    db.refresh(user)

    return RegisterResponse(id=str(user.id), created_at=user.created_at)

TypeScript + Vitest 示例

需求:实现价格计算函数(含折扣和税费)

// tests/priceCalculator.test.ts — AI 生成
import { describe, it, expect } from 'vitest'
import { calculatePrice } from '../src/priceCalculator'

describe('calculatePrice', () => {
  it('正常价格计算(无折扣无税)', () => {
    expect(calculatePrice({ amount: 100, discountPercent: 0, taxRate: 0 })).toBe(100)
  })

  it('折扣计算', () => {
    expect(calculatePrice({ amount: 100, discountPercent: 20, taxRate: 0 })).toBe(80)
  })

  it('税费计算', () => {
    expect(calculatePrice({ amount: 100, discountPercent: 0, taxRate: 10 })).toBe(110)
  })

  it('折扣和税费同时存在(折扣先于税费)', () => {
    // 100 - 20% = 80, 80 * 1.1 = 88
    expect(calculatePrice({ amount: 100, discountPercent: 20, taxRate: 10 })).toBe(88)
  })

  it('金额为0', () => {
    expect(calculatePrice({ amount: 0, discountPercent: 50, taxRate: 10 })).toBe(0)
  })

  it('折扣不能超过100%', () => {
    expect(() => calculatePrice({ amount: 100, discountPercent: 110, taxRate: 0 }))
      .toThrow('折扣不能超过100%')
  })

  it('金额不能为负数', () => {
    expect(() => calculatePrice({ amount: -1, discountPercent: 0, taxRate: 0 }))
      .toThrow('金额不能为负数')
  })

  it('计算结果精度(避免浮点数问题)', () => {
    // 使用整数分避免浮点精度问题
    const result = calculatePrice({ amount: 99.9, discountPercent: 33.3, taxRate: 8 })
    expect(typeof result).toBe('number')
    expect(result).toBeCloseTo(71.89, 2)
  })
})

覆盖率驱动迭代

AI-TDD 的另一个用法:把覆盖率报告给 AI,让它补充测试:

运行测试后覆盖率报告显示:
- src/services/PaymentService.ts: 62% 覆盖率
- 未覆盖行:45-67(退款逻辑)、89-102(超时处理)

@src/services/PaymentService.ts @tests/PaymentService.test.ts
请分析未覆盖的代码,补充缺失的测试用例。
重点覆盖退款失败和超时的场景。

AI-TDD 质量检查清单

本章要点

  1. AI-TDD 的分工:你提供需求描述和边界情况,AI 把描述变成可运行的测试
  2. 7步循环是关键:描述需求 → AI 生成测试 → 审查 → 红色 → AI 写实现 → 绿色 → 重构
  3. 先看测试是否正确:AI 生成的测试不一定覆盖所有重要场景,你需要审查并补充
  4. 覆盖率报告是迭代工具:把覆盖率报告给 AI,让它精准补充缺失的测试
  5. 测试质量和代码质量同样重要:单一职责、名称自说明、边界值覆盖——这些原则对 AI 生成的测试也适用
本章评分
4.7  / 5  (35 评分)

💬 留言讨论