第 10 章

Streaming 流式输出:SSE 协议、断点续传与前端实时渲染完全实战

第十章:Batch API:异步大规模推理与成本优化

10.1 Batch API 的定位与价值

Anthropic 的 Batch API(Message Batches API)是专为大规模异步推理设计的接口。它与标准 Messages API 的核心区别在于:

典型适用场景:

不适用场景:任何需要实时响应的场景(聊天、API 网关等)。

10.2 API 基础:创建与查询 Batch

创建 Batch 请求

import anthropic

client = anthropic.Anthropic()

# 准备 batch 请求列表
requests = [
    {
        "custom_id": "review-001",
        "params": {
            "model": "claude-haiku-4-5-20251001",
            "max_tokens": 512,
            "messages": [
                {
                    "role": "user",
                    "content": "对以下商品评论进行情感分析,输出 positive/negative/neutral:\n\n这款耳机音质非常棒,低音有力,高音清晰,强烈推荐!"
                }
            ]
        }
    },
    {
        "custom_id": "review-002",
        "params": {
            "model": "claude-haiku-4-5-20251001",
            "max_tokens": 512,
            "messages": [
                {
                    "role": "user",
                    "content": "对以下商品评论进行情感分析,输出 positive/negative/neutral:\n\n快递很慢,包装也很差,产品和描述不符,非常失望。"
                }
            ]
        }
    }
]

# 提交 batch
batch = client.messages.batches.create(requests=requests)

print(f"Batch ID: {batch.id}")
print(f"状态: {batch.processing_status}")
print(f"创建时间: {batch.created_at}")
print(f"过期时间: {batch.expires_at}")

custom_id 是用户自定义的唯一标识符,用于在结果中匹配对应的请求。每个 batch 最多可包含 10,000 个请求,总大小不超过 32 MB。

查询 Batch 状态

# 轮询状态
import time

batch_id = batch.id

while True:
    batch = client.messages.batches.retrieve(batch_id)
    
    print(f"状态: {batch.processing_status}")
    print(f"已完成: {batch.request_counts.succeeded}")
    print(f"错误: {batch.request_counts.errored}")
    print(f"待处理: {batch.request_counts.processing}")
    
    if batch.processing_status == "ended":
        break
    
    time.sleep(60)  # 每分钟检查一次

print("Batch 处理完成!")

processing_status 有以下取值:

10.3 获取与处理结果

下载结果

# 处理完成后下载结果
results = {}

for result in client.messages.batches.results(batch_id):
    custom_id = result.custom_id
    
    if result.result.type == "succeeded":
        # 成功的结果
        message = result.result.message
        text = message.content[0].text
        results[custom_id] = {
            "status": "success",
            "text": text,
            "input_tokens": message.usage.input_tokens,
            "output_tokens": message.usage.output_tokens
        }
    elif result.result.type == "errored":
        # 失败的结果
        error = result.result.error
        results[custom_id] = {
            "status": "error",
            "error_type": error.type,
            "error_message": str(error)
        }

# 打印结果
for custom_id, result in results.items():
    print(f"\n{custom_id}:")
    if result["status"] == "success":
        print(f"  结果: {result['text']}")
        print(f"  Token 使用: {result['input_tokens']}+{result['output_tokens']}")
    else:
        print(f"  错误: {result['error_type']} - {result['error_message']}")

流式处理结果

对于大型 batch,结果可以以流的形式处理,避免内存溢出:

import json

# 将结果写入文件,逐条处理
with open("batch_results.jsonl", "w", encoding="utf-8") as f:
    for result in client.messages.batches.results(batch_id):
        record = {
            "custom_id": result.custom_id,
            "type": result.result.type
        }
        
        if result.result.type == "succeeded":
            record["text"] = result.result.message.content[0].text
            record["usage"] = {
                "input_tokens": result.result.message.usage.input_tokens,
                "output_tokens": result.result.message.usage.output_tokens
            }
        else:
            record["error"] = result.result.error.type
        
        f.write(json.dumps(record, ensure_ascii=False) + "\n")

print("结果已写入 batch_results.jsonl")

10.4 大规模 Batch 工程模式

分片提交模式

当总请求数超过单次 batch 的 10,000 条上限时,需要分片处理:

import anthropic
from typing import List, Dict, Any
import math
import time

