第 60 章

缓存策略:KV Cache 与 Prompt Cache

第60章:缓存策略:KV Cache 与 Prompt Cache

推理成本中最大的浪费,往往来自重复计算。当 Hermes Agent 每次都从头处理相同的系统提示时,它就像一个每次上班都要重新记住公司规章制度的员工。缓存,就是让 Agent 拥有"职业记忆"的技术。


60.1 KV Cache 工作原理

Transformer 的计算瓶颈

Transformer 架构的核心是注意力机制(Attention),其计算复杂度为 O(n²)。在推理阶段,每生成一个新 token,都需要对之前所有 token 进行注意力计算:

Query (当前 token) × Key (所有历史 token) → 注意力权重
注意力权重 × Value (所有历史 token) → 当前 token 的输出

问题:历史 token 的 Key 和 Value 矩阵在每次生成新 token 时都需要重新计算,即使输入前缀完全相同。

KV Cache 解决方案:将已计算的 Key-Value 矩阵缓存起来,后续 token 生成时直接复用。

不使用 KV Cache:
  Token 1: 计算 KV → [KV₁]
  Token 2: 计算 KV → [KV₁, KV₂]  (KV₁ 重复计算)
  Token N: 计算 KV → [KV₁, ..., KVₙ]  (前N-1个KV重复计算)
  总计算量:O(n²)

使用 KV Cache:
  Token 1: 计算 KV → 缓存 [KV₁]
  Token 2: 从缓存读取 [KV₁],只计算 KV₂ → 缓存 [KV₁, KV₂]
  Token N: 从缓存读取,只计算 KVₙ
  总计算量:O(n)

KV Cache 的显存占用

KV Cache 需要大量显存,这是限制批处理大小和上下文长度的关键因素:

KV Cache 大小 = 2 × num_layers × num_heads × head_dim × seq_len × batch_size × dtype_bytes

以 Hermes 3 70B(BF16)为例:
  layers = 80
  heads = 64  
  head_dim = 128
  seq_len = 4096
  batch_size = 8
  dtype = BF16 (2 bytes)
  
  KV Cache = 2 × 80 × 64 × 128 × 4096 × 8 × 2 bytes
           ≈ 107 GB

这解释了为什么大批量推理需要多卡并行——仅 KV Cache 就可能超过单卡显存。


60.2 vLLM 的 Prefix Caching 配置

什么是 Prefix Caching

Prefix Caching(前缀缓存)是在 KV Cache 基础上的进一步优化:当多个请求共享相同的前缀(如系统提示),该前缀的 KV Cache 只计算一次,后续请求直接复用。

请求1: [系统提示 | 用户问题A]
请求2: [系统提示 | 用户问题B]  → 系统提示部分直接复用缓存
请求3: [系统提示 | 用户问题C]  → 系统提示部分直接复用缓存

vLLM Prefix Caching 配置

from vllm import LLM, SamplingParams

# 启用 Prefix Caching 的 vLLM 配置
llm = LLM(
    model="NousResearch/Hermes-3-Llama-3.1-70B",
    tensor_parallel_size=4,      # 4卡并行
    gpu_memory_utilization=0.85,  # GPU 显存使用率
    enable_prefix_caching=True,   # 启用前缀缓存(关键配置)
    max_model_len=128000,         # 最大上下文长度
    max_num_seqs=256,             # 最大并发序列数
    block_size=16,                # KV Cache 块大小(token数)
    swap_space=4,                  # CPU swap 空间(GB)
)

# 定义系统提示(所有请求共享)
SYSTEM_PROMPT = """You are Hermes, an autonomous AI agent by NousResearch.
[...系统提示内容...]"""

sampling_params = SamplingParams(
    temperature=0.7,
    max_tokens=2048,
    stop=["<|im_end|>"]
)

