mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-05 23:54:50 +08:00
fix: restore anthropic request profile integration
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
mod client;
|
||||
mod error;
|
||||
mod prompt_cache;
|
||||
mod providers;
|
||||
mod sse;
|
||||
mod types;
|
||||
@@ -9,6 +10,10 @@ pub use client::{
|
||||
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
|
||||
};
|
||||
pub use error::ApiError;
|
||||
pub use prompt_cache::{
|
||||
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
|
||||
PromptCacheStats,
|
||||
};
|
||||
pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource};
|
||||
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
|
||||
pub use providers::{
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use runtime::format_usd;
|
||||
use runtime::{
|
||||
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
|
||||
OAuthTokenExchangeRequest,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use telemetry::SessionTracer;
|
||||
use serde_json::{Map, Value};
|
||||
use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer};
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||
|
||||
use super::{Provider, ProviderFuture};
|
||||
use crate::sse::SseParser;
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
||||
|
||||
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||
@@ -114,6 +118,10 @@ pub struct AnthropicClient {
|
||||
max_retries: u32,
|
||||
initial_backoff: Duration,
|
||||
max_backoff: Duration,
|
||||
request_profile: AnthropicRequestProfile,
|
||||
session_tracer: Option<SessionTracer>,
|
||||
prompt_cache: Option<PromptCache>,
|
||||
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
@@ -126,6 +134,10 @@ impl AnthropicClient {
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
request_profile: AnthropicRequestProfile::default(),
|
||||
session_tracer: None,
|
||||
prompt_cache: None,
|
||||
last_prompt_cache_record: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,6 +150,10 @@ impl AnthropicClient {
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||
max_backoff: DEFAULT_MAX_BACKOFF,
|
||||
request_profile: AnthropicRequestProfile::default(),
|
||||
session_tracer: None,
|
||||
prompt_cache: None,
|
||||
last_prompt_cache_record: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +212,66 @@ impl AnthropicClient {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_session_tracer(self, _session_tracer: SessionTracer) -> Self {
|
||||
pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
|
||||
self.session_tracer = Some(session_tracer);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_client_identity(mut self, client_identity: ClientIdentity) -> Self {
|
||||
self.request_profile.client_identity = client_identity;
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
|
||||
self.request_profile = self.request_profile.with_beta(beta);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_extra_body_param(mut self, key: impl Into<String>, value: Value) -> Self {
|
||||
self.request_profile = self.request_profile.with_extra_body(key, value);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
|
||||
self.prompt_cache = Some(prompt_cache);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
|
||||
self.prompt_cache.as_ref().map(PromptCache::stats)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn request_profile(&self) -> &AnthropicRequestProfile {
|
||||
&self.request_profile
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn session_tracer(&self) -> Option<&SessionTracer> {
|
||||
self.session_tracer.as_ref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn prompt_cache(&self) -> Option<&PromptCache> {
|
||||
self.prompt_cache.as_ref()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
|
||||
self.last_prompt_cache_record
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_request_profile(mut self, request_profile: AnthropicRequestProfile) -> Self {
|
||||
self.request_profile = request_profile;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -213,6 +288,13 @@ impl AnthropicClient {
|
||||
stream: false,
|
||||
..request.clone()
|
||||
};
|
||||
|
||||
if let Some(prompt_cache) = &self.prompt_cache {
|
||||
if let Some(response) = prompt_cache.lookup_completion(&request) {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
let response = self.send_with_retry(&request).await?;
|
||||
let request_id = request_id_from_headers(response.headers());
|
||||
let mut response = response
|
||||
@@ -222,6 +304,30 @@ impl AnthropicClient {
|
||||
if response.request_id.is_none() {
|
||||
response.request_id = request_id;
|
||||
}
|
||||
|
||||
if let Some(prompt_cache) = &self.prompt_cache {
|
||||
let record = prompt_cache.record_response(&request, &response);
|
||||
self.store_last_prompt_cache_record(record);
|
||||
}
|
||||
if let Some(session_tracer) = &self.session_tracer {
|
||||
session_tracer.record_analytics(
|
||||
AnalyticsEvent::new("api", "message_usage")
|
||||
.with_property(
|
||||
"request_id",
|
||||
response
|
||||
.request_id
|
||||
.clone()
|
||||
.map_or(Value::Null, Value::String),
|
||||
)
|
||||
.with_property("total_tokens", Value::from(response.total_tokens()))
|
||||
.with_property(
|
||||
"estimated_cost_usd",
|
||||
Value::String(format_usd(
|
||||
response.usage.estimated_cost_usd(&response.model).total_cost_usd(),
|
||||
)),
|
||||
),
|
||||
);
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -238,6 +344,11 @@ impl AnthropicClient {
|
||||
parser: SseParser::new(),
|
||||
pending: VecDeque::new(),
|
||||
done: false,
|
||||
request: request.clone(),
|
||||
prompt_cache: self.prompt_cache.clone(),
|
||||
latest_usage: None,
|
||||
usage_recorded: false,
|
||||
last_prompt_cache_record: Arc::clone(&self.last_prompt_cache_record),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -290,18 +401,46 @@ impl AnthropicClient {
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
if let Some(session_tracer) = &self.session_tracer {
|
||||
session_tracer.record_http_request_started(
|
||||
attempts,
|
||||
"POST",
|
||||
"/v1/messages",
|
||||
Map::new(),
|
||||
);
|
||||
}
|
||||
match self.send_raw_request(request).await {
|
||||
Ok(response) => match expect_success(response).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Ok(response) => {
|
||||
if let Some(session_tracer) = &self.session_tracer {
|
||||
session_tracer.record_http_request_succeeded(
|
||||
attempts,
|
||||
"POST",
|
||||
"/v1/messages",
|
||||
response.status().as_u16(),
|
||||
request_id_from_headers(response.headers()),
|
||||
Map::new(),
|
||||
);
|
||||
}
|
||||
return Ok(response);
|
||||
}
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
self.record_request_failure(attempts, &error);
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
Err(error) => {
|
||||
self.record_request_failure(attempts, &error);
|
||||
return Err(error);
|
||||
}
|
||||
},
|
||||
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
||||
self.record_request_failure(attempts, &error);
|
||||
last_error = Some(error);
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
Err(error) => {
|
||||
self.record_request_failure(attempts, &error);
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
|
||||
if attempts > self.max_retries {
|
||||
@@ -325,14 +464,37 @@ impl AnthropicClient {
|
||||
let request_builder = self
|
||||
.http
|
||||
.post(&request_url)
|
||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||
.header("content-type", "application/json");
|
||||
let mut request_builder = self.auth.apply(request_builder);
|
||||
for (header_name, header_value) in self.request_profile.header_pairs() {
|
||||
request_builder = request_builder.header(header_name, header_value);
|
||||
}
|
||||
|
||||
request_builder = request_builder.json(request);
|
||||
let request_body = self.request_profile.render_json_body(request)?;
|
||||
request_builder = request_builder.json(&request_body);
|
||||
request_builder.send().await.map_err(ApiError::from)
|
||||
}
|
||||
|
||||
fn record_request_failure(&self, attempt: u32, error: &ApiError) {
|
||||
if let Some(session_tracer) = &self.session_tracer {
|
||||
session_tracer.record_http_request_failed(
|
||||
attempt,
|
||||
"POST",
|
||||
"/v1/messages",
|
||||
error.to_string(),
|
||||
error.is_retryable(),
|
||||
Map::new(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn store_last_prompt_cache_record(&self, record: PromptCacheRecord) {
|
||||
*self
|
||||
.last_prompt_cache_record
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
|
||||
}
|
||||
|
||||
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
||||
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
||||
return Err(ApiError::BackoffOverflow {
|
||||
@@ -571,6 +733,11 @@ pub struct MessageStream {
|
||||
parser: SseParser,
|
||||
pending: VecDeque<StreamEvent>,
|
||||
done: bool,
|
||||
request: MessageRequest,
|
||||
prompt_cache: Option<PromptCache>,
|
||||
latest_usage: Option<Usage>,
|
||||
usage_recorded: bool,
|
||||
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
@@ -582,6 +749,7 @@ impl MessageStream {
|
||||
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
||||
loop {
|
||||
if let Some(event) = self.pending.pop_front() {
|
||||
self.observe_event(&event);
|
||||
return Ok(Some(event));
|
||||
}
|
||||
|
||||
@@ -604,6 +772,29 @@ impl MessageStream {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn observe_event(&mut self, event: &StreamEvent) {
|
||||
match event {
|
||||
StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
|
||||
self.latest_usage = Some(usage.clone());
|
||||
}
|
||||
StreamEvent::MessageStop(_) => {
|
||||
if !self.usage_recorded {
|
||||
if let (Some(prompt_cache), Some(usage)) =
|
||||
(&self.prompt_cache, self.latest_usage.as_ref())
|
||||
{
|
||||
let record = prompt_cache.record_usage(&self.request, usage);
|
||||
*self
|
||||
.last_prompt_cache_record
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
|
||||
}
|
||||
self.usage_recorded = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|
||||
|
||||
Reference in New Issue
Block a user