def chunk_requests(requests: List[Dict], chunk_size: int = 9000):
    """将请求列表分割为多个分片"""
    for i in range(0, len(requests), chunk_size):
        yield requests[i:i + chunk_size]

def submit_batch_jobs(
    client: anthropic.Anthropic,
    all_requests: List[Dict],
    chunk_size: int = 9000
) -> List[str]:
    """提交多个 batch 并返回所有 batch ID"""
    batch_ids = []
    chunks = list(chunk_requests(all_requests, chunk_size))
    
    print(f"总请求数: {len(all_requests)}, 分为 {len(chunks)} 个 batch")
    
    for i, chunk in enumerate(chunks):
        batch = client.messages.batches.create(requests=chunk)
        batch_ids.append(batch.id)
        print(f"Batch {i+1}/{len(chunks)} 已提交: {batch.id}")
        
        # 避免创建请求的速率限制
        if i < len(chunks) - 1:
            time.sleep(1)
    
    return batch_ids

def wait_for_all_batches(
    client: anthropic.Anthropic,
    batch_ids: List[str],
    poll_interval: int = 60
) -> Dict[str, Any]:
    """等待所有 batch 完成并收集结果"""
    pending = set(batch_ids)
    all_results = {}
    
    while pending:
        completed_this_round = set()
        
        for batch_id in pending:
            batch = client.messages.batches.retrieve(batch_id)
            
            if batch.processing_status == "ended":
                print(f"Batch {batch_id} 完成: "
                      f"{batch.request_counts.succeeded} 成功, "
                      f"{batch.request_counts.errored} 失败")
                
                # 收集结果
                for result in client.messages.batches.results(batch_id):
                    if result.result.type == "succeeded":
                        all_results[result.custom_id] = (
                            result.result.message.content[0].text
                        )
                
                completed_this_round.add(batch_id)
        
        pending -= completed_this_round
        
        if pending:
            print(f"还有 {len(pending)} 个 batch 待完成,{poll_interval}s 后重新检查...")
            time.sleep(poll_interval)
    
    return all_results

# 使用示例:处理 50,000 条文本分类任务
def build_classification_requests(texts: List[str]) -> List[Dict]:
    return [
        {
            "custom_id": f"text-{i:06d}",
            "params": {
                "model": "claude-haiku-4-5-20251001",
                "max_tokens": 10,
                "messages": [{
                    "role": "user",
                    "content": f"分类(正面/负面/中性):{text}"
                }]
            }
        }
        for i, text in enumerate(texts)
    ]

带系统提示的批量处理

SYSTEM_PROMPT = """你是一个专业的电商评论分析师。
你的任务是分析用户评论并输出:
1. 情感倾向(positive/negative/neutral)
2. 关键词(最多3个)
3. 评分预测(1-5星)

输出格式为 JSON,不要包含任何其他文字。"""

requests = [
    {
        "custom_id": f"review-{i}",
        "params": {
            "model": "claude-haiku-4-5-20251001",
            "max_tokens": 100,
            "system": SYSTEM_PROMPT,
            "messages": [{"role": "user", "content": review_text}]
        }
    }
    for i, review_text in enumerate(reviews)
]

10.5 成本优化策略

模型选择的成本影响

Batch API 在各模型上均提供 50% 折扣,但模型间的绝对价格差异巨大:

模型 标准输入价格 Batch 输入价格 标准输出价格 Batch 输出价格
claude-opus-4-6 $15/MTok $7.5/MTok $75/MTok $37.5/MTok
claude-sonnet-4-6 $3/MTok $1.5/MTok $15/MTok $7.5/MTok
claude-haiku-4-5-20251001 $0.8/MTok $0.4/MTok $4/MTok $2/MTok

对于大规模批处理任务(如数据标注),使用 claude-haiku-4-5-20251001 + Batch API 的成本约为使用 claude-opus-4-6 标准 API 的 1/188

最小化 Token 消耗

# 不推荐:在 prompt 中包含过多上下文
bad_request = {
    "custom_id": "item-001",
    "params": {
        "model": "claude-haiku-4-5-20251001",
        "max_tokens": 512,  # 设置过大
        "messages": [{
            "role": "user",
            "content": """
            [大量背景信息...]
            
            请对以下文本进行情感分析:
            这款产品很好用。
            """
        }]
    }
}