def batch_inference(user_questions: list[str]) -> list[str]:
    """
    批量推理,系统提示部分会被自动缓存
    第一个请求计算系统提示 KV,后续请求直接复用
    """
    prompts = [
        f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
        f"<|im_start|>user\n{q}<|im_end|>\n"
        f"<|im_start|>assistant\n"
        for q in user_questions
    ]
    
    outputs = llm.generate(prompts, sampling_params)
    return [output.outputs[0].text for output in outputs]

Prefix Caching 命中率监控

import requests
import json

def get_vllm_cache_stats(api_url: str = "http://localhost:8000") -> dict:
    """
    获取 vLLM 的缓存命中率统计
    需要 vLLM 的 metrics API 端点
    """
    response = requests.get(f"{api_url}/metrics")
    
    # vLLM 通过 Prometheus 格式导出指标
    metrics_text = response.text
    
    stats = {}
    for line in metrics_text.split('\n'):
        if 'prefix_cache' in line.lower() and not line.startswith('#'):
            parts = line.split(' ')
            if len(parts) >= 2:
                metric_name = parts[0].split('{')[0]
                value = float(parts[-1])
                stats[metric_name] = value
    
    # 计算命中率
    hits = stats.get('vllm_cache_hits_total', 0)
    queries = stats.get('vllm_cache_queries_total', 1)
    
    return {
        'cache_hits': int(hits),
        'cache_queries': int(queries),
        'hit_rate': f"{hits/queries*100:.2f}%",
        'raw_stats': stats
    }

# 定期监控
import time
def monitor_cache_performance(interval: int = 60):
    while True:
        stats = get_vllm_cache_stats()
        print(f"[{time.strftime('%H:%M:%S')}] Cache Hit Rate: {stats['hit_rate']}")
        time.sleep(interval)

60.3 Anthropic Prompt Cache 在 Hermes 中的使用

Prompt Cache 工作原理

Anthropic 的 Prompt Cache(2024年推出)允许将对话的特定部分标记为可缓存,服务端缓存这些内容的计算结果,后续请求可以复用。

定价说明

在 Hermes 混合架构中使用 Prompt Cache

import anthropic
from typing import Optional

