diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index fe5cc74..bcf3e1b 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -21,7 +21,8 @@ pub use prompt_cache::{ pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource}; pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; pub use providers::{ - detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, + detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override, + resolve_model_alias, ProviderKind, }; pub use sse::{parse_frame, SseParser}; pub use types::{ diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index c5ed567..36c01fd 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -204,6 +204,14 @@ pub fn max_tokens_for_model(model: &str) -> u32 { ) } +/// Returns the effective max output tokens for a model, preferring a plugin +/// override when present. Falls back to [`max_tokens_for_model`] when the +/// override is `None`. +#[must_use] +pub fn max_tokens_for_model_with_override(model: &str, plugin_override: Option) -> u32 { + plugin_override.unwrap_or_else(|| max_tokens_for_model(model)) +} + #[must_use] pub fn model_token_limit(model: &str) -> Option { let canonical = resolve_model_alias(model); @@ -323,8 +331,9 @@ mod tests { }; use super::{ - detect_provider_kind, load_dotenv_file, max_tokens_for_model, model_token_limit, - parse_dotenv, preflight_message_request, resolve_model_alias, ProviderKind, + detect_provider_kind, load_dotenv_file, max_tokens_for_model, + max_tokens_for_model_with_override, model_token_limit, parse_dotenv, + preflight_message_request, resolve_model_alias, ProviderKind, }; #[test] @@ -349,6 +358,56 @@ mod tests { assert_eq!(max_tokens_for_model("grok-3"), 64_000); } + #[test] + fn plugin_config_max_output_tokens_overrides_model_default() { + // given + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + let root = std::env::temp_dir().join(format!("api-plugin-max-tokens-{nanos}")); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + std::fs::create_dir_all(&home).expect("home config dir"); + std::fs::write( + home.join("settings.json"), + r#"{ + "plugins": { + "maxOutputTokens": 12345 + } + }"#, + ) + .expect("write plugin settings"); + + // when + let loaded = runtime::ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + let plugin_override = loaded.plugins().max_output_tokens(); + let effective = max_tokens_for_model_with_override("claude-opus-4-6", plugin_override); + + // then + assert_eq!(plugin_override, Some(12345)); + assert_eq!(effective, 12345); + assert_ne!(effective, max_tokens_for_model("claude-opus-4-6")); + + std::fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn max_tokens_for_model_with_override_falls_back_when_plugin_unset() { + // given + let plugin_override: Option = None; + + // when + let effective = max_tokens_for_model_with_override("claude-opus-4-6", plugin_override); + + // then + assert_eq!(effective, max_tokens_for_model("claude-opus-4-6")); + assert_eq!(effective, 32_000); + } + #[test] fn returns_context_window_metadata_for_supported_models() { assert_eq!(