# 推荐:精简 prompt,system 提示放在 system 字段
good_request = {
    "custom_id": "item-001",
    "params": {
        "model": "claude-haiku-4-5-20251001",
        "max_tokens": 20,  # 仅需要简短输出
        "system": "情感分析:输出 positive/negative/neutral",
        "messages": [{
            "role": "user",
            "content": "这款产品很好用。"
        }]
    }
}

Token 预估与成本计算

import anthropic

client = anthropic.Anthropic()

def estimate_batch_cost(
    requests: list,
    model: str = "claude-haiku-4-5-20251001"
) -> dict:
    """估算 batch 成本(使用 token 计数 API)"""
    
    # 价格表(每百万 token,Batch 折扣后)
    BATCH_PRICES = {
        "claude-haiku-4-5-20251001": {"input": 0.40, "output": 2.00},
        "claude-sonnet-4-6": {"input": 1.50, "output": 7.50},
        "claude-opus-4-6": {"input": 7.50, "output": 37.50},
    }
    
    total_input_tokens = 0
    
    # 对前 10 条请求估算平均 token 数
    sample_size = min(10, len(requests))
    for req in requests[:sample_size]:
        # 使用 token counting API 估算
        count = client.messages.count_tokens(
            model=model,
            system=req["params"].get("system", ""),
            messages=req["params"]["messages"]
        )
        total_input_tokens += count.input_tokens
    
    avg_input = total_input_tokens / sample_size
    estimated_total_input = avg_input * len(requests)
    
    # 假设平均输出 token 数
    avg_output = sum(
        req["params"].get("max_tokens", 100) * 0.5  # 假设用到 50%
        for req in requests[:sample_size]
    ) / sample_size
    estimated_total_output = avg_output * len(requests)
    
    price = BATCH_PRICES.get(model, BATCH_PRICES["claude-haiku-4-5-20251001"])
    
    input_cost = (estimated_total_input / 1_000_000) * price["input"]
    output_cost = (estimated_total_output / 1_000_000) * price["output"]
    
    return {
        "estimated_input_tokens": int(estimated_total_input),
        "estimated_output_tokens": int(estimated_total_output),
        "estimated_input_cost_usd": round(input_cost, 4),
        "estimated_output_cost_usd": round(output_cost, 4),
        "estimated_total_cost_usd": round(input_cost + output_cost, 4)
    }

10.6 错误处理与部分失败

Batch 中的个别请求可能失败,整个 batch 仍会以 ended 状态完成。必须检查每条结果的 type 字段:

from collections import defaultdict

def process_batch_results(client, batch_id: str) -> dict:
    """处理 batch 结果,分类统计成功与失败"""
    
    succeeded = {}
    failed = defaultdict(list)
    
    for result in client.messages.batches.results(batch_id):
        cid = result.custom_id
        
        if result.result.type == "succeeded":
            succeeded[cid] = result.result.message.content[0].text
            
        elif result.result.type == "errored":
            error = result.result.error
            failed[error.type].append({
                "custom_id": cid,
                "error": str(error)
            })
    
    # 打印失败统计
    if failed:
        print(f"\n失败请求统计:")
        for error_type, items in failed.items():
            print(f"  {error_type}: {len(items)} 条")
            # 打印前 3 条样本
            for item in items[:3]:
                print(f"    - {item['custom_id']}: {item['error']}")
    
    print(f"\n成功: {len(succeeded)}, 失败: {sum(len(v) for v in failed.values())}")
    
    return {"succeeded": succeeded, "failed": dict(failed)}

# 重新提交失败的请求
def retry_failed_requests(
    original_requests: list,
    failed_ids: set,
    client
) -> str:
    """将失败的请求提取出来重新提交"""
    
    retry_requests = [
        req for req in original_requests
        if req["custom_id"] in failed_ids
    ]
    
    if not retry_requests:
        return None
    
    print(f"重新提交 {len(retry_requests)} 条失败请求...")
    retry_batch = client.messages.batches.create(requests=retry_requests)
    return retry_batch.id

10.7 Batch 生命周期管理

列出和取消 Batch

# 列出所有 batch
for batch in client.messages.batches.list():
    print(f"{batch.id}: {batch.processing_status} - "
          f"创建于 {batch.created_at}")

# 取消正在处理的 batch
batch = client.messages.batches.cancel(batch_id)
print(f"取消状态: {batch.processing_status}")  # 'canceling' 或 'ended'

