Chapter 60

Caching Strategies: KV Cache and Prompt Cache

Chapter 60: Caching Strategies — KV Cache and Prompt Cache

The biggest waste in inference costs often comes from repeated computation. When Hermes Agent reprocesses the same system prompt from scratch every time, it's like an employee who has to re-read the company handbook at the start of every shift. Caching is the technology that gives an Agent "professional memory."


60.1 How KV Cache Works

The Transformer Computation Bottleneck

The core of the Transformer architecture is the attention mechanism, with O(n²) computational complexity. During inference, generating each new token requires computing attention over all previous tokens:

Query (current token) × Key (all prior tokens) → attention weights
Attention weights × Value (all prior tokens) → current token output

Problem: The Key and Value matrices for historical tokens must be recomputed for every new token generated, even when the input prefix is identical.

KV Cache solution: Cache the computed Key-Value matrices so subsequent token generation can reuse them directly.

Without KV Cache:
  Token 1: compute KV → [KV₁]
  Token 2: compute KV → [KV₁, KV₂]  (KV₁ recomputed)
  Token N: compute KV → [KV₁, ..., KVₙ]  (all prior KV recomputed)
  Total complexity: O(n²)

With KV Cache:
  Token 1: compute KV → cache [KV₁]
  Token 2: read [KV₁] from cache, compute only KV₂ → cache [KV₁, KV₂]
  Token N: read from cache, compute only KVₙ
  Total complexity: O(n)

KV Cache Memory Footprint

KV Cache requires substantial VRAM, which limits batch size and context length:

KV Cache Size = 2 × num_layers × num_heads × head_dim × seq_len × batch_size × dtype_bytes

Example: Hermes 3 70B (BF16)
  layers=80, heads=64, head_dim=128, seq_len=4096, batch_size=8, BF16=2 bytes
  KV Cache ≈ 107 GB

This explains why large-batch inference requires multi-GPU parallelism—KV Cache alone can exceed a single GPU's VRAM.


60.2 vLLM Prefix Caching Configuration

What Is Prefix Caching?

Prefix Caching extends KV Cache: when multiple requests share the same prefix (e.g., a system prompt), the KV Cache for that prefix is computed only once and reused by all subsequent requests.

Request 1: [System Prompt | User Question A]
Request 2: [System Prompt | User Question B]  → Reuse cached KV for system prompt
Request 3: [System Prompt | User Question C]  → Reuse cached KV for system prompt

vLLM Configuration

from vllm import LLM, SamplingParams

llm = LLM(
    model="NousResearch/Hermes-3-Llama-3.1-70B",
    tensor_parallel_size=4,
    gpu_memory_utilization=0.85,
    enable_prefix_caching=True,    # Key setting
    max_model_len=128000,
    max_num_seqs=256,
    block_size=16,
    swap_space=4,
)

SYSTEM_PROMPT = """You are Hermes, an autonomous AI agent by NousResearch.
[...system prompt content...]"""

sampling_params = SamplingParams(temperature=0.7, max_tokens=2048, stop=["<|im_end|>"])

def batch_inference(user_questions: list[str]) -> list[str]:
    """
    Batch inference — system prompt KV is cached after first request,
    subsequent requests reuse it automatically.
    """
    prompts = [
        f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
        f"<|im_start|>user\n{q}<|im_end|>\n"
        f"<|im_start|>assistant\n"
        for q in user_questions
    ]
    outputs = llm.generate(prompts, sampling_params)
    return [o.outputs[0].text for o in outputs]

Cache Hit Rate Monitoring

import requests

def get_vllm_cache_stats(api_url: str = "http://localhost:8000") -> dict:
    response = requests.get(f"{api_url}/metrics")
    stats = {}
    for line in response.text.split('\n'):
        if 'prefix_cache' in line.lower() and not line.startswith('#'):
            parts = line.split(' ')
            if len(parts) >= 2:
                name = parts[0].split('{')[0]
                stats[name] = float(parts[-1])
    
    hits = stats.get('vllm_cache_hits_total', 0)
    queries = stats.get('vllm_cache_queries_total', 1)
    return {
        'cache_hits': int(hits),
        'cache_queries': int(queries),
        'hit_rate': f"{hits/queries*100:.2f}%"
    }

60.3 Anthropic Prompt Cache in Hermes

How Prompt Cache Works

Anthropic's Prompt Cache (launched 2024) lets you mark specific parts of a conversation as cacheable. The server caches the computed representations for those sections, and subsequent requests reuse them.

Pricing:

Implementation

import anthropic
from typing import Optional

