mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-07 16:44:50 +08:00
feat: b5-skip-perms-flag — batch 5 upstream parity
This commit is contained in:
@@ -61,6 +61,16 @@ pub struct RuntimeFeatureConfig {
|
||||
permission_mode: Option<ResolvedPermissionMode>,
|
||||
permission_rules: RuntimePermissionRuleConfig,
|
||||
sandbox: SandboxConfig,
|
||||
provider_fallbacks: ProviderFallbackConfig,
|
||||
}
|
||||
|
||||
/// Ordered chain of fallback model identifiers used when the primary
|
||||
/// provider returns a retryable failure (429/500/503/etc.). The chain is
|
||||
/// strict: each entry is tried in order until one succeeds.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct ProviderFallbackConfig {
|
||||
primary: Option<String>,
|
||||
fallbacks: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hook command lists grouped by lifecycle stage.
|
||||
@@ -283,6 +293,7 @@ impl ConfigLoader {
|
||||
permission_mode: parse_optional_permission_mode(&merged_value)?,
|
||||
permission_rules: parse_optional_permission_rules(&merged_value)?,
|
||||
sandbox: parse_optional_sandbox_config(&merged_value)?,
|
||||
provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?,
|
||||
};
|
||||
|
||||
Ok(RuntimeConfig {
|
||||
@@ -367,6 +378,11 @@ impl RuntimeConfig {
|
||||
pub fn sandbox(&self) -> &SandboxConfig {
|
||||
&self.feature_config.sandbox
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig {
|
||||
&self.feature_config.provider_fallbacks
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeFeatureConfig {
|
||||
@@ -421,6 +437,33 @@ impl RuntimeFeatureConfig {
|
||||
pub fn sandbox(&self) -> &SandboxConfig {
|
||||
&self.sandbox
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig {
|
||||
&self.provider_fallbacks
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderFallbackConfig {
|
||||
#[must_use]
|
||||
pub fn new(primary: Option<String>, fallbacks: Vec<String>) -> Self {
|
||||
Self { primary, fallbacks }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn primary(&self) -> Option<&str> {
|
||||
self.primary.as_deref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn fallbacks(&self) -> &[String] {
|
||||
&self.fallbacks
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.fallbacks.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimePluginConfig {
|
||||
@@ -776,6 +819,23 @@ fn parse_optional_sandbox_config(root: &JsonValue) -> Result<SandboxConfig, Conf
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_optional_provider_fallbacks(
|
||||
root: &JsonValue,
|
||||
) -> Result<ProviderFallbackConfig, ConfigError> {
|
||||
let Some(object) = root.as_object() else {
|
||||
return Ok(ProviderFallbackConfig::default());
|
||||
};
|
||||
let Some(value) = object.get("providerFallbacks") else {
|
||||
return Ok(ProviderFallbackConfig::default());
|
||||
};
|
||||
let entry = expect_object(value, "merged settings.providerFallbacks")?;
|
||||
let primary =
|
||||
optional_string(entry, "primary", "merged settings.providerFallbacks")?.map(str::to_string);
|
||||
let fallbacks = optional_string_array(entry, "fallbacks", "merged settings.providerFallbacks")?
|
||||
.unwrap_or_default();
|
||||
Ok(ProviderFallbackConfig { primary, fallbacks })
|
||||
}
|
||||
|
||||
fn parse_filesystem_mode_label(value: &str) -> Result<FilesystemIsolationMode, ConfigError> {
|
||||
match value {
|
||||
"off" => Ok(FilesystemIsolationMode::Off),
|
||||
@@ -1247,6 +1307,66 @@ mod tests {
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_provider_fallbacks_chain_with_primary_and_ordered_fallbacks() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(cwd.join(".claw")).expect("project config dir");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::write(
|
||||
home.join("settings.json"),
|
||||
r#"{
|
||||
"providerFallbacks": {
|
||||
"primary": "claude-opus-4-6",
|
||||
"fallbacks": ["grok-3", "grok-3-mini"]
|
||||
}
|
||||
}"#,
|
||||
)
|
||||
.expect("write provider fallback settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
let chain = loaded.provider_fallbacks();
|
||||
assert_eq!(chain.primary(), Some("claude-opus-4-6"));
|
||||
assert_eq!(
|
||||
chain.fallbacks(),
|
||||
&["grok-3".to_string(), "grok-3-mini".to_string()]
|
||||
);
|
||||
assert!(!chain.is_empty());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_fallbacks_default_is_empty_when_unset() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
let cwd = root.join("project");
|
||||
let home = root.join("home").join(".claw");
|
||||
fs::create_dir_all(&home).expect("home config dir");
|
||||
fs::create_dir_all(&cwd).expect("project dir");
|
||||
fs::write(home.join("settings.json"), "{}").expect("write empty settings");
|
||||
|
||||
// when
|
||||
let loaded = ConfigLoader::new(&cwd, &home)
|
||||
.load()
|
||||
.expect("config should load");
|
||||
|
||||
// then
|
||||
let chain = loaded.provider_fallbacks();
|
||||
assert_eq!(chain.primary(), None);
|
||||
assert!(chain.fallbacks().is_empty());
|
||||
assert!(chain.is_empty());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_typed_mcp_and_oauth_config() {
|
||||
let root = temp_dir();
|
||||
|
||||
@@ -57,8 +57,8 @@ pub use config::{
|
||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection,
|
||||
McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig,
|
||||
RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig,
|
||||
ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig,
|
||||
RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig,
|
||||
CLAW_SETTINGS_SCHEMA_NAME,
|
||||
};
|
||||
pub use conversation::{
|
||||
|
||||
@@ -4,8 +4,8 @@ use std::process::Command;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use api::{
|
||||
max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage,
|
||||
MessageRequest, MessageResponse, OutputContentBlock, ProviderClient,
|
||||
max_tokens_for_model, resolve_model_alias, ApiError, ContentBlockDelta, InputContentBlock,
|
||||
InputMessage, MessageRequest, MessageResponse, OutputContentBlock, ProviderClient,
|
||||
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
||||
};
|
||||
use plugins::PluginTool;
|
||||
@@ -22,10 +22,11 @@ use runtime::{
|
||||
team_cron_registry::{CronRegistry, TeamRegistry},
|
||||
worker_boot::{WorkerReadySnapshot, WorkerRegistry},
|
||||
write_file, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, BashCommandOutput,
|
||||
BranchFreshness, ContentBlock, ConversationMessage, ConversationRuntime, GrepSearchInput,
|
||||
LaneCommitProvenance, LaneEvent, LaneEventBlocker, LaneEventName, LaneEventStatus,
|
||||
LaneFailureClass, McpDegradedReport, MessageRole, PermissionMode, PermissionPolicy,
|
||||
PromptCacheEvent, RuntimeError, Session, TaskPacket, ToolError, ToolExecutor,
|
||||
BranchFreshness, ConfigLoader, ContentBlock, ConversationMessage, ConversationRuntime,
|
||||
GrepSearchInput, LaneCommitProvenance, LaneEvent, LaneEventBlocker, LaneEventName,
|
||||
LaneEventStatus, LaneFailureClass, McpDegradedReport, MessageRole, PermissionMode,
|
||||
PermissionPolicy, PromptCacheEvent, ProviderFallbackConfig, RuntimeError, Session, TaskPacket,
|
||||
ToolError, ToolExecutor,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
@@ -3699,29 +3700,73 @@ fn classify_lane_failure(error: &str) -> LaneFailureClass {
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderEntry {
|
||||
model: String,
|
||||
client: ProviderClient,
|
||||
}
|
||||
|
||||
struct ProviderRuntimeClient {
|
||||
runtime: tokio::runtime::Runtime,
|
||||
client: ProviderClient,
|
||||
model: String,
|
||||
chain: Vec<ProviderEntry>,
|
||||
allowed_tools: BTreeSet<String>,
|
||||
}
|
||||
|
||||
impl ProviderRuntimeClient {
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
|
||||
let model = resolve_model_alias(&model).clone();
|
||||
let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?;
|
||||
let fallback_config = load_provider_fallback_config();
|
||||
Self::new_with_fallback_config(model, allowed_tools, &fallback_config)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn new_with_fallback_config(
|
||||
model: String,
|
||||
allowed_tools: BTreeSet<String>,
|
||||
fallback_config: &ProviderFallbackConfig,
|
||||
) -> Result<Self, String> {
|
||||
let primary_model = fallback_config
|
||||
.primary()
|
||||
.map(str::to_string)
|
||||
.unwrap_or(model);
|
||||
let primary = build_provider_entry(&primary_model)?;
|
||||
let mut chain = vec![primary];
|
||||
for fallback_model in fallback_config.fallbacks() {
|
||||
match build_provider_entry(fallback_model) {
|
||||
Ok(entry) => chain.push(entry),
|
||||
Err(error) => {
|
||||
eprintln!(
|
||||
"warning: skipping unavailable fallback provider {fallback_model}: {error}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
|
||||
client,
|
||||
model,
|
||||
chain,
|
||||
allowed_tools,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn build_provider_entry(model: &str) -> Result<ProviderEntry, String> {
|
||||
let resolved = resolve_model_alias(model).clone();
|
||||
let client = ProviderClient::from_model(&resolved).map_err(|error| error.to_string())?;
|
||||
Ok(ProviderEntry {
|
||||
model: resolved,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_provider_fallback_config() -> ProviderFallbackConfig {
|
||||
std::env::current_dir()
|
||||
.ok()
|
||||
.and_then(|cwd| ConfigLoader::default_for(cwd).load().ok())
|
||||
.map_or_else(ProviderFallbackConfig::default, |config| {
|
||||
config.provider_fallbacks().clone()
|
||||
})
|
||||
}
|
||||
|
||||
impl ApiClient for ProviderRuntimeClient {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
||||
let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
|
||||
.into_iter()
|
||||
@@ -3731,106 +3776,129 @@ impl ApiClient for ProviderRuntimeClient {
|
||||
input_schema: spec.input_schema,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let message_request = MessageRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: max_tokens_for_model(&self.model),
|
||||
messages: convert_messages(&request.messages),
|
||||
system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")),
|
||||
tools: (!tools.is_empty()).then_some(tools),
|
||||
tool_choice: (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto),
|
||||
stream: true,
|
||||
};
|
||||
let messages = convert_messages(&request.messages);
|
||||
let system = (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n"));
|
||||
let tool_choice = (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto);
|
||||
|
||||
self.runtime.block_on(async {
|
||||
let mut stream = self
|
||||
.client
|
||||
.stream_message(&message_request)
|
||||
.await
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||
let mut events = Vec::new();
|
||||
let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new();
|
||||
let mut saw_stop = false;
|
||||
let runtime = &self.runtime;
|
||||
let chain = &self.chain;
|
||||
let mut last_error: Option<ApiError> = None;
|
||||
for (index, entry) in chain.iter().enumerate() {
|
||||
let message_request = MessageRequest {
|
||||
model: entry.model.clone(),
|
||||
max_tokens: max_tokens_for_model(&entry.model),
|
||||
messages: messages.clone(),
|
||||
system: system.clone(),
|
||||
tools: (!tools.is_empty()).then(|| tools.clone()),
|
||||
tool_choice: tool_choice.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
while let Some(event) = stream
|
||||
.next_event()
|
||||
.await
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?
|
||||
{
|
||||
match event {
|
||||
ApiStreamEvent::MessageStart(start) => {
|
||||
for block in start.message.content {
|
||||
push_output_block(block, 0, &mut events, &mut pending_tools, true);
|
||||
}
|
||||
}
|
||||
ApiStreamEvent::ContentBlockStart(start) => {
|
||||
push_output_block(
|
||||
start.content_block,
|
||||
start.index,
|
||||
&mut events,
|
||||
&mut pending_tools,
|
||||
true,
|
||||
);
|
||||
}
|
||||
ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta {
|
||||
ContentBlockDelta::TextDelta { text } => {
|
||||
if !text.is_empty() {
|
||||
events.push(AssistantEvent::TextDelta(text));
|
||||
}
|
||||
}
|
||||
ContentBlockDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) {
|
||||
input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
ContentBlockDelta::ThinkingDelta { .. }
|
||||
| ContentBlockDelta::SignatureDelta { .. } => {}
|
||||
},
|
||||
ApiStreamEvent::ContentBlockStop(stop) => {
|
||||
if let Some((id, name, input)) = pending_tools.remove(&stop.index) {
|
||||
events.push(AssistantEvent::ToolUse { id, name, input });
|
||||
}
|
||||
}
|
||||
ApiStreamEvent::MessageDelta(delta) => {
|
||||
events.push(AssistantEvent::Usage(delta.usage.token_usage()));
|
||||
}
|
||||
ApiStreamEvent::MessageStop(_) => {
|
||||
saw_stop = true;
|
||||
events.push(AssistantEvent::MessageStop);
|
||||
}
|
||||
let attempt = runtime.block_on(stream_with_provider(&entry.client, &message_request));
|
||||
match attempt {
|
||||
Ok(events) => return Ok(events),
|
||||
Err(error) if error.is_retryable() && index + 1 < chain.len() => {
|
||||
eprintln!(
|
||||
"provider {} failed with retryable error, falling back: {error}",
|
||||
entry.model
|
||||
);
|
||||
last_error = Some(error);
|
||||
continue;
|
||||
}
|
||||
Err(error) => return Err(RuntimeError::new(error.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
Err(RuntimeError::new(
|
||||
last_error
|
||||
.map(|error| error.to_string())
|
||||
.unwrap_or_else(|| String::from("provider chain exhausted with no attempts")),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn stream_with_provider(
|
||||
client: &ProviderClient,
|
||||
message_request: &MessageRequest,
|
||||
) -> Result<Vec<AssistantEvent>, ApiError> {
|
||||
let mut stream = client.stream_message(message_request).await?;
|
||||
let mut events = Vec::new();
|
||||
let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new();
|
||||
let mut saw_stop = false;
|
||||
|
||||
while let Some(event) = stream.next_event().await? {
|
||||
match event {
|
||||
ApiStreamEvent::MessageStart(start) => {
|
||||
for block in start.message.content {
|
||||
push_output_block(block, 0, &mut events, &mut pending_tools, true);
|
||||
}
|
||||
}
|
||||
|
||||
push_prompt_cache_record(&self.client, &mut events);
|
||||
|
||||
if !saw_stop
|
||||
&& events.iter().any(|event| {
|
||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||
|| matches!(event, AssistantEvent::ToolUse { .. })
|
||||
})
|
||||
{
|
||||
ApiStreamEvent::ContentBlockStart(start) => {
|
||||
push_output_block(
|
||||
start.content_block,
|
||||
start.index,
|
||||
&mut events,
|
||||
&mut pending_tools,
|
||||
true,
|
||||
);
|
||||
}
|
||||
ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta {
|
||||
ContentBlockDelta::TextDelta { text } => {
|
||||
if !text.is_empty() {
|
||||
events.push(AssistantEvent::TextDelta(text));
|
||||
}
|
||||
}
|
||||
ContentBlockDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) {
|
||||
input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
ContentBlockDelta::ThinkingDelta { .. }
|
||||
| ContentBlockDelta::SignatureDelta { .. } => {}
|
||||
},
|
||||
ApiStreamEvent::ContentBlockStop(stop) => {
|
||||
if let Some((id, name, input)) = pending_tools.remove(&stop.index) {
|
||||
events.push(AssistantEvent::ToolUse { id, name, input });
|
||||
}
|
||||
}
|
||||
ApiStreamEvent::MessageDelta(delta) => {
|
||||
events.push(AssistantEvent::Usage(delta.usage.token_usage()));
|
||||
}
|
||||
ApiStreamEvent::MessageStop(_) => {
|
||||
saw_stop = true;
|
||||
events.push(AssistantEvent::MessageStop);
|
||||
}
|
||||
|
||||
if events
|
||||
.iter()
|
||||
.any(|event| matches!(event, AssistantEvent::MessageStop))
|
||||
{
|
||||
return Ok(events);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.send_message(&MessageRequest {
|
||||
stream: false,
|
||||
..message_request.clone()
|
||||
})
|
||||
.await
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||
let mut events = response_to_events(response);
|
||||
push_prompt_cache_record(&self.client, &mut events);
|
||||
Ok(events)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
push_prompt_cache_record(client, &mut events);
|
||||
|
||||
if !saw_stop
|
||||
&& events.iter().any(|event| {
|
||||
matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty())
|
||||
|| matches!(event, AssistantEvent::ToolUse { .. })
|
||||
})
|
||||
{
|
||||
events.push(AssistantEvent::MessageStop);
|
||||
}
|
||||
|
||||
if events
|
||||
.iter()
|
||||
.any(|event| matches!(event, AssistantEvent::MessageStop))
|
||||
{
|
||||
return Ok(events);
|
||||
}
|
||||
|
||||
let response = client
|
||||
.send_message(&MessageRequest {
|
||||
stream: false,
|
||||
..message_request.clone()
|
||||
})
|
||||
.await?;
|
||||
let mut events = response_to_events(response);
|
||||
push_prompt_cache_record(client, &mut events);
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
struct SubagentToolExecutor {
|
||||
@@ -5257,8 +5325,10 @@ mod tests {
|
||||
derive_agent_state, execute_agent_with_spawn, execute_tool, final_assistant_text,
|
||||
maybe_commit_provenance, mvp_tool_specs, permission_mode_from_plugin,
|
||||
persist_agent_terminal_state, push_output_block, run_task_packet, AgentInput, AgentJob,
|
||||
GlobalToolRegistry, LaneEventName, LaneFailureClass, SubagentToolExecutor,
|
||||
GlobalToolRegistry, LaneEventName, LaneFailureClass, ProviderRuntimeClient,
|
||||
SubagentToolExecutor,
|
||||
};
|
||||
use runtime::ProviderFallbackConfig;
|
||||
use api::OutputContentBlock;
|
||||
use runtime::{
|
||||
permission_enforcer::PermissionEnforcer, ApiRequest, AssistantEvent, ConversationRuntime,
|
||||
@@ -7769,6 +7839,151 @@ printf 'pwsh:%s' "$1"
|
||||
assert_eq!(output["stdout"], "ok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_runtime_client_chain_uses_only_primary_when_no_fallbacks_configured() {
|
||||
// given
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY");
|
||||
std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key");
|
||||
let fallback_config = ProviderFallbackConfig::default();
|
||||
|
||||
// when
|
||||
let client = ProviderRuntimeClient::new_with_fallback_config(
|
||||
"claude-sonnet-4-6".to_string(),
|
||||
BTreeSet::new(),
|
||||
&fallback_config,
|
||||
)
|
||||
.expect("primary-only chain should construct");
|
||||
|
||||
// then
|
||||
assert_eq!(client.chain.len(), 1);
|
||||
assert_eq!(client.chain[0].model, "claude-sonnet-4-6");
|
||||
|
||||
match original_anthropic {
|
||||
Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value),
|
||||
None => std::env::remove_var("ANTHROPIC_API_KEY"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_runtime_client_chain_appends_configured_fallbacks_in_order() {
|
||||
// given
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY");
|
||||
let original_xai = std::env::var_os("XAI_API_KEY");
|
||||
std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key");
|
||||
std::env::set_var("XAI_API_KEY", "xai-test-key");
|
||||
let fallback_config = ProviderFallbackConfig::new(
|
||||
None,
|
||||
vec!["grok-3".to_string(), "grok-3-mini".to_string()],
|
||||
);
|
||||
|
||||
// when
|
||||
let client = ProviderRuntimeClient::new_with_fallback_config(
|
||||
"claude-sonnet-4-6".to_string(),
|
||||
BTreeSet::new(),
|
||||
&fallback_config,
|
||||
)
|
||||
.expect("chain with fallbacks should construct");
|
||||
|
||||
// then
|
||||
assert_eq!(client.chain.len(), 3);
|
||||
assert_eq!(client.chain[0].model, "claude-sonnet-4-6");
|
||||
assert_eq!(client.chain[1].model, "grok-3");
|
||||
assert_eq!(client.chain[2].model, "grok-3-mini");
|
||||
|
||||
match original_anthropic {
|
||||
Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value),
|
||||
None => std::env::remove_var("ANTHROPIC_API_KEY"),
|
||||
}
|
||||
match original_xai {
|
||||
Some(value) => std::env::set_var("XAI_API_KEY", value),
|
||||
None => std::env::remove_var("XAI_API_KEY"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_runtime_client_chain_primary_override_replaces_constructor_model() {
|
||||
// given
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY");
|
||||
let original_xai = std::env::var_os("XAI_API_KEY");
|
||||
std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key");
|
||||
std::env::set_var("XAI_API_KEY", "xai-test-key");
|
||||
let fallback_config = ProviderFallbackConfig::new(
|
||||
Some("grok-3".to_string()),
|
||||
vec!["claude-sonnet-4-6".to_string()],
|
||||
);
|
||||
|
||||
// when
|
||||
let client = ProviderRuntimeClient::new_with_fallback_config(
|
||||
"claude-haiku-4-5-20251213".to_string(),
|
||||
BTreeSet::new(),
|
||||
&fallback_config,
|
||||
)
|
||||
.expect("chain with primary override should construct");
|
||||
|
||||
// then
|
||||
assert_eq!(client.chain.len(), 2);
|
||||
assert_eq!(client.chain[0].model, "grok-3");
|
||||
assert_eq!(client.chain[1].model, "claude-sonnet-4-6");
|
||||
|
||||
match original_anthropic {
|
||||
Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value),
|
||||
None => std::env::remove_var("ANTHROPIC_API_KEY"),
|
||||
}
|
||||
match original_xai {
|
||||
Some(value) => std::env::set_var("XAI_API_KEY", value),
|
||||
None => std::env::remove_var("XAI_API_KEY"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_runtime_client_chain_skips_fallbacks_missing_credentials() {
|
||||
// given
|
||||
let _guard = env_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY");
|
||||
let original_xai = std::env::var_os("XAI_API_KEY");
|
||||
std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key");
|
||||
std::env::remove_var("XAI_API_KEY");
|
||||
let fallback_config = ProviderFallbackConfig::new(
|
||||
None,
|
||||
vec![
|
||||
"grok-3".to_string(),
|
||||
"claude-haiku-4-5-20251213".to_string(),
|
||||
],
|
||||
);
|
||||
|
||||
// when
|
||||
let client = ProviderRuntimeClient::new_with_fallback_config(
|
||||
"claude-sonnet-4-6".to_string(),
|
||||
BTreeSet::new(),
|
||||
&fallback_config,
|
||||
)
|
||||
.expect("chain construction should not fail when only some fallbacks are unavailable");
|
||||
|
||||
// then
|
||||
assert_eq!(client.chain.len(), 2);
|
||||
assert_eq!(client.chain[0].model, "claude-sonnet-4-6");
|
||||
assert_eq!(client.chain[1].model, "claude-haiku-4-5-20251213");
|
||||
|
||||
match original_anthropic {
|
||||
Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value),
|
||||
None => std::env::remove_var("ANTHROPIC_API_KEY"),
|
||||
}
|
||||
if let Some(value) = original_xai {
|
||||
std::env::set_var("XAI_API_KEY", value);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_task_packet_creates_packet_backed_task() {
|
||||
let result = run_task_packet(TaskPacket {
|
||||
|
||||
Reference in New Issue
Block a user