mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-06 16:14:49 +08:00
Block oversized requests before providers hard-fail
The runtime already tracked rough token estimates for compaction, but provider-bound requests still relied on naive model output limits and could be sent upstream even when the selected model could not fit the estimated prompt plus requested output. This adds a small model token/context registry in the API layer, estimates request size from the serialized prompt payload, and fails locally with a dedicated context-window error before Anthropic or xAI calls are made. Focused integration coverage asserts the preflight fires before any HTTP request leaves the process. Constraint: Keep the first pass minimal and reusable across both Anthropic and OpenAI-compatible providers Rejected: Auto-compact-and-retry in the same patch | broader control-flow change than the requested minimal preflight Confidence: medium Scope-risk: narrow Reversibility: clean Directive: Expand the model registry before enabling preflight for additional providers or aliases Tested: cargo build -p api -p tools -p rusty-claude-cli; cargo test -p api Not-tested: End-to-end CLI auto-compaction or retry behavior after a local context_window_blocked failure
This commit is contained in:
@@ -8,6 +8,13 @@ pub enum ApiError {
|
|||||||
provider: &'static str,
|
provider: &'static str,
|
||||||
env_vars: &'static [&'static str],
|
env_vars: &'static [&'static str],
|
||||||
},
|
},
|
||||||
|
ContextWindowExceeded {
|
||||||
|
model: String,
|
||||||
|
estimated_input_tokens: u32,
|
||||||
|
requested_output_tokens: u32,
|
||||||
|
estimated_total_tokens: u32,
|
||||||
|
context_window_tokens: u32,
|
||||||
|
},
|
||||||
ExpiredOAuthToken,
|
ExpiredOAuthToken,
|
||||||
Auth(String),
|
Auth(String),
|
||||||
InvalidApiKeyEnv(VarError),
|
InvalidApiKeyEnv(VarError),
|
||||||
@@ -48,6 +55,7 @@ impl ApiError {
|
|||||||
Self::Api { retryable, .. } => *retryable,
|
Self::Api { retryable, .. } => *retryable,
|
||||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||||
Self::MissingCredentials { .. }
|
Self::MissingCredentials { .. }
|
||||||
|
| Self::ContextWindowExceeded { .. }
|
||||||
| Self::ExpiredOAuthToken
|
| Self::ExpiredOAuthToken
|
||||||
| Self::Auth(_)
|
| Self::Auth(_)
|
||||||
| Self::InvalidApiKeyEnv(_)
|
| Self::InvalidApiKeyEnv(_)
|
||||||
@@ -67,6 +75,16 @@ impl Display for ApiError {
|
|||||||
"missing {provider} credentials; export {} before calling the {provider} API",
|
"missing {provider} credentials; export {} before calling the {provider} API",
|
||||||
env_vars.join(" or ")
|
env_vars.join(" or ")
|
||||||
),
|
),
|
||||||
|
Self::ContextWindowExceeded {
|
||||||
|
model,
|
||||||
|
estimated_input_tokens,
|
||||||
|
requested_output_tokens,
|
||||||
|
estimated_total_tokens,
|
||||||
|
context_window_tokens,
|
||||||
|
} => write!(
|
||||||
|
f,
|
||||||
|
"context_window_blocked for {model}: estimated input {estimated_input_tokens} + requested output {requested_output_tokens} = {estimated_total_tokens} tokens exceeds the {context_window_tokens}-token context window; compact the session or reduce request size before retrying"
|
||||||
|
),
|
||||||
Self::ExpiredOAuthToken => {
|
Self::ExpiredOAuthToken => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session
|
|||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||||
|
|
||||||
use super::{Provider, ProviderFuture};
|
use super::{preflight_message_request, Provider, ProviderFuture};
|
||||||
use crate::sse::SseParser;
|
use crate::sse::SseParser;
|
||||||
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
||||||
|
|
||||||
@@ -294,6 +294,8 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
preflight_message_request(&request)?;
|
||||||
|
|
||||||
let response = self.send_with_retry(&request).await?;
|
let response = self.send_with_retry(&request).await?;
|
||||||
let request_id = request_id_from_headers(response.headers());
|
let request_id = request_id_from_headers(response.headers());
|
||||||
let mut response = response
|
let mut response = response
|
||||||
@@ -337,6 +339,7 @@ impl AnthropicClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageStream, ApiError> {
|
) -> Result<MessageStream, ApiError> {
|
||||||
|
preflight_message_request(request)?;
|
||||||
let response = self
|
let response = self
|
||||||
.send_with_retry(&request.clone().with_streaming())
|
.send_with_retry(&request.clone().with_streaming())
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
use crate::types::{MessageRequest, MessageResponse};
|
use crate::types::{MessageRequest, MessageResponse};
|
||||||
|
|
||||||
@@ -40,6 +42,12 @@ pub struct ProviderMetadata {
|
|||||||
pub default_base_url: &'static str,
|
pub default_base_url: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct ModelTokenLimit {
|
||||||
|
pub max_output_tokens: u32,
|
||||||
|
pub context_window_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
|
||||||
(
|
(
|
||||||
"opus",
|
"opus",
|
||||||
@@ -182,17 +190,86 @@ pub fn detect_provider_kind(model: &str) -> ProviderKind {
|
|||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn max_tokens_for_model(model: &str) -> u32 {
|
pub fn max_tokens_for_model(model: &str) -> u32 {
|
||||||
|
model_token_limit(model).map_or_else(
|
||||||
|
|| {
|
||||||
|
let canonical = resolve_model_alias(model);
|
||||||
|
if canonical.contains("opus") {
|
||||||
|
32_000
|
||||||
|
} else {
|
||||||
|
64_000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|limit| limit.max_output_tokens,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn model_token_limit(model: &str) -> Option<ModelTokenLimit> {
|
||||||
let canonical = resolve_model_alias(model);
|
let canonical = resolve_model_alias(model);
|
||||||
if canonical.contains("opus") {
|
match canonical.as_str() {
|
||||||
32_000
|
"claude-opus-4-6" => Some(ModelTokenLimit {
|
||||||
} else {
|
max_output_tokens: 32_000,
|
||||||
64_000
|
context_window_tokens: 200_000,
|
||||||
|
}),
|
||||||
|
"claude-sonnet-4-6" | "claude-haiku-4-5-20251213" => Some(ModelTokenLimit {
|
||||||
|
max_output_tokens: 64_000,
|
||||||
|
context_window_tokens: 200_000,
|
||||||
|
}),
|
||||||
|
"grok-3" | "grok-3-mini" => Some(ModelTokenLimit {
|
||||||
|
max_output_tokens: 64_000,
|
||||||
|
context_window_tokens: 131_072,
|
||||||
|
}),
|
||||||
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn preflight_message_request(request: &MessageRequest) -> Result<(), ApiError> {
|
||||||
|
let Some(limit) = model_token_limit(&request.model) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let estimated_input_tokens = estimate_message_request_input_tokens(request);
|
||||||
|
let estimated_total_tokens = estimated_input_tokens.saturating_add(request.max_tokens);
|
||||||
|
if estimated_total_tokens > limit.context_window_tokens {
|
||||||
|
return Err(ApiError::ContextWindowExceeded {
|
||||||
|
model: resolve_model_alias(&request.model),
|
||||||
|
estimated_input_tokens,
|
||||||
|
requested_output_tokens: request.max_tokens,
|
||||||
|
estimated_total_tokens,
|
||||||
|
context_window_tokens: limit.context_window_tokens,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_message_request_input_tokens(request: &MessageRequest) -> u32 {
|
||||||
|
let mut estimate = estimate_serialized_tokens(&request.messages);
|
||||||
|
estimate = estimate.saturating_add(estimate_serialized_tokens(&request.system));
|
||||||
|
estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tools));
|
||||||
|
estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tool_choice));
|
||||||
|
estimate
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_serialized_tokens<T: Serialize>(value: &T) -> u32 {
|
||||||
|
serde_json::to_vec(value)
|
||||||
|
.ok()
|
||||||
|
.map_or(0, |bytes| (bytes.len() / 4 + 1) as u32)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind};
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::error::ApiError;
|
||||||
|
use crate::types::{
|
||||||
|
InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
detect_provider_kind, max_tokens_for_model, model_token_limit, preflight_message_request,
|
||||||
|
resolve_model_alias, ProviderKind,
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolves_grok_aliases() {
|
fn resolves_grok_aliases() {
|
||||||
@@ -215,4 +292,86 @@ mod tests {
|
|||||||
assert_eq!(max_tokens_for_model("opus"), 32_000);
|
assert_eq!(max_tokens_for_model("opus"), 32_000);
|
||||||
assert_eq!(max_tokens_for_model("grok-3"), 64_000);
|
assert_eq!(max_tokens_for_model("grok-3"), 64_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn returns_context_window_metadata_for_supported_models() {
|
||||||
|
assert_eq!(
|
||||||
|
model_token_limit("claude-sonnet-4-6")
|
||||||
|
.expect("claude-sonnet-4-6 should be registered")
|
||||||
|
.context_window_tokens,
|
||||||
|
200_000
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
model_token_limit("grok-mini")
|
||||||
|
.expect("grok-mini should resolve to a registered model")
|
||||||
|
.context_window_tokens,
|
||||||
|
131_072
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preflight_blocks_requests_that_exceed_the_model_context_window() {
|
||||||
|
let request = MessageRequest {
|
||||||
|
model: "claude-sonnet-4-6".to_string(),
|
||||||
|
max_tokens: 64_000,
|
||||||
|
messages: vec![InputMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![InputContentBlock::Text {
|
||||||
|
text: "x".repeat(600_000),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system: Some("Keep the answer short.".to_string()),
|
||||||
|
tools: Some(vec![ToolDefinition {
|
||||||
|
name: "weather".to_string(),
|
||||||
|
description: Some("Fetches weather".to_string()),
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": { "city": { "type": "string" } },
|
||||||
|
}),
|
||||||
|
}]),
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let error = preflight_message_request(&request)
|
||||||
|
.expect_err("oversized request should be rejected before the provider call");
|
||||||
|
|
||||||
|
match error {
|
||||||
|
ApiError::ContextWindowExceeded {
|
||||||
|
model,
|
||||||
|
estimated_input_tokens,
|
||||||
|
requested_output_tokens,
|
||||||
|
estimated_total_tokens,
|
||||||
|
context_window_tokens,
|
||||||
|
} => {
|
||||||
|
assert_eq!(model, "claude-sonnet-4-6");
|
||||||
|
assert!(estimated_input_tokens > 136_000);
|
||||||
|
assert_eq!(requested_output_tokens, 64_000);
|
||||||
|
assert!(estimated_total_tokens > context_window_tokens);
|
||||||
|
assert_eq!(context_window_tokens, 200_000);
|
||||||
|
}
|
||||||
|
other => panic!("expected context-window preflight failure, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preflight_skips_unknown_models() {
|
||||||
|
let request = MessageRequest {
|
||||||
|
model: "unknown-model".to_string(),
|
||||||
|
max_tokens: 64_000,
|
||||||
|
messages: vec![InputMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![InputContentBlock::Text {
|
||||||
|
text: "x".repeat(600_000),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
preflight_message_request(&request)
|
||||||
|
.expect("models without context metadata should skip the guarded preflight");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ use crate::types::{
|
|||||||
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{Provider, ProviderFuture};
|
use super::{preflight_message_request, Provider, ProviderFuture};
|
||||||
|
|
||||||
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
|
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
|
||||||
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
|
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
|
||||||
@@ -128,6 +128,7 @@ impl OpenAiCompatClient {
|
|||||||
stream: false,
|
stream: false,
|
||||||
..request.clone()
|
..request.clone()
|
||||||
};
|
};
|
||||||
|
preflight_message_request(&request)?;
|
||||||
let response = self.send_with_retry(&request).await?;
|
let response = self.send_with_retry(&request).await?;
|
||||||
let request_id = request_id_from_headers(response.headers());
|
let request_id = request_id_from_headers(response.headers());
|
||||||
let payload = response.json::<ChatCompletionResponse>().await?;
|
let payload = response.json::<ChatCompletionResponse>().await?;
|
||||||
@@ -142,6 +143,7 @@ impl OpenAiCompatClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageStream, ApiError> {
|
) -> Result<MessageStream, ApiError> {
|
||||||
|
preflight_message_request(request)?;
|
||||||
let response = self
|
let response = self
|
||||||
.send_with_retry(&request.clone().with_streaming())
|
.send_with_retry(&request.clone().with_streaming())
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -103,6 +103,41 @@ async fn send_message_posts_json_and_parses_response() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_blocks_oversized_requests_before_the_http_call() {
|
||||||
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
let server = spawn_server(
|
||||||
|
state.clone(),
|
||||||
|
vec![http_response("200 OK", "application/json", "{}")],
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
|
||||||
|
let error = client
|
||||||
|
.send_message(&MessageRequest {
|
||||||
|
model: "claude-sonnet-4-6".to_string(),
|
||||||
|
max_tokens: 64_000,
|
||||||
|
messages: vec![InputMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![InputContentBlock::Text {
|
||||||
|
text: "x".repeat(600_000),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system: Some("Keep the answer short.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
stream: false,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect_err("oversized request should fail local context-window preflight");
|
||||||
|
|
||||||
|
assert!(matches!(error, ApiError::ContextWindowExceeded { .. }));
|
||||||
|
assert!(
|
||||||
|
state.lock().await.is_empty(),
|
||||||
|
"preflight failure should avoid any upstream HTTP request"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn send_message_applies_request_profile_and_records_telemetry() {
|
async fn send_message_applies_request_profile_and_records_telemetry() {
|
||||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ use std::sync::Arc;
|
|||||||
use std::sync::{Mutex as StdMutex, OnceLock};
|
use std::sync::{Mutex as StdMutex, OnceLock};
|
||||||
|
|
||||||
use api::{
|
use api::{
|
||||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
|
||||||
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient,
|
ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
|
||||||
OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice,
|
OpenAiCompatClient, OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent,
|
||||||
ToolDefinition,
|
ToolChoice, ToolDefinition,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
@@ -63,6 +63,42 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() {
|
|||||||
assert_eq!(body["tools"][0]["type"], json!("function"));
|
assert_eq!(body["tools"][0]["type"], json!("function"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_message_blocks_oversized_xai_requests_before_the_http_call() {
|
||||||
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
let server = spawn_server(
|
||||||
|
state.clone(),
|
||||||
|
vec![http_response("200 OK", "application/json", "{}")],
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
|
||||||
|
.with_base_url(server.base_url());
|
||||||
|
let error = client
|
||||||
|
.send_message(&MessageRequest {
|
||||||
|
model: "grok-3".to_string(),
|
||||||
|
max_tokens: 64_000,
|
||||||
|
messages: vec![InputMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![InputContentBlock::Text {
|
||||||
|
text: "x".repeat(300_000),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system: Some("Keep the answer short.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
stream: false,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect_err("oversized request should fail local context-window preflight");
|
||||||
|
|
||||||
|
assert!(matches!(error, ApiError::ContextWindowExceeded { .. }));
|
||||||
|
assert!(
|
||||||
|
state.lock().await.is_empty(),
|
||||||
|
"preflight failure should avoid any upstream HTTP request"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn send_message_accepts_full_chat_completions_endpoint_override() {
|
async fn send_message_accepts_full_chat_completions_endpoint_override() {
|
||||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
|||||||
Reference in New Issue
Block a user