class HermesWithPromptCache:
    """
    在 Anthropic API 上运行的 Hermes Agent(带 Prompt Cache)
    注:当 Hermes 4 通过 Anthropic 兼容接口部署时可用
    """
    
    def __init__(self, api_key: str):
        self.client = anthropic.Anthropic(api_key=api_key)
        self.model = "claude-3-5-sonnet-20241022"  # 或兼容 Hermes 的端点
        
        # 构建可缓存的系统提示
        self.system_prompt_blocks = [
            {
                "type": "text",
                "text": self._get_core_system_prompt(),
                "cache_control": {"type": "ephemeral"}  # 标记为可缓存
            }
        ]
    
    def _get_core_system_prompt(self) -> str:
        return """You are Hermes, an autonomous AI agent developed by NousResearch.

## Core Capabilities
- Complex reasoning and multi-step planning
- Tool use and code execution
- Research and information synthesis
- Structured output generation

## Operational Rules
①  Always verify tool outputs before proceeding
②  Ask for clarification when task requirements are ambiguous
③  Log significant decisions in structured format
④  Prioritize accuracy over speed
⑤  Respect user privacy and data security

## Tool Usage Protocol
- Call only registered tools
- Parse all arguments strictly
- Handle errors gracefully with retry logic
- Report tool failures immediately

## Output Format
- Use markdown for reports and documentation
- Use JSON for structured data
- Use plain text for conversational responses
"""
    
    def chat(
        self, 
        user_message: str,
        conversation_history: Optional[list] = None,
        task_context: Optional[str] = None
    ) -> dict:
        """
        发送消息,系统提示自动走缓存路径
        
        Returns:
            包含响应内容和缓存统计的字典
        """
        messages = []
        
        # 添加对话历史
        if conversation_history:
            messages.extend(conversation_history)
        
        # 如果有任务上下文,作为可缓存块注入
        if task_context:
            messages.append({
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": f"Task Context:\n{task_context}",
                        "cache_control": {"type": "ephemeral"}
                    },
                    {
                        "type": "text",
                        "text": user_message
                    }
                ]
            })
        else:
            messages.append({"role": "user", "content": user_message})
        
        response = self.client.messages.create(
            model=self.model,
            max_tokens=4096,
            system=self.system_prompt_blocks,
            messages=messages
        )
        
        # 提取缓存统计
        usage = response.usage
        cache_stats = {
            "input_tokens": usage.input_tokens,
            "output_tokens": usage.output_tokens,
            "cache_creation_input_tokens": getattr(usage, 'cache_creation_input_tokens', 0),
            "cache_read_input_tokens": getattr(usage, 'cache_read_input_tokens', 0),
        }
        
        # 计算实际成本
        cache_stats["cost_estimate"] = self._calculate_cost(cache_stats)
        cache_stats["savings_vs_no_cache"] = self._calculate_savings(cache_stats)
        
        return {
            "content": response.content[0].text,
            "cache_stats": cache_stats
        }
    
    def _calculate_cost(self, usage: dict) -> float:
        """估算本次请求实际成本(USD)"""
        INPUT_PRICE = 3.00 / 1_000_000   # $3/M tokens
        OUTPUT_PRICE = 15.00 / 1_000_000  # $15/M tokens
        CACHE_WRITE_PRICE = INPUT_PRICE * 1.25
        CACHE_READ_PRICE = INPUT_PRICE * 0.10
        
        cost = (
            usage["input_tokens"] * INPUT_PRICE +
            usage["output_tokens"] * OUTPUT_PRICE +
            usage["cache_creation_input_tokens"] * CACHE_WRITE_PRICE +
            usage["cache_read_input_tokens"] * CACHE_READ_PRICE
        )
        return round(cost, 6)
    
    def _calculate_savings(self, usage: dict) -> float:
        """计算与不使用缓存相比节省的费用"""
        INPUT_PRICE = 3.00 / 1_000_000
        cache_read = usage["cache_read_input_tokens"]
        # 没有缓存时这些 token 按标准价格计费
        savings = cache_read * INPUT_PRICE * 0.90  # 节省 90%
        return round(savings, 6)


# 实战示例:跨多个请求复用系统提示缓存
agent = HermesWithPromptCache(api_key="your_api_key")

# 第一次请求:缓存写入(略贵)
result1 = agent.chat("分析 Hermes Agent 的技术架构")
print(f"第1次调用: 缓存写入 {result1['cache_stats']['cache_creation_input_tokens']} tokens")
print(f"实际成本: ${result1['cache_stats']['cost_estimate']}")

# 第二次请求:缓存命中(省 90%)
result2 = agent.chat("列出 Hermes Agent 支持的工具类型")
print(f"第2次调用: 缓存命中 {result2['cache_stats']['cache_read_input_tokens']} tokens")
print(f"节省费用: ${result2['cache_stats']['savings_vs_no_cache']}")

60.4 Redis 层语义缓存

语义缓存 vs KV 精确匹配缓存

传统缓存要求请求完全匹配才能命中。语义缓存(Semantic Cache)通过向量相似度,允许相似但不完全相同的请求复用缓存结果:

精确匹配缓存:
  缓存:"如何安装 Python 3.11?"
  查询:"如何安装 Python 3.11?" → ✅ 命中
  查询:"Python 3.11 怎么安装?"   → ❌ 未命中(文字不同)

语义缓存:
  缓存:"如何安装 Python 3.11?"
  查询:"Python 3.11 怎么安装?"   → ✅ 命中(语义相似)
  查询:"安装 Python 的步骤是什么?" → ✅ 命中(相似度超过阈值)

实现方案

import redis
import numpy as np
import json
import hashlib
from typing import Optional, Tuple
import time

