mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-06 08:04:50 +08:00
The runtime already tracked rough token estimates for compaction, but provider-bound requests still relied on naive model output limits and could be sent upstream even when the selected model could not fit the estimated prompt plus requested output. This adds a small model token/context registry in the API layer, estimates request size from the serialized prompt payload, and fails locally with a dedicated context-window error before Anthropic or xAI calls are made. Focused integration coverage asserts the preflight fires before any HTTP request leaves the process. Constraint: Keep the first pass minimal and reusable across both Anthropic and OpenAI-compatible providers Rejected: Auto-compact-and-retry in the same patch | broader control-flow change than the requested minimal preflight Confidence: medium Scope-risk: narrow Reversibility: clean Directive: Expand the model registry before enabling preflight for additional providers or aliases Tested: cargo build -p api -p tools -p rusty-claude-cli; cargo test -p api Not-tested: End-to-end CLI auto-compaction or retry behavior after a local context_window_blocked failure
378 lines
12 KiB
Rust
378 lines
12 KiB
Rust
use std::future::Future;
|
|
use std::pin::Pin;
|
|
|
|
use serde::Serialize;
|
|
|
|
use crate::error::ApiError;
|
|
use crate::types::{MessageRequest, MessageResponse};
|
|
|
|
pub mod anthropic;
|
|
pub mod openai_compat;
|
|
|
|
#[allow(dead_code)]
|
|
pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>> + Send + 'a>>;
|
|
|
|
#[allow(dead_code)]
|
|
pub trait Provider {
|
|
type Stream;
|
|
|
|
fn send_message<'a>(
|
|
&'a self,
|
|
request: &'a MessageRequest,
|
|
) -> ProviderFuture<'a, MessageResponse>;
|
|
|
|
fn stream_message<'a>(
|
|
&'a self,
|
|
request: &'a MessageRequest,
|
|
) -> ProviderFuture<'a, Self::Stream>;
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum ProviderKind {
|
|
Anthropic,
|
|
Xai,
|
|
OpenAi,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct ProviderMetadata {
|
|
pub provider: ProviderKind,
|
|
pub auth_env: &'static str,
|
|
pub base_url_env: &'static str,
|
|
pub default_base_url: &'static str,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct ModelTokenLimit {
|
|
pub max_output_tokens: u32,
|
|
pub context_window_tokens: u32,
|
|
}
|
|
|
|
const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
|
(
|
|
"opus",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Anthropic,
|
|
auth_env: "ANTHROPIC_API_KEY",
|
|
base_url_env: "ANTHROPIC_BASE_URL",
|
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"sonnet",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Anthropic,
|
|
auth_env: "ANTHROPIC_API_KEY",
|
|
base_url_env: "ANTHROPIC_BASE_URL",
|
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"haiku",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Anthropic,
|
|
auth_env: "ANTHROPIC_API_KEY",
|
|
base_url_env: "ANTHROPIC_BASE_URL",
|
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"grok",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Xai,
|
|
auth_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"grok-3",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Xai,
|
|
auth_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"grok-mini",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Xai,
|
|
auth_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"grok-3-mini",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Xai,
|
|
auth_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
|
},
|
|
),
|
|
(
|
|
"grok-2",
|
|
ProviderMetadata {
|
|
provider: ProviderKind::Xai,
|
|
auth_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
|
},
|
|
),
|
|
];
|
|
|
|
#[must_use]
|
|
pub fn resolve_model_alias(model: &str) -> String {
|
|
let trimmed = model.trim();
|
|
let lower = trimmed.to_ascii_lowercase();
|
|
MODEL_REGISTRY
|
|
.iter()
|
|
.find_map(|(alias, metadata)| {
|
|
(*alias == lower).then_some(match metadata.provider {
|
|
ProviderKind::Anthropic => match *alias {
|
|
"opus" => "claude-opus-4-6",
|
|
"sonnet" => "claude-sonnet-4-6",
|
|
"haiku" => "claude-haiku-4-5-20251213",
|
|
_ => trimmed,
|
|
},
|
|
ProviderKind::Xai => match *alias {
|
|
"grok" | "grok-3" => "grok-3",
|
|
"grok-mini" | "grok-3-mini" => "grok-3-mini",
|
|
"grok-2" => "grok-2",
|
|
_ => trimmed,
|
|
},
|
|
ProviderKind::OpenAi => trimmed,
|
|
})
|
|
})
|
|
.map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
|
|
let canonical = resolve_model_alias(model);
|
|
if canonical.starts_with("claude") {
|
|
return Some(ProviderMetadata {
|
|
provider: ProviderKind::Anthropic,
|
|
auth_env: "ANTHROPIC_API_KEY",
|
|
base_url_env: "ANTHROPIC_BASE_URL",
|
|
default_base_url: anthropic::DEFAULT_BASE_URL,
|
|
});
|
|
}
|
|
if canonical.starts_with("grok") {
|
|
return Some(ProviderMetadata {
|
|
provider: ProviderKind::Xai,
|
|
auth_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
|
|
});
|
|
}
|
|
None
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn detect_provider_kind(model: &str) -> ProviderKind {
|
|
if let Some(metadata) = metadata_for_model(model) {
|
|
return metadata.provider;
|
|
}
|
|
if anthropic::has_auth_from_env_or_saved().unwrap_or(false) {
|
|
return ProviderKind::Anthropic;
|
|
}
|
|
if openai_compat::has_api_key("OPENAI_API_KEY") {
|
|
return ProviderKind::OpenAi;
|
|
}
|
|
if openai_compat::has_api_key("XAI_API_KEY") {
|
|
return ProviderKind::Xai;
|
|
}
|
|
ProviderKind::Anthropic
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn max_tokens_for_model(model: &str) -> u32 {
|
|
model_token_limit(model).map_or_else(
|
|
|| {
|
|
let canonical = resolve_model_alias(model);
|
|
if canonical.contains("opus") {
|
|
32_000
|
|
} else {
|
|
64_000
|
|
}
|
|
},
|
|
|limit| limit.max_output_tokens,
|
|
)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn model_token_limit(model: &str) -> Option<ModelTokenLimit> {
|
|
let canonical = resolve_model_alias(model);
|
|
match canonical.as_str() {
|
|
"claude-opus-4-6" => Some(ModelTokenLimit {
|
|
max_output_tokens: 32_000,
|
|
context_window_tokens: 200_000,
|
|
}),
|
|
"claude-sonnet-4-6" | "claude-haiku-4-5-20251213" => Some(ModelTokenLimit {
|
|
max_output_tokens: 64_000,
|
|
context_window_tokens: 200_000,
|
|
}),
|
|
"grok-3" | "grok-3-mini" => Some(ModelTokenLimit {
|
|
max_output_tokens: 64_000,
|
|
context_window_tokens: 131_072,
|
|
}),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
pub fn preflight_message_request(request: &MessageRequest) -> Result<(), ApiError> {
|
|
let Some(limit) = model_token_limit(&request.model) else {
|
|
return Ok(());
|
|
};
|
|
|
|
let estimated_input_tokens = estimate_message_request_input_tokens(request);
|
|
let estimated_total_tokens = estimated_input_tokens.saturating_add(request.max_tokens);
|
|
if estimated_total_tokens > limit.context_window_tokens {
|
|
return Err(ApiError::ContextWindowExceeded {
|
|
model: resolve_model_alias(&request.model),
|
|
estimated_input_tokens,
|
|
requested_output_tokens: request.max_tokens,
|
|
estimated_total_tokens,
|
|
context_window_tokens: limit.context_window_tokens,
|
|
});
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn estimate_message_request_input_tokens(request: &MessageRequest) -> u32 {
|
|
let mut estimate = estimate_serialized_tokens(&request.messages);
|
|
estimate = estimate.saturating_add(estimate_serialized_tokens(&request.system));
|
|
estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tools));
|
|
estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tool_choice));
|
|
estimate
|
|
}
|
|
|
|
fn estimate_serialized_tokens<T: Serialize>(value: &T) -> u32 {
|
|
serde_json::to_vec(value)
|
|
.ok()
|
|
.map_or(0, |bytes| (bytes.len() / 4 + 1) as u32)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use serde_json::json;
|
|
|
|
use crate::error::ApiError;
|
|
use crate::types::{
|
|
InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
|
|
};
|
|
|
|
use super::{
|
|
detect_provider_kind, max_tokens_for_model, model_token_limit, preflight_message_request,
|
|
resolve_model_alias, ProviderKind,
|
|
};
|
|
|
|
#[test]
|
|
fn resolves_grok_aliases() {
|
|
assert_eq!(resolve_model_alias("grok"), "grok-3");
|
|
assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
|
|
assert_eq!(resolve_model_alias("grok-2"), "grok-2");
|
|
}
|
|
|
|
#[test]
|
|
fn detects_provider_from_model_name_first() {
|
|
assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
|
|
assert_eq!(
|
|
detect_provider_kind("claude-sonnet-4-6"),
|
|
ProviderKind::Anthropic
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn keeps_existing_max_token_heuristic() {
|
|
assert_eq!(max_tokens_for_model("opus"), 32_000);
|
|
assert_eq!(max_tokens_for_model("grok-3"), 64_000);
|
|
}
|
|
|
|
#[test]
|
|
fn returns_context_window_metadata_for_supported_models() {
|
|
assert_eq!(
|
|
model_token_limit("claude-sonnet-4-6")
|
|
.expect("claude-sonnet-4-6 should be registered")
|
|
.context_window_tokens,
|
|
200_000
|
|
);
|
|
assert_eq!(
|
|
model_token_limit("grok-mini")
|
|
.expect("grok-mini should resolve to a registered model")
|
|
.context_window_tokens,
|
|
131_072
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn preflight_blocks_requests_that_exceed_the_model_context_window() {
|
|
let request = MessageRequest {
|
|
model: "claude-sonnet-4-6".to_string(),
|
|
max_tokens: 64_000,
|
|
messages: vec![InputMessage {
|
|
role: "user".to_string(),
|
|
content: vec![InputContentBlock::Text {
|
|
text: "x".repeat(600_000),
|
|
}],
|
|
}],
|
|
system: Some("Keep the answer short.".to_string()),
|
|
tools: Some(vec![ToolDefinition {
|
|
name: "weather".to_string(),
|
|
description: Some("Fetches weather".to_string()),
|
|
input_schema: json!({
|
|
"type": "object",
|
|
"properties": { "city": { "type": "string" } },
|
|
}),
|
|
}]),
|
|
tool_choice: Some(ToolChoice::Auto),
|
|
stream: true,
|
|
};
|
|
|
|
let error = preflight_message_request(&request)
|
|
.expect_err("oversized request should be rejected before the provider call");
|
|
|
|
match error {
|
|
ApiError::ContextWindowExceeded {
|
|
model,
|
|
estimated_input_tokens,
|
|
requested_output_tokens,
|
|
estimated_total_tokens,
|
|
context_window_tokens,
|
|
} => {
|
|
assert_eq!(model, "claude-sonnet-4-6");
|
|
assert!(estimated_input_tokens > 136_000);
|
|
assert_eq!(requested_output_tokens, 64_000);
|
|
assert!(estimated_total_tokens > context_window_tokens);
|
|
assert_eq!(context_window_tokens, 200_000);
|
|
}
|
|
other => panic!("expected context-window preflight failure, got {other:?}"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn preflight_skips_unknown_models() {
|
|
let request = MessageRequest {
|
|
model: "unknown-model".to_string(),
|
|
max_tokens: 64_000,
|
|
messages: vec![InputMessage {
|
|
role: "user".to_string(),
|
|
content: vec![InputContentBlock::Text {
|
|
text: "x".repeat(600_000),
|
|
}],
|
|
}],
|
|
system: None,
|
|
tools: None,
|
|
tool_choice: None,
|
|
stream: false,
|
|
};
|
|
|
|
preflight_message_request(&request)
|
|
.expect("models without context metadata should skip the guarded preflight");
|
|
}
|
|
}
|