From d509f16b5a250efb6ee348de51a69de9e289e6a5 Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Tue, 7 Apr 2026 14:51:12 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20b5-skip-perms-flag=20=E2=80=94=20batch?= =?UTF-8?q?=205=20upstream=20parity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rust/crates/runtime/src/config.rs | 120 +++++++++ rust/crates/runtime/src/lib.rs | 4 +- rust/crates/tools/src/lib.rs | 429 ++++++++++++++++++++++-------- 3 files changed, 444 insertions(+), 109 deletions(-) diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index 3120ac5..1159b54 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -61,6 +61,16 @@ pub struct RuntimeFeatureConfig { permission_mode: Option, permission_rules: RuntimePermissionRuleConfig, sandbox: SandboxConfig, + provider_fallbacks: ProviderFallbackConfig, +} + +/// Ordered chain of fallback model identifiers used when the primary +/// provider returns a retryable failure (429/500/503/etc.). The chain is +/// strict: each entry is tried in order until one succeeds. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ProviderFallbackConfig { + primary: Option, + fallbacks: Vec, } /// Hook command lists grouped by lifecycle stage. @@ -283,6 +293,7 @@ impl ConfigLoader { permission_mode: parse_optional_permission_mode(&merged_value)?, permission_rules: parse_optional_permission_rules(&merged_value)?, sandbox: parse_optional_sandbox_config(&merged_value)?, + provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?, }; Ok(RuntimeConfig { @@ -367,6 +378,11 @@ impl RuntimeConfig { pub fn sandbox(&self) -> &SandboxConfig { &self.feature_config.sandbox } + + #[must_use] + pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig { + &self.feature_config.provider_fallbacks + } } impl RuntimeFeatureConfig { @@ -421,6 +437,33 @@ impl RuntimeFeatureConfig { pub fn sandbox(&self) -> &SandboxConfig { &self.sandbox } + + #[must_use] + pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig { + &self.provider_fallbacks + } +} + +impl ProviderFallbackConfig { + #[must_use] + pub fn new(primary: Option, fallbacks: Vec) -> Self { + Self { primary, fallbacks } + } + + #[must_use] + pub fn primary(&self) -> Option<&str> { + self.primary.as_deref() + } + + #[must_use] + pub fn fallbacks(&self) -> &[String] { + &self.fallbacks + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.fallbacks.is_empty() + } } impl RuntimePluginConfig { @@ -776,6 +819,23 @@ fn parse_optional_sandbox_config(root: &JsonValue) -> Result Result { + let Some(object) = root.as_object() else { + return Ok(ProviderFallbackConfig::default()); + }; + let Some(value) = object.get("providerFallbacks") else { + return Ok(ProviderFallbackConfig::default()); + }; + let entry = expect_object(value, "merged settings.providerFallbacks")?; + let primary = + optional_string(entry, "primary", "merged settings.providerFallbacks")?.map(str::to_string); + let fallbacks = optional_string_array(entry, "fallbacks", "merged settings.providerFallbacks")? + .unwrap_or_default(); + Ok(ProviderFallbackConfig { primary, fallbacks }) +} + fn parse_filesystem_mode_label(value: &str) -> Result { match value { "off" => Ok(FilesystemIsolationMode::Off), @@ -1247,6 +1307,66 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn parses_provider_fallbacks_chain_with_primary_and_ordered_fallbacks() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + fs::write( + home.join("settings.json"), + r#"{ + "providerFallbacks": { + "primary": "claude-opus-4-6", + "fallbacks": ["grok-3", "grok-3-mini"] + } + }"#, + ) + .expect("write provider fallback settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + let chain = loaded.provider_fallbacks(); + assert_eq!(chain.primary(), Some("claude-opus-4-6")); + assert_eq!( + chain.fallbacks(), + &["grok-3".to_string(), "grok-3-mini".to_string()] + ); + assert!(!chain.is_empty()); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn provider_fallbacks_default_is_empty_when_unset() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write(home.join("settings.json"), "{}").expect("write empty settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + let chain = loaded.provider_fallbacks(); + assert_eq!(chain.primary(), None); + assert!(chain.fallbacks().is_empty()); + assert!(chain.is_empty()); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn parses_typed_mcp_and_oauth_config() { let root = temp_dir(); diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 4614a6c..ef56269 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -57,8 +57,8 @@ pub use config::{ ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection, McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, - ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig, - RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig, + ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, + RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 7de712f..c69e67b 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -4,8 +4,8 @@ use std::process::Command; use std::time::{Duration, Instant}; use api::{ - max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + max_tokens_for_model, resolve_model_alias, ApiError, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use plugins::PluginTool; @@ -22,10 +22,11 @@ use runtime::{ team_cron_registry::{CronRegistry, TeamRegistry}, worker_boot::{WorkerReadySnapshot, WorkerRegistry}, write_file, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, BashCommandOutput, - BranchFreshness, ContentBlock, ConversationMessage, ConversationRuntime, GrepSearchInput, - LaneCommitProvenance, LaneEvent, LaneEventBlocker, LaneEventName, LaneEventStatus, - LaneFailureClass, McpDegradedReport, MessageRole, PermissionMode, PermissionPolicy, - PromptCacheEvent, RuntimeError, Session, TaskPacket, ToolError, ToolExecutor, + BranchFreshness, ConfigLoader, ContentBlock, ConversationMessage, ConversationRuntime, + GrepSearchInput, LaneCommitProvenance, LaneEvent, LaneEventBlocker, LaneEventName, + LaneEventStatus, LaneFailureClass, McpDegradedReport, MessageRole, PermissionMode, + PermissionPolicy, PromptCacheEvent, ProviderFallbackConfig, RuntimeError, Session, TaskPacket, + ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -3699,29 +3700,73 @@ fn classify_lane_failure(error: &str) -> LaneFailureClass { } } +struct ProviderEntry { + model: String, + client: ProviderClient, +} + struct ProviderRuntimeClient { runtime: tokio::runtime::Runtime, - client: ProviderClient, - model: String, + chain: Vec, allowed_tools: BTreeSet, } impl ProviderRuntimeClient { #[allow(clippy::needless_pass_by_value)] fn new(model: String, allowed_tools: BTreeSet) -> Result { - let model = resolve_model_alias(&model).clone(); - let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; + let fallback_config = load_provider_fallback_config(); + Self::new_with_fallback_config(model, allowed_tools, &fallback_config) + } + + #[allow(clippy::needless_pass_by_value)] + fn new_with_fallback_config( + model: String, + allowed_tools: BTreeSet, + fallback_config: &ProviderFallbackConfig, + ) -> Result { + let primary_model = fallback_config + .primary() + .map(str::to_string) + .unwrap_or(model); + let primary = build_provider_entry(&primary_model)?; + let mut chain = vec![primary]; + for fallback_model in fallback_config.fallbacks() { + match build_provider_entry(fallback_model) { + Ok(entry) => chain.push(entry), + Err(error) => { + eprintln!( + "warning: skipping unavailable fallback provider {fallback_model}: {error}" + ); + } + } + } Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, - client, - model, + chain, allowed_tools, }) } } +fn build_provider_entry(model: &str) -> Result { + let resolved = resolve_model_alias(model).clone(); + let client = ProviderClient::from_model(&resolved).map_err(|error| error.to_string())?; + Ok(ProviderEntry { + model: resolved, + client, + }) +} + +fn load_provider_fallback_config() -> ProviderFallbackConfig { + std::env::current_dir() + .ok() + .and_then(|cwd| ConfigLoader::default_for(cwd).load().ok()) + .map_or_else(ProviderFallbackConfig::default, |config| { + config.provider_fallbacks().clone() + }) +} + impl ApiClient for ProviderRuntimeClient { - #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) .into_iter() @@ -3731,106 +3776,129 @@ impl ApiClient for ProviderRuntimeClient { input_schema: spec.input_schema, }) .collect::>(); - let message_request = MessageRequest { - model: self.model.clone(), - max_tokens: max_tokens_for_model(&self.model), - messages: convert_messages(&request.messages), - system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), - tools: (!tools.is_empty()).then_some(tools), - tool_choice: (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto), - stream: true, - }; + let messages = convert_messages(&request.messages); + let system = (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")); + let tool_choice = (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto); - self.runtime.block_on(async { - let mut stream = self - .client - .stream_message(&message_request) - .await - .map_err(|error| RuntimeError::new(error.to_string()))?; - let mut events = Vec::new(); - let mut pending_tools: BTreeMap = BTreeMap::new(); - let mut saw_stop = false; + let runtime = &self.runtime; + let chain = &self.chain; + let mut last_error: Option = None; + for (index, entry) in chain.iter().enumerate() { + let message_request = MessageRequest { + model: entry.model.clone(), + max_tokens: max_tokens_for_model(&entry.model), + messages: messages.clone(), + system: system.clone(), + tools: (!tools.is_empty()).then(|| tools.clone()), + tool_choice: tool_choice.clone(), + stream: true, + }; - while let Some(event) = stream - .next_event() - .await - .map_err(|error| RuntimeError::new(error.to_string()))? - { - match event { - ApiStreamEvent::MessageStart(start) => { - for block in start.message.content { - push_output_block(block, 0, &mut events, &mut pending_tools, true); - } - } - ApiStreamEvent::ContentBlockStart(start) => { - push_output_block( - start.content_block, - start.index, - &mut events, - &mut pending_tools, - true, - ); - } - ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { - ContentBlockDelta::TextDelta { text } => { - if !text.is_empty() { - events.push(AssistantEvent::TextDelta(text)); - } - } - ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { - input.push_str(&partial_json); - } - } - ContentBlockDelta::ThinkingDelta { .. } - | ContentBlockDelta::SignatureDelta { .. } => {} - }, - ApiStreamEvent::ContentBlockStop(stop) => { - if let Some((id, name, input)) = pending_tools.remove(&stop.index) { - events.push(AssistantEvent::ToolUse { id, name, input }); - } - } - ApiStreamEvent::MessageDelta(delta) => { - events.push(AssistantEvent::Usage(delta.usage.token_usage())); - } - ApiStreamEvent::MessageStop(_) => { - saw_stop = true; - events.push(AssistantEvent::MessageStop); - } + let attempt = runtime.block_on(stream_with_provider(&entry.client, &message_request)); + match attempt { + Ok(events) => return Ok(events), + Err(error) if error.is_retryable() && index + 1 < chain.len() => { + eprintln!( + "provider {} failed with retryable error, falling back: {error}", + entry.model + ); + last_error = Some(error); + continue; + } + Err(error) => return Err(RuntimeError::new(error.to_string())), + } + } + + Err(RuntimeError::new( + last_error + .map(|error| error.to_string()) + .unwrap_or_else(|| String::from("provider chain exhausted with no attempts")), + )) + } +} + +#[allow(clippy::too_many_lines)] +async fn stream_with_provider( + client: &ProviderClient, + message_request: &MessageRequest, +) -> Result, ApiError> { + let mut stream = client.stream_message(message_request).await?; + let mut events = Vec::new(); + let mut pending_tools: BTreeMap = BTreeMap::new(); + let mut saw_stop = false; + + while let Some(event) = stream.next_event().await? { + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block(block, 0, &mut events, &mut pending_tools, true); } } - - push_prompt_cache_record(&self.client, &mut events); - - if !saw_stop - && events.iter().any(|event| { - matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) - || matches!(event, AssistantEvent::ToolUse { .. }) - }) - { + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + start.index, + &mut events, + &mut pending_tools, + true, + ); + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } + | ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(delta.usage.token_usage())); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; events.push(AssistantEvent::MessageStop); } - - if events - .iter() - .any(|event| matches!(event, AssistantEvent::MessageStop)) - { - return Ok(events); - } - - let response = self - .client - .send_message(&MessageRequest { - stream: false, - ..message_request.clone() - }) - .await - .map_err(|error| RuntimeError::new(error.to_string()))?; - let mut events = response_to_events(response); - push_prompt_cache_record(&self.client, &mut events); - Ok(events) - }) + } } + + push_prompt_cache_record(client, &mut events); + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await?; + let mut events = response_to_events(response); + push_prompt_cache_record(client, &mut events); + Ok(events) } struct SubagentToolExecutor { @@ -5257,8 +5325,10 @@ mod tests { derive_agent_state, execute_agent_with_spawn, execute_tool, final_assistant_text, maybe_commit_provenance, mvp_tool_specs, permission_mode_from_plugin, persist_agent_terminal_state, push_output_block, run_task_packet, AgentInput, AgentJob, - GlobalToolRegistry, LaneEventName, LaneFailureClass, SubagentToolExecutor, + GlobalToolRegistry, LaneEventName, LaneFailureClass, ProviderRuntimeClient, + SubagentToolExecutor, }; + use runtime::ProviderFallbackConfig; use api::OutputContentBlock; use runtime::{ permission_enforcer::PermissionEnforcer, ApiRequest, AssistantEvent, ConversationRuntime, @@ -7769,6 +7839,151 @@ printf 'pwsh:%s' "$1" assert_eq!(output["stdout"], "ok"); } + #[test] + fn provider_runtime_client_chain_uses_only_primary_when_no_fallbacks_configured() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + let fallback_config = ProviderFallbackConfig::default(); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-sonnet-4-6".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("primary-only chain should construct"); + + // then + assert_eq!(client.chain.len(), 1); + assert_eq!(client.chain[0].model, "claude-sonnet-4-6"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + } + + #[test] + fn provider_runtime_client_chain_appends_configured_fallbacks_in_order() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + let original_xai = std::env::var_os("XAI_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + std::env::set_var("XAI_API_KEY", "xai-test-key"); + let fallback_config = ProviderFallbackConfig::new( + None, + vec!["grok-3".to_string(), "grok-3-mini".to_string()], + ); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-sonnet-4-6".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("chain with fallbacks should construct"); + + // then + assert_eq!(client.chain.len(), 3); + assert_eq!(client.chain[0].model, "claude-sonnet-4-6"); + assert_eq!(client.chain[1].model, "grok-3"); + assert_eq!(client.chain[2].model, "grok-3-mini"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + match original_xai { + Some(value) => std::env::set_var("XAI_API_KEY", value), + None => std::env::remove_var("XAI_API_KEY"), + } + } + + #[test] + fn provider_runtime_client_chain_primary_override_replaces_constructor_model() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + let original_xai = std::env::var_os("XAI_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + std::env::set_var("XAI_API_KEY", "xai-test-key"); + let fallback_config = ProviderFallbackConfig::new( + Some("grok-3".to_string()), + vec!["claude-sonnet-4-6".to_string()], + ); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-haiku-4-5-20251213".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("chain with primary override should construct"); + + // then + assert_eq!(client.chain.len(), 2); + assert_eq!(client.chain[0].model, "grok-3"); + assert_eq!(client.chain[1].model, "claude-sonnet-4-6"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + match original_xai { + Some(value) => std::env::set_var("XAI_API_KEY", value), + None => std::env::remove_var("XAI_API_KEY"), + } + } + + #[test] + fn provider_runtime_client_chain_skips_fallbacks_missing_credentials() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + let original_xai = std::env::var_os("XAI_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + std::env::remove_var("XAI_API_KEY"); + let fallback_config = ProviderFallbackConfig::new( + None, + vec![ + "grok-3".to_string(), + "claude-haiku-4-5-20251213".to_string(), + ], + ); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-sonnet-4-6".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("chain construction should not fail when only some fallbacks are unavailable"); + + // then + assert_eq!(client.chain.len(), 2); + assert_eq!(client.chain[0].model, "claude-sonnet-4-6"); + assert_eq!(client.chain[1].model, "claude-haiku-4-5-20251213"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + if let Some(value) = original_xai { + std::env::set_var("XAI_API_KEY", value); + } + } + #[test] fn run_task_packet_creates_packet_backed_task() { let result = run_task_packet(TaskPacket {