class SemanticCache:
    """
    基于 Redis 的语义缓存
    使用向量相似度匹配相似请求
    """
    
    def __init__(
        self,
        redis_url: str = "redis://localhost:6379",
        similarity_threshold: float = 0.92,
        ttl_seconds: int = 3600,  # 缓存有效期 1 小时
        embedding_model: str = "text-embedding-3-small"
    ):
        self.redis_client = redis.from_url(redis_url)
        self.similarity_threshold = similarity_threshold
        self.ttl_seconds = ttl_seconds
        self.embedding_model = embedding_model
        
        # 统计计数器
        self.hits = 0
        self.misses = 0
    
    def get_embedding(self, text: str) -> list[float]:
        """
        获取文本的向量嵌入
        实际使用时替换为真实的 embedding API 调用
        """
        # 示例:使用 OpenAI embedding API
        import openai
        response = openai.embeddings.create(
            model=self.embedding_model,
            input=text
        )
        return response.data[0].embedding
    
    def cosine_similarity(self, vec1: list, vec2: list) -> float:
        """计算余弦相似度"""
        v1 = np.array(vec1)
        v2 = np.array(vec2)
        return float(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))
    
    def _get_cache_key(self, query: str) -> str:
        """基于查询生成精确匹配缓存键"""
        return f"exact:{hashlib.md5(query.encode()).hexdigest()}"
    
    def get(self, query: str) -> Optional[Tuple[str, float]]:
        """
        查找缓存
        
        Returns:
            (缓存结果, 相似度分数) 如果命中
            None 如果未命中
        """
        # 先尝试精确匹配
        exact_key = self._get_cache_key(query)
        exact_result = self.redis_client.get(exact_key)
        if exact_result:
            self.hits += 1
            return json.loads(exact_result)["response"], 1.0
        
        # 语义相似度匹配
        query_embedding = self.get_embedding(query)
        
        # 获取所有缓存的向量键
        vector_keys = self.redis_client.keys("vec:*")
        
        best_match = None
        best_similarity = 0.0
        
        for key in vector_keys:
            cached_data = json.loads(self.redis_client.get(key))
            cached_embedding = cached_data["embedding"]
            
            similarity = self.cosine_similarity(query_embedding, cached_embedding)
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_match = cached_data
        
        if best_similarity >= self.similarity_threshold and best_match:
            self.hits += 1
            return best_match["response"], best_similarity
        
        self.misses += 1
        return None
    
    def set(self, query: str, response: str, metadata: dict = None):
        """
        存储查询-响应对到缓存
        同时存储精确键和向量键
        """
        embedding = self.get_embedding(query)
        
        cache_entry = {
            "query": query,
            "response": response,
            "embedding": embedding,
            "timestamp": time.time(),
            "metadata": metadata or {}
        }
        
        # 存储向量数据(用于语义搜索)
        vec_key = f"vec:{hashlib.md5(query.encode()).hexdigest()}"
        self.redis_client.setex(
            vec_key,
            self.ttl_seconds,
            json.dumps(cache_entry)
        )
        
        # 存储精确匹配键
        exact_key = self._get_cache_key(query)
        self.redis_client.setex(
            exact_key,
            self.ttl_seconds,
            json.dumps({"response": response})
        )
    
    def get_stats(self) -> dict:
        """获取缓存统计信息"""
        total = self.hits + self.misses
        return {
            "total_queries": total,
            "cache_hits": self.hits,
            "cache_misses": self.misses,
            "hit_rate": f"{self.hits/total*100:.1f}%" if total > 0 else "0%",
            "cached_entries": len(self.redis_client.keys("vec:*"))
        }
    
    def invalidate_by_pattern(self, pattern: str):
        """按模式失效缓存(如知识更新时)"""
        keys = self.redis_client.keys(f"*{pattern}*")
        if keys:
            self.redis_client.delete(*keys)
            return len(keys)
        return 0


