use crate::session::Session; const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0; const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0; const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75; const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct TokenUsage { pub input_tokens: u32, pub output_tokens: u32, pub cache_creation_input_tokens: u32, pub cache_read_input_tokens: u32, } #[derive(Debug, Clone, Copy, PartialEq)] pub struct UsageCostEstimate { pub input_cost_usd: f64, pub output_cost_usd: f64, pub cache_creation_cost_usd: f64, pub cache_read_cost_usd: f64, } impl UsageCostEstimate { #[must_use] pub fn total_cost_usd(self) -> f64 { self.input_cost_usd + self.output_cost_usd + self.cache_creation_cost_usd + self.cache_read_cost_usd } } impl TokenUsage { #[must_use] pub fn total_tokens(self) -> u32 { self.input_tokens + self.output_tokens + self.cache_creation_input_tokens + self.cache_read_input_tokens } #[must_use] pub fn estimate_cost_usd(self) -> UsageCostEstimate { UsageCostEstimate { input_cost_usd: cost_for_tokens(self.input_tokens, DEFAULT_INPUT_COST_PER_MILLION), output_cost_usd: cost_for_tokens(self.output_tokens, DEFAULT_OUTPUT_COST_PER_MILLION), cache_creation_cost_usd: cost_for_tokens( self.cache_creation_input_tokens, DEFAULT_CACHE_CREATION_COST_PER_MILLION, ), cache_read_cost_usd: cost_for_tokens( self.cache_read_input_tokens, DEFAULT_CACHE_READ_COST_PER_MILLION, ), } } #[must_use] pub fn summary_lines(self, label: &str) -> Vec { let cost = self.estimate_cost_usd(); vec![ format!( "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}", self.total_tokens(), self.input_tokens, self.output_tokens, self.cache_creation_input_tokens, self.cache_read_input_tokens, format_usd(cost.total_cost_usd()), ), format!( " cost breakdown: input={} output={} cache_write={} cache_read={}", format_usd(cost.input_cost_usd), format_usd(cost.output_cost_usd), format_usd(cost.cache_creation_cost_usd), format_usd(cost.cache_read_cost_usd), ), ] } } fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 { f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens } #[must_use] pub fn format_usd(amount: f64) -> String { format!("${amount:.4}") } #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct UsageTracker { latest_turn: TokenUsage, cumulative: TokenUsage, turns: u32, } impl UsageTracker { #[must_use] pub fn new() -> Self { Self::default() } #[must_use] pub fn from_session(session: &Session) -> Self { let mut tracker = Self::new(); for message in &session.messages { if let Some(usage) = message.usage { tracker.record(usage); } } tracker } pub fn record(&mut self, usage: TokenUsage) { self.latest_turn = usage; self.cumulative.input_tokens += usage.input_tokens; self.cumulative.output_tokens += usage.output_tokens; self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens; self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens; self.turns += 1; } #[must_use] pub fn current_turn_usage(&self) -> TokenUsage { self.latest_turn } #[must_use] pub fn cumulative_usage(&self) -> TokenUsage { self.cumulative } #[must_use] pub fn turns(&self) -> u32 { self.turns } } #[cfg(test)] mod tests { use super::{format_usd, TokenUsage, UsageTracker}; use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; #[test] fn tracks_true_cumulative_usage() { let mut tracker = UsageTracker::new(); tracker.record(TokenUsage { input_tokens: 10, output_tokens: 4, cache_creation_input_tokens: 2, cache_read_input_tokens: 1, }); tracker.record(TokenUsage { input_tokens: 20, output_tokens: 6, cache_creation_input_tokens: 3, cache_read_input_tokens: 2, }); assert_eq!(tracker.turns(), 2); assert_eq!(tracker.current_turn_usage().input_tokens, 20); assert_eq!(tracker.current_turn_usage().output_tokens, 6); assert_eq!(tracker.cumulative_usage().output_tokens, 10); assert_eq!(tracker.cumulative_usage().input_tokens, 30); assert_eq!(tracker.cumulative_usage().total_tokens(), 48); } #[test] fn computes_cost_summary_lines() { let usage = TokenUsage { input_tokens: 1_000_000, output_tokens: 500_000, cache_creation_input_tokens: 100_000, cache_read_input_tokens: 200_000, }; let cost = usage.estimate_cost_usd(); assert_eq!(format_usd(cost.input_cost_usd), "$15.0000"); assert_eq!(format_usd(cost.output_cost_usd), "$37.5000"); let lines = usage.summary_lines("usage"); assert!(lines[0].contains("estimated_cost=$54.6750")); assert!(lines[1].contains("cache_read=$0.3000")); } #[test] fn reconstructs_usage_from_session_messages() { let session = Session { version: 1, messages: vec![ConversationMessage { role: MessageRole::Assistant, blocks: vec![ContentBlock::Text { text: "done".to_string(), }], usage: Some(TokenUsage { input_tokens: 5, output_tokens: 2, cache_creation_input_tokens: 1, cache_read_input_tokens: 0, }), }], }; let tracker = UsageTracker::from_session(&session); assert_eq!(tracker.turns(), 1); assert_eq!(tracker.cumulative_usage().total_tokens(), 8); } }