mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-08 00:54:49 +08:00
Use Anthropic count tokens for preflight
This commit is contained in:
@@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session
|
|||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||||
|
|
||||||
use super::{preflight_message_request, Provider, ProviderFuture};
|
use super::{model_token_limit, resolve_model_alias, Provider, ProviderFuture};
|
||||||
use crate::sse::SseParser;
|
use crate::sse::SseParser;
|
||||||
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
||||||
|
|
||||||
@@ -294,7 +294,7 @@ impl AnthropicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
preflight_message_request(&request)?;
|
self.preflight_message_request(&request).await?;
|
||||||
|
|
||||||
let response = self.send_with_retry(&request).await?;
|
let response = self.send_with_retry(&request).await?;
|
||||||
let request_id = request_id_from_headers(response.headers());
|
let request_id = request_id_from_headers(response.headers());
|
||||||
@@ -339,7 +339,7 @@ impl AnthropicClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<MessageStream, ApiError> {
|
) -> Result<MessageStream, ApiError> {
|
||||||
preflight_message_request(request)?;
|
self.preflight_message_request(request).await?;
|
||||||
let response = self
|
let response = self
|
||||||
.send_with_retry(&request.clone().with_streaming())
|
.send_with_retry(&request.clone().with_streaming())
|
||||||
.await?;
|
.await?;
|
||||||
@@ -466,18 +466,67 @@ impl AnthropicClient {
|
|||||||
request: &MessageRequest,
|
request: &MessageRequest,
|
||||||
) -> Result<reqwest::Response, ApiError> {
|
) -> Result<reqwest::Response, ApiError> {
|
||||||
let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
|
let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
|
||||||
|
let request_body = self.request_profile.render_json_body(request)?;
|
||||||
|
let request_builder = self.build_request(&request_url).json(&request_body);
|
||||||
|
request_builder.send().await.map_err(ApiError::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_request(&self, request_url: &str) -> reqwest::RequestBuilder {
|
||||||
let request_builder = self
|
let request_builder = self
|
||||||
.http
|
.http
|
||||||
.post(&request_url)
|
.post(request_url)
|
||||||
.header("content-type", "application/json");
|
.header("content-type", "application/json");
|
||||||
let mut request_builder = self.auth.apply(request_builder);
|
let mut request_builder = self.auth.apply(request_builder);
|
||||||
for (header_name, header_value) in self.request_profile.header_pairs() {
|
for (header_name, header_value) in self.request_profile.header_pairs() {
|
||||||
request_builder = request_builder.header(header_name, header_value);
|
request_builder = request_builder.header(header_name, header_value);
|
||||||
}
|
}
|
||||||
|
request_builder
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> {
|
||||||
|
let Some(limit) = model_token_limit(&request.model) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let counted_input_tokens = match self.count_tokens(request).await {
|
||||||
|
Ok(count) => count,
|
||||||
|
Err(_) => return Ok(()),
|
||||||
|
};
|
||||||
|
let estimated_total_tokens = counted_input_tokens.saturating_add(request.max_tokens);
|
||||||
|
if estimated_total_tokens > limit.context_window_tokens {
|
||||||
|
return Err(ApiError::ContextWindowExceeded {
|
||||||
|
model: resolve_model_alias(&request.model),
|
||||||
|
estimated_input_tokens: counted_input_tokens,
|
||||||
|
requested_output_tokens: request.max_tokens,
|
||||||
|
estimated_total_tokens,
|
||||||
|
context_window_tokens: limit.context_window_tokens,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn count_tokens(&self, request: &MessageRequest) -> Result<u32, ApiError> {
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct CountTokensResponse {
|
||||||
|
input_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
let request_url = format!("{}/v1/messages/count_tokens", self.base_url.trim_end_matches('/'));
|
||||||
let request_body = self.request_profile.render_json_body(request)?;
|
let request_body = self.request_profile.render_json_body(request)?;
|
||||||
request_builder = request_builder.json(&request_body);
|
let response = self
|
||||||
request_builder.send().await.map_err(ApiError::from)
|
.build_request(&request_url)
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
|
||||||
|
let parsed = expect_success(response)
|
||||||
|
.await?
|
||||||
|
.json::<CountTokensResponse>()
|
||||||
|
.await
|
||||||
|
.map_err(ApiError::from)?;
|
||||||
|
Ok(parsed.input_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_request_failure(&self, attempt: u32, error: &ApiError) {
|
fn record_request_failure(&self, attempt: u32, error: &ApiError) {
|
||||||
|
|||||||
Reference in New Issue
Block a user