Batch 过期

未完成的 batch 在 24 小时后过期,结果在完成后保留 29 天。应在此时间窗口内下载结果。

from datetime import datetime, timezone

def check_batch_expiry(batch) -> None:
    """检查 batch 是否即将过期"""
    now = datetime.now(timezone.utc)
    expires_at = batch.expires_at
    
    time_left = expires_at - now
    hours_left = time_left.total_seconds() / 3600
    
    if hours_left < 2:
        print(f"警告:Batch {batch.id} 将在 {hours_left:.1f} 小时内过期!")
    elif hours_left < 6:
        print(f"注意:Batch {batch.id} 还有 {hours_left:.1f} 小时到期")

10.8 与 Prompt Caching 结合使用

Batch API 与 Prompt Caching 可以叠加使用,进一步降低成本。当多个请求共享相同的系统提示时,缓存命中可以额外节省 90% 的 prompt 输入费用:

# 使用 prompt caching + batch API 的最优模式
SHARED_SYSTEM = "你是专业文档分析助手,负责提取文档中的关键信息。"

requests = [
    {
        "custom_id": f"doc-{i}",
        "params": {
            "model": "claude-haiku-4-5-20251001",
            "max_tokens": 200,
            "system": [
                {
                    "type": "text",
                    "text": SHARED_SYSTEM,
                    "cache_control": {"type": "ephemeral"}  # 启用缓存
                }
            ],
            "messages": [{"role": "user", "content": doc_text}]
        }
    }
    for i, doc_text in enumerate(documents)
]

注意:Prompt Caching 在 Batch 模式下按缓存读取价格(约标准输入价格的 10%)计算,进一步叠加 Batch 折扣。

10.9 完整生产示例:大规模产品描述生成

import anthropic
import json
import time
from pathlib import Path

def batch_generate_descriptions(
    products: list[dict],
    output_file: str = "descriptions.jsonl"
) -> None:
    """批量生成产品描述"""
    
    client = anthropic.Anthropic()
    
    # 构建请求
    requests = [
        {
            "custom_id": f"product-{p['id']}",
            "params": {
                "model": "claude-haiku-4-5-20251001",
                "max_tokens": 300,
                "system": (
                    "你是电商文案专家,用吸引人的中文写作风格"
                    "为产品生成50-100字的销售描述。"
                ),
                "messages": [{
                    "role": "user",
                    "content": (
                        f"产品名称:{p['name']}\n"
                        f"类别:{p['category']}\n"
                        f"特点:{', '.join(p['features'])}"
                    )
                }]
            }
        }
        for p in products
    ]
    
    # 分片提交
    batch_ids = []
    for i in range(0, len(requests), 9000):
        chunk = requests[i:i+9000]
        batch = client.messages.batches.create(requests=chunk)
        batch_ids.append(batch.id)
        print(f"提交 Batch: {batch.id} ({len(chunk)} 条请求)")
    
    # 等待完成
    pending = set(batch_ids)
    while pending:
        done = set()
        for bid in pending:
            b = client.messages.batches.retrieve(bid)
            if b.processing_status == "ended":
                done.add(bid)
        pending -= done
        if pending:
            print(f"等待 {len(pending)} 个 batch...")
            time.sleep(30)
    
    # 收集结果
    with open(output_file, "w", encoding="utf-8") as f:
        for bid in batch_ids:
            for result in client.messages.batches.results(bid):
                if result.result.type == "succeeded":
                    product_id = result.custom_id.replace("product-", "")
                    record = {
                        "id": product_id,
                        "description": result.result.message.content[0].text
                    }
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")
    
    print(f"完成!结果写入 {output_file}")

小结

Batch API 是高成本效益大规模推理的首选方案。核心要点:

  1. 相比标准 API 节省 50% 费用,是数据处理、标注、评估任务的最优选择
  2. 每次 batch 最多 10,000 条请求,超过需分片提交
  3. custom_id 用于匹配请求与结果,务必保证唯一性
  4. 结果迭代器按流式方式读取,避免大 batch 内存溢出
  5. 结合 Prompt Caching 可进一步节省共享系统提示的开销
  6. 部分失败是正常情况,需要针对失败条目实现重试逻辑
本章评分
4.5  / 5  (50 评分)

💬 留言讨论