class HermesWithPromptCache:
    def __init__(self, api_key: str):
        self.client = anthropic.Anthropic(api_key=api_key)
        self.model = "claude-3-5-sonnet-20241022"
        
        self.system_prompt_blocks = [
            {
                "type": "text",
                "text": self._get_core_system_prompt(),
                "cache_control": {"type": "ephemeral"}  # Mark as cacheable
            }
        ]
    
    def _get_core_system_prompt(self) -> str:
        return """You are Hermes, an autonomous AI agent by NousResearch.

Rules: ①Break tasks into steps ②Verify tool outputs ③Ask when unclear ④Log decisions
Tools: {TOOLS_PLACEHOLDER}
Format: JSON for structured data, markdown for reports.
"""
    
    def chat(self, user_message: str, conversation_history: Optional[list] = None,
             task_context: Optional[str] = None) -> dict:
        messages = list(conversation_history or [])
        
        if task_context:
            messages.append({
                "role": "user",
                "content": [
                    {"type": "text", "text": f"Task Context:\n{task_context}",
                     "cache_control": {"type": "ephemeral"}},
                    {"type": "text", "text": user_message}
                ]
            })
        else:
            messages.append({"role": "user", "content": user_message})
        
        response = self.client.messages.create(
            model=self.model,
            max_tokens=4096,
            system=self.system_prompt_blocks,
            messages=messages
        )
        
        usage = response.usage
        cache_stats = {
            "input_tokens": usage.input_tokens,
            "output_tokens": usage.output_tokens,
            "cache_creation_tokens": getattr(usage, 'cache_creation_input_tokens', 0),
            "cache_read_tokens": getattr(usage, 'cache_read_input_tokens', 0),
        }
        cache_stats["cost_usd"] = self._calculate_cost(cache_stats)
        cache_stats["savings_usd"] = self._calculate_savings(cache_stats)
        
        return {"content": response.content[0].text, "cache_stats": cache_stats}
    
    def _calculate_cost(self, u: dict) -> float:
        INPUT = 3.00 / 1e6; OUTPUT = 15.00 / 1e6
        return round(
            u["input_tokens"] * INPUT + u["output_tokens"] * OUTPUT +
            u["cache_creation_tokens"] * (INPUT * 1.25) +
            u["cache_read_tokens"] * (INPUT * 0.10), 6
        )
    
    def _calculate_savings(self, u: dict) -> float:
        return round(u["cache_read_tokens"] * (3.00 / 1e6) * 0.90, 6)

# Usage
agent = HermesWithPromptCache(api_key="your_api_key")

# First call: cache write (slightly more expensive)
r1 = agent.chat("Analyze Hermes Agent's architecture")
print(f"Call 1 — Cache written: {r1['cache_stats']['cache_creation_tokens']} tokens")
print(f"Cost: ${r1['cache_stats']['cost_usd']}")

# Second call: cache hit (90% cheaper)
r2 = agent.chat("List supported tool categories")
print(f"Call 2 — Cache hit: {r2['cache_stats']['cache_read_tokens']} tokens")
print(f"Savings: ${r2['cache_stats']['savings_usd']}")

60.4 Redis Semantic Cache

Semantic Cache vs. Exact-Match Cache

Traditional caches require exact string matches. Semantic caching uses vector similarity to allow similar (but not identical) requests to reuse cached results:

Exact cache:
  Cached: "How do I install Python 3.11?"
  Query:  "How do I install Python 3.11?" → HIT
  Query:  "Steps to install Python?"       → MISS (different text)

Semantic cache:
  Cached: "How do I install Python 3.11?"
  Query:  "Steps to install Python?"       → HIT (semantically similar)
  Query:  "Python installation guide"      → HIT (above similarity threshold)

Implementation

import redis
import numpy as np
import json
import hashlib
import time
from typing import Optional, Tuple

