From be561bfdeb92fce7011938e748ee20051460d6a4 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Mon, 6 Apr 2026 09:38:21 +0000 Subject: [PATCH] Use Anthropic count tokens for preflight --- rust/crates/api/src/providers/anthropic.rs | 61 +++++++++++++++++++--- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index a85924b..e19a589 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session use crate::error::ApiError; use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; -use super::{preflight_message_request, Provider, ProviderFuture}; +use super::{model_token_limit, resolve_model_alias, Provider, ProviderFuture}; use crate::sse::SseParser; use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage}; @@ -294,7 +294,7 @@ impl AnthropicClient { } } - preflight_message_request(&request)?; + self.preflight_message_request(&request).await?; let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); @@ -339,7 +339,7 @@ impl AnthropicClient { &self, request: &MessageRequest, ) -> Result { - preflight_message_request(request)?; + self.preflight_message_request(request).await?; let response = self .send_with_retry(&request.clone().with_streaming()) .await?; @@ -466,18 +466,67 @@ impl AnthropicClient { request: &MessageRequest, ) -> Result { let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_body = self.request_profile.render_json_body(request)?; + let request_builder = self.build_request(&request_url).json(&request_body); + request_builder.send().await.map_err(ApiError::from) + } + + fn build_request(&self, request_url: &str) -> reqwest::RequestBuilder { let request_builder = self .http - .post(&request_url) + .post(request_url) .header("content-type", "application/json"); let mut request_builder = self.auth.apply(request_builder); for (header_name, header_value) in self.request_profile.header_pairs() { request_builder = request_builder.header(header_name, header_value); } + request_builder + } + async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> { + let Some(limit) = model_token_limit(&request.model) else { + return Ok(()); + }; + + let counted_input_tokens = match self.count_tokens(request).await { + Ok(count) => count, + Err(_) => return Ok(()), + }; + let estimated_total_tokens = counted_input_tokens.saturating_add(request.max_tokens); + if estimated_total_tokens > limit.context_window_tokens { + return Err(ApiError::ContextWindowExceeded { + model: resolve_model_alias(&request.model), + estimated_input_tokens: counted_input_tokens, + requested_output_tokens: request.max_tokens, + estimated_total_tokens, + context_window_tokens: limit.context_window_tokens, + }); + } + + Ok(()) + } + + async fn count_tokens(&self, request: &MessageRequest) -> Result { + #[derive(serde::Deserialize)] + struct CountTokensResponse { + input_tokens: u32, + } + + let request_url = format!("{}/v1/messages/count_tokens", self.base_url.trim_end_matches('/')); let request_body = self.request_profile.render_json_body(request)?; - request_builder = request_builder.json(&request_body); - request_builder.send().await.map_err(ApiError::from) + let response = self + .build_request(&request_url) + .json(&request_body) + .send() + .await + .map_err(ApiError::from)?; + + let parsed = expect_success(response) + .await? + .json::() + .await + .map_err(ApiError::from)?; + Ok(parsed.input_tokens) } fn record_request_failure(&self, attempt: u32, error: &ApiError) {