第 60 章
缓存策略:KV Cache 与 Prompt Cache
第60章:缓存策略:KV Cache 与 Prompt Cache
推理成本中最大的浪费,往往来自重复计算。当 Hermes Agent 每次都从头处理相同的系统提示时,它就像一个每次上班都要重新记住公司规章制度的员工。缓存,就是让 Agent 拥有"职业记忆"的技术。
60.1 KV Cache 工作原理
Transformer 的计算瓶颈
Transformer 架构的核心是注意力机制(Attention),其计算复杂度为 O(n²)。在推理阶段,每生成一个新 token,都需要对之前所有 token 进行注意力计算:
Query (当前 token) × Key (所有历史 token) → 注意力权重
注意力权重 × Value (所有历史 token) → 当前 token 的输出
问题:历史 token 的 Key 和 Value 矩阵在每次生成新 token 时都需要重新计算,即使输入前缀完全相同。
KV Cache 解决方案:将已计算的 Key-Value 矩阵缓存起来,后续 token 生成时直接复用。
不使用 KV Cache:
Token 1: 计算 KV → [KV₁]
Token 2: 计算 KV → [KV₁, KV₂] (KV₁ 重复计算)
Token N: 计算 KV → [KV₁, ..., KVₙ] (前N-1个KV重复计算)
总计算量:O(n²)
使用 KV Cache:
Token 1: 计算 KV → 缓存 [KV₁]
Token 2: 从缓存读取 [KV₁],只计算 KV₂ → 缓存 [KV₁, KV₂]
Token N: 从缓存读取,只计算 KVₙ
总计算量:O(n)
KV Cache 的显存占用
KV Cache 需要大量显存,这是限制批处理大小和上下文长度的关键因素:
KV Cache 大小 = 2 × num_layers × num_heads × head_dim × seq_len × batch_size × dtype_bytes
以 Hermes 3 70B(BF16)为例:
layers = 80
heads = 64
head_dim = 128
seq_len = 4096
batch_size = 8
dtype = BF16 (2 bytes)
KV Cache = 2 × 80 × 64 × 128 × 4096 × 8 × 2 bytes
≈ 107 GB
这解释了为什么大批量推理需要多卡并行——仅 KV Cache 就可能超过单卡显存。
60.2 vLLM 的 Prefix Caching 配置
什么是 Prefix Caching
Prefix Caching(前缀缓存)是在 KV Cache 基础上的进一步优化:当多个请求共享相同的前缀(如系统提示),该前缀的 KV Cache 只计算一次,后续请求直接复用。
请求1: [系统提示 | 用户问题A]
请求2: [系统提示 | 用户问题B] → 系统提示部分直接复用缓存
请求3: [系统提示 | 用户问题C] → 系统提示部分直接复用缓存
vLLM Prefix Caching 配置
from vllm import LLM, SamplingParams
# 启用 Prefix Caching 的 vLLM 配置
llm = LLM(
model="NousResearch/Hermes-3-Llama-3.1-70B",
tensor_parallel_size=4, # 4卡并行
gpu_memory_utilization=0.85, # GPU 显存使用率
enable_prefix_caching=True, # 启用前缀缓存(关键配置)
max_model_len=128000, # 最大上下文长度
max_num_seqs=256, # 最大并发序列数
block_size=16, # KV Cache 块大小(token数)
swap_space=4, # CPU swap 空间(GB)
)
# 定义系统提示(所有请求共享)
SYSTEM_PROMPT = """You are Hermes, an autonomous AI agent by NousResearch.
[...系统提示内容...]"""
sampling_params = SamplingParams(
temperature=0.7,
max_tokens=2048,
stop=["<|im_end|>"]
)
def batch_inference(user_questions: list[str]) -> list[str]:
"""
批量推理,系统提示部分会被自动缓存
第一个请求计算系统提示 KV,后续请求直接复用
"""
prompts = [
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{q}<|im_end|>\n"
f"<|im_start|>assistant\n"
for q in user_questions
]
outputs = llm.generate(prompts, sampling_params)
return [output.outputs[0].text for output in outputs]
Prefix Caching 命中率监控
import requests
import json
def get_vllm_cache_stats(api_url: str = "http://localhost:8000") -> dict:
"""
获取 vLLM 的缓存命中率统计
需要 vLLM 的 metrics API 端点
"""
response = requests.get(f"{api_url}/metrics")
# vLLM 通过 Prometheus 格式导出指标
metrics_text = response.text
stats = {}
for line in metrics_text.split('\n'):
if 'prefix_cache' in line.lower() and not line.startswith('#'):
parts = line.split(' ')
if len(parts) >= 2:
metric_name = parts[0].split('{')[0]
value = float(parts[-1])
stats[metric_name] = value
# 计算命中率
hits = stats.get('vllm_cache_hits_total', 0)
queries = stats.get('vllm_cache_queries_total', 1)
return {
'cache_hits': int(hits),
'cache_queries': int(queries),
'hit_rate': f"{hits/queries*100:.2f}%",
'raw_stats': stats
}
# 定期监控
import time
def monitor_cache_performance(interval: int = 60):
while True:
stats = get_vllm_cache_stats()
print(f"[{time.strftime('%H:%M:%S')}] Cache Hit Rate: {stats['hit_rate']}")
time.sleep(interval)
60.3 Anthropic Prompt Cache 在 Hermes 中的使用
Prompt Cache 工作原理
Anthropic 的 Prompt Cache(2024年推出)允许将对话的特定部分标记为可缓存,服务端缓存这些内容的计算结果,后续请求可以复用。
定价说明:
- 缓存写入:标准输入价格 × 1.25(写入时略贵)
- 缓存命中:标准输入价格 × 0.10(读取时省 90%)
在 Hermes 混合架构中使用 Prompt Cache
import anthropic
from typing import Optional
class HermesWithPromptCache:
"""
在 Anthropic API 上运行的 Hermes Agent(带 Prompt Cache)
注:当 Hermes 4 通过 Anthropic 兼容接口部署时可用
"""
def __init__(self, api_key: str):
self.client = anthropic.Anthropic(api_key=api_key)
self.model = "claude-3-5-sonnet-20241022" # 或兼容 Hermes 的端点
# 构建可缓存的系统提示
self.system_prompt_blocks = [
{
"type": "text",
"text": self._get_core_system_prompt(),
"cache_control": {"type": "ephemeral"} # 标记为可缓存
}
]
def _get_core_system_prompt(self) -> str:
return """You are Hermes, an autonomous AI agent developed by NousResearch.
## Core Capabilities
- Complex reasoning and multi-step planning
- Tool use and code execution
- Research and information synthesis
- Structured output generation
## Operational Rules
① Always verify tool outputs before proceeding
② Ask for clarification when task requirements are ambiguous
③ Log significant decisions in structured format
④ Prioritize accuracy over speed
⑤ Respect user privacy and data security
## Tool Usage Protocol
- Call only registered tools
- Parse all arguments strictly
- Handle errors gracefully with retry logic
- Report tool failures immediately
## Output Format
- Use markdown for reports and documentation
- Use JSON for structured data
- Use plain text for conversational responses
"""
def chat(
self,
user_message: str,
conversation_history: Optional[list] = None,
task_context: Optional[str] = None
) -> dict:
"""
发送消息,系统提示自动走缓存路径
Returns:
包含响应内容和缓存统计的字典
"""
messages = []
# 添加对话历史
if conversation_history:
messages.extend(conversation_history)
# 如果有任务上下文,作为可缓存块注入
if task_context:
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": f"Task Context:\n{task_context}",
"cache_control": {"type": "ephemeral"}
},
{
"type": "text",
"text": user_message
}
]
})
else:
messages.append({"role": "user", "content": user_message})
response = self.client.messages.create(
model=self.model,
max_tokens=4096,
system=self.system_prompt_blocks,
messages=messages
)
# 提取缓存统计
usage = response.usage
cache_stats = {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"cache_creation_input_tokens": getattr(usage, 'cache_creation_input_tokens', 0),
"cache_read_input_tokens": getattr(usage, 'cache_read_input_tokens', 0),
}
# 计算实际成本
cache_stats["cost_estimate"] = self._calculate_cost(cache_stats)
cache_stats["savings_vs_no_cache"] = self._calculate_savings(cache_stats)
return {
"content": response.content[0].text,
"cache_stats": cache_stats
}
def _calculate_cost(self, usage: dict) -> float:
"""估算本次请求实际成本(USD)"""
INPUT_PRICE = 3.00 / 1_000_000 # $3/M tokens
OUTPUT_PRICE = 15.00 / 1_000_000 # $15/M tokens
CACHE_WRITE_PRICE = INPUT_PRICE * 1.25
CACHE_READ_PRICE = INPUT_PRICE * 0.10
cost = (
usage["input_tokens"] * INPUT_PRICE +
usage["output_tokens"] * OUTPUT_PRICE +
usage["cache_creation_input_tokens"] * CACHE_WRITE_PRICE +
usage["cache_read_input_tokens"] * CACHE_READ_PRICE
)
return round(cost, 6)
def _calculate_savings(self, usage: dict) -> float:
"""计算与不使用缓存相比节省的费用"""
INPUT_PRICE = 3.00 / 1_000_000
cache_read = usage["cache_read_input_tokens"]
# 没有缓存时这些 token 按标准价格计费
savings = cache_read * INPUT_PRICE * 0.90 # 节省 90%
return round(savings, 6)
# 实战示例:跨多个请求复用系统提示缓存
agent = HermesWithPromptCache(api_key="your_api_key")
# 第一次请求:缓存写入(略贵)
result1 = agent.chat("分析 Hermes Agent 的技术架构")
print(f"第1次调用: 缓存写入 {result1['cache_stats']['cache_creation_input_tokens']} tokens")
print(f"实际成本: ${result1['cache_stats']['cost_estimate']}")
# 第二次请求:缓存命中(省 90%)
result2 = agent.chat("列出 Hermes Agent 支持的工具类型")
print(f"第2次调用: 缓存命中 {result2['cache_stats']['cache_read_input_tokens']} tokens")
print(f"节省费用: ${result2['cache_stats']['savings_vs_no_cache']}")
60.4 Redis 层语义缓存
语义缓存 vs KV 精确匹配缓存
传统缓存要求请求完全匹配才能命中。语义缓存(Semantic Cache)通过向量相似度,允许相似但不完全相同的请求复用缓存结果:
精确匹配缓存:
缓存:"如何安装 Python 3.11?"
查询:"如何安装 Python 3.11?" → ✅ 命中
查询:"Python 3.11 怎么安装?" → ❌ 未命中(文字不同)
语义缓存:
缓存:"如何安装 Python 3.11?"
查询:"Python 3.11 怎么安装?" → ✅ 命中(语义相似)
查询:"安装 Python 的步骤是什么?" → ✅ 命中(相似度超过阈值)
实现方案
import redis
import numpy as np
import json
import hashlib
from typing import Optional, Tuple
import time
class SemanticCache:
"""
基于 Redis 的语义缓存
使用向量相似度匹配相似请求
"""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
similarity_threshold: float = 0.92,
ttl_seconds: int = 3600, # 缓存有效期 1 小时
embedding_model: str = "text-embedding-3-small"
):
self.redis_client = redis.from_url(redis_url)
self.similarity_threshold = similarity_threshold
self.ttl_seconds = ttl_seconds
self.embedding_model = embedding_model
# 统计计数器
self.hits = 0
self.misses = 0
def get_embedding(self, text: str) -> list[float]:
"""
获取文本的向量嵌入
实际使用时替换为真实的 embedding API 调用
"""
# 示例:使用 OpenAI embedding API
import openai
response = openai.embeddings.create(
model=self.embedding_model,
input=text
)
return response.data[0].embedding
def cosine_similarity(self, vec1: list, vec2: list) -> float:
"""计算余弦相似度"""
v1 = np.array(vec1)
v2 = np.array(vec2)
return float(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))
def _get_cache_key(self, query: str) -> str:
"""基于查询生成精确匹配缓存键"""
return f"exact:{hashlib.md5(query.encode()).hexdigest()}"
def get(self, query: str) -> Optional[Tuple[str, float]]:
"""
查找缓存
Returns:
(缓存结果, 相似度分数) 如果命中
None 如果未命中
"""
# 先尝试精确匹配
exact_key = self._get_cache_key(query)
exact_result = self.redis_client.get(exact_key)
if exact_result:
self.hits += 1
return json.loads(exact_result)["response"], 1.0
# 语义相似度匹配
query_embedding = self.get_embedding(query)
# 获取所有缓存的向量键
vector_keys = self.redis_client.keys("vec:*")
best_match = None
best_similarity = 0.0
for key in vector_keys:
cached_data = json.loads(self.redis_client.get(key))
cached_embedding = cached_data["embedding"]
similarity = self.cosine_similarity(query_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_match = cached_data
if best_similarity >= self.similarity_threshold and best_match:
self.hits += 1
return best_match["response"], best_similarity
self.misses += 1
return None
def set(self, query: str, response: str, metadata: dict = None):
"""
存储查询-响应对到缓存
同时存储精确键和向量键
"""
embedding = self.get_embedding(query)
cache_entry = {
"query": query,
"response": response,
"embedding": embedding,
"timestamp": time.time(),
"metadata": metadata or {}
}
# 存储向量数据(用于语义搜索)
vec_key = f"vec:{hashlib.md5(query.encode()).hexdigest()}"
self.redis_client.setex(
vec_key,
self.ttl_seconds,
json.dumps(cache_entry)
)
# 存储精确匹配键
exact_key = self._get_cache_key(query)
self.redis_client.setex(
exact_key,
self.ttl_seconds,
json.dumps({"response": response})
)
def get_stats(self) -> dict:
"""获取缓存统计信息"""
total = self.hits + self.misses
return {
"total_queries": total,
"cache_hits": self.hits,
"cache_misses": self.misses,
"hit_rate": f"{self.hits/total*100:.1f}%" if total > 0 else "0%",
"cached_entries": len(self.redis_client.keys("vec:*"))
}
def invalidate_by_pattern(self, pattern: str):
"""按模式失效缓存(如知识更新时)"""
keys = self.redis_client.keys(f"*{pattern}*")
if keys:
self.redis_client.delete(*keys)
return len(keys)
return 0
# 集成到 Hermes Agent
class CachedHermesAgent:
def __init__(self, llm_client, cache: SemanticCache):
self.llm = llm_client
self.cache = cache
async def execute(self, query: str, use_cache: bool = True) -> str:
if use_cache:
cached = self.cache.get(query)
if cached:
result, similarity = cached
print(f"缓存命中!相似度: {similarity:.3f}")
return result
# 缓存未命中,调用 LLM
response = await self.llm.complete(query)
# 存储到缓存
if use_cache:
self.cache.set(query, response)
return response
60.5 缓存失效策略
失效策略对比
| 策略 | 触发条件 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|---|
| TTL(时间过期) | 固定时间后过期 | 新闻、市场数据 | 简单,自动 | 可能在有效期内提前过时 |
| LRU(最近最少使用) | 缓存满时淘汰最久未用 | 通用场景 | 保留热点数据 | 不考虑数据时效性 |
| 事件驱动 | 知识库更新时触发 | 产品文档、规则库 | 精确控制 | 需要外部触发器 |
| 版本标签 | 模型或知识版本变化 | 模型升级场景 | 强一致性 | 实现复杂 |
| 相似度衰减 | 超过阈值时降低命中率 | 语义缓存 | 渐进式更新 | 需要定期重计算 |
class SmartCacheInvalidator:
"""
智能缓存失效管理器
支持多种失效策略组合
"""
def __init__(self, cache: SemanticCache):
self.cache = cache
self.version_tag = "v1.0"
self.invalidation_rules = []
def add_time_based_rule(self, ttl_seconds: int):
"""添加基于时间的失效规则"""
self.invalidation_rules.append({
"type": "ttl",
"ttl": ttl_seconds
})
def add_version_based_rule(self, current_version: str):
"""添加基于版本的失效规则"""
if current_version != self.version_tag:
# 版本变化时清空所有缓存
self.cache.redis_client.flushdb()
self.version_tag = current_version
print(f"版本升级到 {current_version},缓存已清空")
def add_keyword_invalidation(self, keywords: list[str]):
"""当知识库更新时,失效包含特定关键词的缓存"""
for keyword in keywords:
count = self.cache.invalidate_by_pattern(keyword)
if count > 0:
print(f"关键词 '{keyword}' 相关的 {count} 条缓存已失效")
def warm_up_cache(self, common_queries: list[str], llm_client):
"""
缓存预热:提前计算常见查询的结果
在服务启动时调用,减少冷启动延迟
"""
import asyncio
async def warm_query(query):
cached = self.cache.get(query)
if not cached:
response = await llm_client.complete(query)
self.cache.set(query, response)
print(f"预热: {query[:50]}...")
asyncio.run(asyncio.gather(*[warm_query(q) for q in common_queries]))
print(f"缓存预热完成,已预热 {len(common_queries)} 条查询")
本章小结
缓存是 Hermes Agent 性能优化的核心手段,分四层实施:
- KV Cache:框架层自动管理,消除单请求内的重复计算,将推理复杂度从 O(n²) 降至 O(n)
- Prefix Caching(vLLM):服务层优化,当多个请求共享系统提示时,KV 只计算一次,命中率超过 80% 时可节省约 40% 的推理计算量
- Prompt Cache(Anthropic):API 层优化,缓存命中时输入 Token 费用降低 90%,适合系统提示超过 1000 tokens 的场景
- 语义缓存(Redis):应用层优化,相似请求复用结果,对重复性高的任务(FAQ、固定报告模板)命中率可达 60-80%
思考题
- 在多租户场景下,不同用户的系统提示各不相同,如何设计 Prefix Cache 使不同租户之间共享公共部分?
- 当语义缓存命中率超过 90% 时,是否意味着任务过于同质化?如何平衡缓存效率和任务多样性?
- 对于包含实时数据(如股价、天气)的查询,如何设计"部分缓存"策略——缓存静态推理部分,动态注入实时数据?
- 在 A/B 测试场景下,缓存可能导致不同用户看到旧模型的结果,如何设计缓存隔离机制?