use std::fs; use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use serde::{Deserialize, Serialize}; use crate::types::{MessageRequest, MessageResponse, Usage}; const DEFAULT_COMPLETION_TTL_SECS: u64 = 30; const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60; const DEFAULT_BREAK_MIN_DROP: u32 = 2_000; const MAX_SANITIZED_LENGTH: usize = 80; const REQUEST_FINGERPRINT_VERSION: u32 = 1; const REQUEST_FINGERPRINT_PREFIX: &str = "v1"; const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325; const FNV_PRIME: u64 = 0x0000_0100_0000_01b3; #[derive(Debug, Clone)] pub struct PromptCacheConfig { pub session_id: String, pub completion_ttl: Duration, pub prompt_ttl: Duration, pub cache_break_min_drop: u32, } impl PromptCacheConfig { #[must_use] pub fn new(session_id: impl Into) -> Self { Self { session_id: session_id.into(), completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS), prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS), cache_break_min_drop: DEFAULT_BREAK_MIN_DROP, } } } impl Default for PromptCacheConfig { fn default() -> Self { Self::new("default") } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct PromptCachePaths { pub root: PathBuf, pub session_dir: PathBuf, pub completion_dir: PathBuf, pub session_state_path: PathBuf, pub stats_path: PathBuf, } impl PromptCachePaths { #[must_use] pub fn for_session(session_id: &str) -> Self { let root = base_cache_root(); let session_dir = root.join(sanitize_path_segment(session_id)); let completion_dir = session_dir.join("completions"); Self { root, session_state_path: session_dir.join("session-state.json"), stats_path: session_dir.join("stats.json"), session_dir, completion_dir, } } #[must_use] pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf { self.completion_dir.join(format!("{request_hash}.json")) } } #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct PromptCacheStats { pub tracked_requests: u64, pub completion_cache_hits: u64, pub completion_cache_misses: u64, pub completion_cache_writes: u64, pub expected_invalidations: u64, pub unexpected_cache_breaks: u64, pub total_cache_creation_input_tokens: u64, pub total_cache_read_input_tokens: u64, pub last_cache_creation_input_tokens: Option, pub last_cache_read_input_tokens: Option, pub last_request_hash: Option, pub last_completion_cache_key: Option, pub last_break_reason: Option, pub last_cache_source: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CacheBreakEvent { pub unexpected: bool, pub reason: String, pub previous_cache_read_input_tokens: u32, pub current_cache_read_input_tokens: u32, pub token_drop: u32, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct PromptCacheRecord { pub cache_break: Option, pub stats: PromptCacheStats, } #[derive(Debug, Clone)] pub struct PromptCache { inner: Arc>, } impl PromptCache { #[must_use] pub fn new(session_id: impl Into) -> Self { Self::with_config(PromptCacheConfig::new(session_id)) } #[must_use] pub fn with_config(config: PromptCacheConfig) -> Self { let paths = PromptCachePaths::for_session(&config.session_id); let stats = read_json::(&paths.stats_path).unwrap_or_default(); let previous = read_json::(&paths.session_state_path); Self { inner: Arc::new(Mutex::new(PromptCacheInner { config, paths, stats, previous, })), } } #[must_use] pub fn paths(&self) -> PromptCachePaths { self.lock().paths.clone() } #[must_use] pub fn stats(&self) -> PromptCacheStats { self.lock().stats.clone() } #[must_use] pub fn lookup_completion(&self, request: &MessageRequest) -> Option { let request_hash = request_hash_hex(request); let (paths, ttl) = { let inner = self.lock(); (inner.paths.clone(), inner.config.completion_ttl) }; let entry_path = paths.completion_entry_path(&request_hash); let entry = read_json::(&entry_path); let Some(entry) = entry else { let mut inner = self.lock(); inner.stats.completion_cache_misses += 1; inner.stats.last_completion_cache_key = Some(request_hash); persist_state(&inner); return None; }; if entry.fingerprint_version != current_fingerprint_version() { let mut inner = self.lock(); inner.stats.completion_cache_misses += 1; inner.stats.last_completion_cache_key = Some(request_hash.clone()); let _ = fs::remove_file(entry_path); persist_state(&inner); return None; } let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs(); let mut inner = self.lock(); inner.stats.last_completion_cache_key = Some(request_hash.clone()); if expired { inner.stats.completion_cache_misses += 1; let _ = fs::remove_file(entry_path); persist_state(&inner); return None; } inner.stats.completion_cache_hits += 1; apply_usage_to_stats( &mut inner.stats, &entry.response.usage, &request_hash, "completion-cache", ); inner.previous = Some(TrackedPromptState::from_usage( request, &entry.response.usage, )); persist_state(&inner); Some(entry.response) } #[must_use] pub fn record_response( &self, request: &MessageRequest, response: &MessageResponse, ) -> PromptCacheRecord { self.record_usage_internal(request, &response.usage, Some(response)) } #[must_use] pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { self.record_usage_internal(request, usage, None) } fn record_usage_internal( &self, request: &MessageRequest, usage: &Usage, response: Option<&MessageResponse>, ) -> PromptCacheRecord { let request_hash = request_hash_hex(request); let mut inner = self.lock(); let previous = inner.previous.clone(); let current = TrackedPromptState::from_usage(request, usage); let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t); inner.stats.tracked_requests += 1; apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response"); if let Some(event) = &cache_break { if event.unexpected { inner.stats.unexpected_cache_breaks += 1; } else { inner.stats.expected_invalidations += 1; } inner.stats.last_break_reason = Some(event.reason.clone()); } inner.previous = Some(current); if let Some(response) = response { write_completion_entry(&inner.paths, &request_hash, response); inner.stats.completion_cache_writes += 1; } persist_state(&inner); PromptCacheRecord { cache_break, stats: inner.stats.clone(), } } fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> { self.inner .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) } } #[derive(Debug)] struct PromptCacheInner { config: PromptCacheConfig, paths: PromptCachePaths, stats: PromptCacheStats, previous: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct CompletionCacheEntry { cached_at_unix_secs: u64, #[serde(default = "current_fingerprint_version")] fingerprint_version: u32, response: MessageResponse, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] struct TrackedPromptState { observed_at_unix_secs: u64, #[serde(default = "current_fingerprint_version")] fingerprint_version: u32, model_hash: u64, system_hash: u64, tools_hash: u64, messages_hash: u64, cache_read_input_tokens: u32, } impl TrackedPromptState { fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { let hashes = RequestFingerprints::from_request(request); Self { observed_at_unix_secs: now_unix_secs(), fingerprint_version: current_fingerprint_version(), model_hash: hashes.model, system_hash: hashes.system, tools_hash: hashes.tools, messages_hash: hashes.messages, cache_read_input_tokens: usage.cache_read_input_tokens, } } } #[derive(Debug, Clone, Copy)] struct RequestFingerprints { model: u64, system: u64, tools: u64, messages: u64, } impl RequestFingerprints { fn from_request(request: &MessageRequest) -> Self { Self { model: hash_serializable(&request.model), system: hash_serializable(&request.system), tools: hash_serializable(&request.tools), messages: hash_serializable(&request.messages), } } } fn detect_cache_break( config: &PromptCacheConfig, previous: Option<&TrackedPromptState>, current: &TrackedPromptState, ) -> Option { let previous = previous?; if previous.fingerprint_version != current.fingerprint_version { return Some(CacheBreakEvent { unexpected: false, reason: format!( "fingerprint version changed (v{} -> v{})", previous.fingerprint_version, current.fingerprint_version ), previous_cache_read_input_tokens: previous.cache_read_input_tokens, current_cache_read_input_tokens: current.cache_read_input_tokens, token_drop: previous .cache_read_input_tokens .saturating_sub(current.cache_read_input_tokens), }); } let token_drop = previous .cache_read_input_tokens .saturating_sub(current.cache_read_input_tokens); if token_drop < config.cache_break_min_drop { return None; } let mut reasons = Vec::new(); if previous.model_hash != current.model_hash { reasons.push("model changed"); } if previous.system_hash != current.system_hash { reasons.push("system prompt changed"); } if previous.tools_hash != current.tools_hash { reasons.push("tool definitions changed"); } if previous.messages_hash != current.messages_hash { reasons.push("message payload changed"); } let elapsed = current .observed_at_unix_secs .saturating_sub(previous.observed_at_unix_secs); let (unexpected, reason) = if reasons.is_empty() { if elapsed > config.prompt_ttl.as_secs() { ( false, format!("possible prompt cache TTL expiry after {elapsed}s"), ) } else { ( true, "cache read tokens dropped while prompt fingerprint remained stable".to_string(), ) } } else { (false, reasons.join(", ")) }; Some(CacheBreakEvent { unexpected, reason, previous_cache_read_input_tokens: previous.cache_read_input_tokens, current_cache_read_input_tokens: current.cache_read_input_tokens, token_drop, }) } fn apply_usage_to_stats( stats: &mut PromptCacheStats, usage: &Usage, request_hash: &str, source: &str, ) { stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens); stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens); stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens); stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens); stats.last_request_hash = Some(request_hash.to_string()); stats.last_cache_source = Some(source.to_string()); } fn persist_state(inner: &PromptCacheInner) { let _ = ensure_cache_dirs(&inner.paths); let _ = write_json(&inner.paths.stats_path, &inner.stats); if let Some(previous) = &inner.previous { let _ = write_json(&inner.paths.session_state_path, previous); } } fn write_completion_entry( paths: &PromptCachePaths, request_hash: &str, response: &MessageResponse, ) { let _ = ensure_cache_dirs(paths); let entry = CompletionCacheEntry { cached_at_unix_secs: now_unix_secs(), fingerprint_version: current_fingerprint_version(), response: response.clone(), }; let _ = write_json(&paths.completion_entry_path(request_hash), &entry); } fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> { fs::create_dir_all(&paths.completion_dir) } fn write_json(path: &Path, value: &T) -> std::io::Result<()> { let json = serde_json::to_vec_pretty(value) .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; fs::write(path, json) } fn read_json Deserialize<'de>>(path: &Path) -> Option { let bytes = fs::read(path).ok()?; serde_json::from_slice(&bytes).ok() } fn request_hash_hex(request: &MessageRequest) -> String { format!( "{REQUEST_FINGERPRINT_PREFIX}-{:016x}", hash_serializable(request) ) } fn hash_serializable(value: &T) -> u64 { let json = serde_json::to_vec(value).unwrap_or_default(); stable_hash_bytes(&json) } fn sanitize_path_segment(value: &str) -> String { let sanitized: String = value .chars() .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' }) .collect(); if sanitized.len() <= MAX_SANITIZED_LENGTH { return sanitized; } let suffix = format!("-{:x}", hash_string(value)); format!( "{}{}", &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())], suffix ) } fn hash_string(value: &str) -> u64 { stable_hash_bytes(value.as_bytes()) } fn base_cache_root() -> PathBuf { if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") { return PathBuf::from(config_home) .join("cache") .join("prompt-cache"); } if let Some(home) = std::env::var_os("HOME") { return PathBuf::from(home) .join(".claude") .join("cache") .join("prompt-cache"); } std::env::temp_dir().join("claude-prompt-cache") } fn now_unix_secs() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .map_or(0, |duration| duration.as_secs()) } const fn current_fingerprint_version() -> u32 { REQUEST_FINGERPRINT_VERSION } fn stable_hash_bytes(bytes: &[u8]) -> u64 { let mut hash = FNV_OFFSET_BASIS; for byte in bytes { hash ^= u64::from(*byte); hash = hash.wrapping_mul(FNV_PRIME); } hash } #[cfg(test)] mod tests { use std::sync::{Mutex, OnceLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use super::{ detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, }; use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; fn test_env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) } #[test] fn path_builder_sanitizes_session_identifier() { let paths = PromptCachePaths::for_session("session:/with spaces"); let session_dir = paths .session_dir .file_name() .and_then(|value| value.to_str()) .expect("session dir name"); assert_eq!(session_dir, "session--with-spaces"); assert!(paths.completion_dir.ends_with("completions")); assert!(paths.stats_path.ends_with("stats.json")); assert!(paths.session_state_path.ends_with("session-state.json")); } #[test] fn request_fingerprint_drives_unexpected_break_detection() { let request = sample_request("same"); let previous = TrackedPromptState::from_usage( &request, &Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 6_000, output_tokens: 0, }, ); let current = TrackedPromptState::from_usage( &request, &Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 1_000, output_tokens: 0, }, ); let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) .expect("break should be detected"); assert!(event.unexpected); assert!(event.reason.contains("stable")); } #[test] fn changed_prompt_marks_break_as_expected() { let previous_request = sample_request("first"); let current_request = sample_request("second"); let previous = TrackedPromptState::from_usage( &previous_request, &Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 6_000, output_tokens: 0, }, ); let current = TrackedPromptState::from_usage( ¤t_request, &Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 1_000, output_tokens: 0, }, ); let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) .expect("break should be detected"); assert!(!event.unexpected); assert!(event.reason.contains("message payload changed")); } #[test] fn completion_cache_round_trip_persists_recent_response() { let _guard = test_env_lock(); let temp_root = std::env::temp_dir().join(format!( "prompt-cache-test-{}-{}", std::process::id(), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time") .as_nanos() )); std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let cache = PromptCache::new("unit-test-session"); let request = sample_request("cache me"); let response = sample_response(42, 12, "cached"); assert!(cache.lookup_completion(&request).is_none()); let record = cache.record_response(&request, &response); assert!(record.cache_break.is_none()); let cached = cache .lookup_completion(&request) .expect("cached response should load"); assert_eq!(cached.content, response.content); let stats = cache.stats(); assert_eq!(stats.completion_cache_hits, 1); assert_eq!(stats.completion_cache_misses, 1); assert_eq!(stats.completion_cache_writes, 1); let persisted = read_json::(&cache.paths().stats_path) .expect("stats should persist"); assert_eq!(persisted.completion_cache_hits, 1); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[test] fn distinct_requests_do_not_collide_in_completion_cache() { let _guard = test_env_lock(); let temp_root = std::env::temp_dir().join(format!( "prompt-cache-distinct-{}-{}", std::process::id(), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time") .as_nanos() )); std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let cache = PromptCache::new("distinct-request-session"); let first_request = sample_request("first"); let second_request = sample_request("second"); let response = sample_response(42, 12, "cached"); let _ = cache.record_response(&first_request, &response); assert!(cache.lookup_completion(&second_request).is_none()); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[test] fn expired_completion_entries_are_not_reused() { let _guard = test_env_lock(); let temp_root = std::env::temp_dir().join(format!( "prompt-cache-expired-{}-{}", std::process::id(), SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time") .as_nanos() )); std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let cache = PromptCache::with_config(PromptCacheConfig { session_id: "expired-session".to_string(), completion_ttl: Duration::ZERO, ..PromptCacheConfig::default() }); let request = sample_request("expire me"); let response = sample_response(7, 3, "stale"); let _ = cache.record_response(&request, &response); assert!(cache.lookup_completion(&request).is_none()); let stats = cache.stats(); assert_eq!(stats.completion_cache_hits, 0); assert_eq!(stats.completion_cache_misses, 1); std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[test] fn sanitize_path_caps_long_values() { let long_value = "x".repeat(200); let sanitized = sanitize_path_segment(&long_value); assert!(sanitized.len() <= 80); } #[test] fn request_hashes_are_versioned_and_stable() { let request = sample_request("stable"); let first = request_hash_hex(&request); let second = request_hash_hex(&request); assert_eq!(first, second); assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX)); } fn sample_request(text: &str) -> MessageRequest { MessageRequest { model: "claude-3-7-sonnet-latest".to_string(), max_tokens: 64, messages: vec![InputMessage::user_text(text)], system: Some("system".to_string()), tools: None, tool_choice: None, stream: false, } } fn sample_response( cache_read_input_tokens: u32, output_tokens: u32, text: &str, ) -> MessageResponse { MessageResponse { id: "msg_test".to_string(), kind: "message".to_string(), role: "assistant".to_string(), content: vec![OutputContentBlock::Text { text: text.to_string(), }], model: "claude-3-7-sonnet-latest".to_string(), stop_reason: Some("end_turn".to_string()), stop_sequence: None, usage: Usage { input_tokens: 10, cache_creation_input_tokens: 5, cache_read_input_tokens, output_tokens, }, request_id: Some("req_test".to_string()), } } }