class SemanticCache:
    def __init__(self, redis_url: str = "redis://localhost:6379",
                 similarity_threshold: float = 0.92, ttl_seconds: int = 3600):
        self.redis_client = redis.from_url(redis_url)
        self.similarity_threshold = similarity_threshold
        self.ttl_seconds = ttl_seconds
        self.hits = 0
        self.misses = 0
    
    def get_embedding(self, text: str) -> list:
        import openai
        r = openai.embeddings.create(model="text-embedding-3-small", input=text)
        return r.data[0].embedding
    
    def cosine_similarity(self, v1: list, v2: list) -> float:
        a, b = np.array(v1), np.array(v2)
        return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
    
    def get(self, query: str) -> Optional[Tuple[str, float]]:
        # Try exact match first
        exact_key = f"exact:{hashlib.md5(query.encode()).hexdigest()}"
        exact = self.redis_client.get(exact_key)
        if exact:
            self.hits += 1
            return json.loads(exact)["response"], 1.0
        
        # Semantic similarity search
        query_emb = self.get_embedding(query)
        best_match, best_sim = None, 0.0
        
        for key in self.redis_client.keys("vec:*"):
            data = json.loads(self.redis_client.get(key))
            sim = self.cosine_similarity(query_emb, data["embedding"])
            if sim > best_sim:
                best_sim = sim
                best_match = data
        
        if best_sim >= self.similarity_threshold and best_match:
            self.hits += 1
            return best_match["response"], best_sim
        
        self.misses += 1
        return None
    
    def set(self, query: str, response: str, metadata: dict = None):
        embedding = self.get_embedding(query)
        entry = {"query": query, "response": response, "embedding": embedding,
                 "timestamp": time.time(), "metadata": metadata or {}}
        
        key_hash = hashlib.md5(query.encode()).hexdigest()
        self.redis_client.setex(f"vec:{key_hash}", self.ttl_seconds, json.dumps(entry))
        self.redis_client.setex(f"exact:{key_hash}", self.ttl_seconds,
                                json.dumps({"response": response}))
    
    def get_stats(self) -> dict:
        total = self.hits + self.misses
        return {
            "total_queries": total,
            "cache_hits": self.hits,
            "hit_rate": f"{self.hits/total*100:.1f}%" if total > 0 else "0%",
            "cached_entries": len(self.redis_client.keys("vec:*"))
        }
    
    def invalidate_by_pattern(self, pattern: str) -> int:
        keys = self.redis_client.keys(f"*{pattern}*")
        if keys:
            self.redis_client.delete(*keys)
            return len(keys)
        return 0

60.5 Cache Invalidation Strategies

Strategy Trigger Use Case Pros Cons
TTL (time expiry) Fixed time after write News, market data Simple, automatic May expire valid data early
LRU Cache full, evict least-recently-used General purpose Retains hot data Ignores data freshness
Event-driven Knowledge base update Product docs, rules Precise control Requires external trigger
Version tag Model or data version change Model upgrades Strong consistency Complex implementation
Similarity decay Reduce hit rate over time Semantic caches Gradual refresh Requires periodic recompute
class SmartCacheInvalidator:
    def __init__(self, cache: SemanticCache):
        self.cache = cache
        self.version_tag = "v1.0"
    
    def on_version_update(self, new_version: str):
        if new_version != self.version_tag:
            self.cache.redis_client.flushdb()
            self.version_tag = new_version
            print(f"Version upgraded to {new_version}: cache cleared")
    
    def on_knowledge_update(self, updated_topics: list):
        for topic in updated_topics:
            count = self.cache.invalidate_by_pattern(topic)
            if count > 0:
                print(f"Invalidated {count} cache entries related to '{topic}'")
    
    def warm_up(self, common_queries: list, llm_client):
        """Pre-populate cache with common queries at service startup."""
        for query in common_queries:
            if not self.cache.get(query):
                # Call LLM and cache result
                response = llm_client.complete_sync(query)
                self.cache.set(query, response)
        print(f"Cache warmed up with {len(common_queries)} queries")

Chapter Summary

Caching is the core performance optimization technique for Hermes Agent, implemented in four layers:

  1. KV Cache: Framework-level, automatic management eliminates redundant computation within a single request, reducing complexity from O(n²) to O(n)
  2. Prefix Caching (vLLM): Server-level optimization — when multiple requests share a system prompt, KV is computed once; at 80%+ hit rate, saves ~40% of inference compute
  3. Prompt Cache (Anthropic): API-level optimization — 90% discount on cached input tokens; ideal when system prompts exceed 1,000 tokens
  4. Semantic Cache (Redis): Application-level optimization — similar requests reuse results; for repetitive tasks (FAQ, fixed report templates), hit rates reach 60–80%

Review Questions

  1. In a multi-tenant scenario where each user has a different system prompt, how would you design Prefix Cache so tenants can share common sections?
  2. If a semantic cache hit rate exceeds 90%, does that indicate the task set is too homogeneous? How do you balance cache efficiency against task diversity?
  3. For queries containing real-time data (stock prices, weather), how would you design a "partial cache" strategy—caching static reasoning while injecting live data dynamically?
  4. In A/B test scenarios, caching could cause different users to receive results from the old model. How would you design cache isolation to prevent this?
Rate this chapter
4.6  / 5  (3 ratings)

💬 Comments