# 集成到 Hermes Agent
class CachedHermesAgent:
    def __init__(self, llm_client, cache: SemanticCache):
        self.llm = llm_client
        self.cache = cache
    
    async def execute(self, query: str, use_cache: bool = True) -> str:
        if use_cache:
            cached = self.cache.get(query)
            if cached:
                result, similarity = cached
                print(f"缓存命中!相似度: {similarity:.3f}")
                return result
        
        # 缓存未命中,调用 LLM
        response = await self.llm.complete(query)
        
        # 存储到缓存
        if use_cache:
            self.cache.set(query, response)
        
        return response

60.5 缓存失效策略

失效策略对比

策略 触发条件 适用场景 优点 缺点
TTL(时间过期) 固定时间后过期 新闻、市场数据 简单,自动 可能在有效期内提前过时
LRU(最近最少使用) 缓存满时淘汰最久未用 通用场景 保留热点数据 不考虑数据时效性
事件驱动 知识库更新时触发 产品文档、规则库 精确控制 需要外部触发器
版本标签 模型或知识版本变化 模型升级场景 强一致性 实现复杂
相似度衰减 超过阈值时降低命中率 语义缓存 渐进式更新 需要定期重计算
class SmartCacheInvalidator:
    """
    智能缓存失效管理器
    支持多种失效策略组合
    """
    
    def __init__(self, cache: SemanticCache):
        self.cache = cache
        self.version_tag = "v1.0"
        self.invalidation_rules = []
    
    def add_time_based_rule(self, ttl_seconds: int):
        """添加基于时间的失效规则"""
        self.invalidation_rules.append({
            "type": "ttl",
            "ttl": ttl_seconds
        })
    
    def add_version_based_rule(self, current_version: str):
        """添加基于版本的失效规则"""
        if current_version != self.version_tag:
            # 版本变化时清空所有缓存
            self.cache.redis_client.flushdb()
            self.version_tag = current_version
            print(f"版本升级到 {current_version},缓存已清空")
    
    def add_keyword_invalidation(self, keywords: list[str]):
        """当知识库更新时,失效包含特定关键词的缓存"""
        for keyword in keywords:
            count = self.cache.invalidate_by_pattern(keyword)
            if count > 0:
                print(f"关键词 '{keyword}' 相关的 {count} 条缓存已失效")
    
    def warm_up_cache(self, common_queries: list[str], llm_client):
        """
        缓存预热:提前计算常见查询的结果
        在服务启动时调用,减少冷启动延迟
        """
        import asyncio
        
        async def warm_query(query):
            cached = self.cache.get(query)
            if not cached:
                response = await llm_client.complete(query)
                self.cache.set(query, response)
                print(f"预热: {query[:50]}...")
        
        asyncio.run(asyncio.gather(*[warm_query(q) for q in common_queries]))
        print(f"缓存预热完成,已预热 {len(common_queries)} 条查询")

本章小结

缓存是 Hermes Agent 性能优化的核心手段,分四层实施:

  1. KV Cache:框架层自动管理,消除单请求内的重复计算,将推理复杂度从 O(n²) 降至 O(n)
  2. Prefix Caching(vLLM):服务层优化,当多个请求共享系统提示时,KV 只计算一次,命中率超过 80% 时可节省约 40% 的推理计算量
  3. Prompt Cache(Anthropic):API 层优化,缓存命中时输入 Token 费用降低 90%,适合系统提示超过 1000 tokens 的场景
  4. 语义缓存(Redis):应用层优化,相似请求复用结果,对重复性高的任务(FAQ、固定报告模板)命中率可达 60-80%

思考题

  1. 在多租户场景下,不同用户的系统提示各不相同,如何设计 Prefix Cache 使不同租户之间共享公共部分?
  2. 当语义缓存命中率超过 90% 时,是否意味着任务过于同质化?如何平衡缓存效率和任务多样性?
  3. 对于包含实时数据(如股价、天气)的查询,如何设计"部分缓存"策略——缓存静态推理部分,动态注入实时数据?
  4. 在 A/B 测试场景下,缓存可能导致不同用户看到旧模型的结果,如何设计缓存隔离机制?
本章评分
4.6  / 5  (3 评分)

💬 留言讨论