mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-07 00:24:50 +08:00
Use Anthropic count tokens for preflight
This commit is contained in:
@@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session
|
||||
use crate::error::ApiError;
|
||||
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||
|
||||
use super::{preflight_message_request, Provider, ProviderFuture};
|
||||
use super::{model_token_limit, resolve_model_alias, Provider, ProviderFuture};
|
||||
use crate::sse::SseParser;
|
||||
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
||||
|
||||
@@ -294,7 +294,7 @@ impl AnthropicClient {
|
||||
}
|
||||
}
|
||||
|
||||
preflight_message_request(&request)?;
|
||||
self.preflight_message_request(&request).await?;
|
||||
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
@@ -339,7 +339,7 @@ impl AnthropicClient {
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
preflight_message_request(request)?;
|
||||
self.preflight_message_request(request).await?;
|
||||
let response = self
|
||||
.send_with_retry(&request.clone().with_streaming())
|
||||
.await?;
|
||||
@@ -466,18 +466,67 @@ impl AnthropicClient {
|
||||
request: &MessageRequest,
|
||||
) -> Result<reqwest::Response, ApiError> {
|
||||
let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
|
||||
let request_body = self.request_profile.render_json_body(request)?;
|
||||
let request_builder = self.build_request(&request_url).json(&request_body);
|
||||
request_builder.send().await.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
fn build_request(&self, request_url: &str) -> reqwest::RequestBuilder {
|
||||
let request_builder = self
|
||||
.http
|
||||
.post(&request_url)
|
||||
.post(request_url)
|
||||
.header("content-type", "application/json");
|
||||
let mut request_builder = self.auth.apply(request_builder);
|
||||
for (header_name, header_value) in self.request_profile.header_pairs() {
|
||||
request_builder = request_builder.header(header_name, header_value);
|
||||
}
|
||||
request_builder
|
||||
}
|
||||
|
||||
async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> {
|
||||
let Some(limit) = model_token_limit(&request.model) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let counted_input_tokens = match self.count_tokens(request).await {
|
||||
Ok(count) => count,
|
||||
Err(_) => return Ok(()),
|
||||
};
|
||||
let estimated_total_tokens = counted_input_tokens.saturating_add(request.max_tokens);
|
||||
if estimated_total_tokens > limit.context_window_tokens {
|
||||
return Err(ApiError::ContextWindowExceeded {
|
||||
model: resolve_model_alias(&request.model),
|
||||
estimated_input_tokens: counted_input_tokens,
|
||||
requested_output_tokens: request.max_tokens,
|
||||
estimated_total_tokens,
|
||||
context_window_tokens: limit.context_window_tokens,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn count_tokens(&self, request: &MessageRequest) -> Result<u32, ApiError> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CountTokensResponse {
|
||||
input_tokens: u32,
|
||||
}
|
||||
|
||||
let request_url = format!("{}/v1/messages/count_tokens", self.base_url.trim_end_matches('/'));
|
||||
let request_body = self.request_profile.render_json_body(request)?;
|
||||
request_builder = request_builder.json(&request_body);
|
||||
request_builder.send().await.map_err(ApiError::from)
|
||||
let response = self
|
||||
.build_request(&request_url)
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
|
||||
let parsed = expect_success(response)
|
||||
.await?
|
||||
.json::<CountTokensResponse>()
|
||||
.await
|
||||
.map_err(ApiError::from)?;
|
||||
Ok(parsed.input_tokens)
|
||||
}
|
||||
|
||||
fn record_request_failure(&self, attempt: u32, error: &ApiError) {
|
||||
|
||||
Reference in New Issue
Block a user