fix: restore anthropic request profile integration

This commit is contained in:
YeonGyu-Kim
2026-04-02 11:31:53 +09:00
parent 8476d713a8
commit de589d47a5
2 changed files with 204 additions and 8 deletions

View File

@@ -1,5 +1,6 @@
mod client;
mod error;
mod prompt_cache;
mod providers;
mod sse;
mod types;
@@ -9,6 +10,10 @@ pub use client::{
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
};
pub use error::ApiError;
pub use prompt_cache::{
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
PromptCacheStats,
};
pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource};
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
pub use providers::{

View File

@@ -1,18 +1,22 @@
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::format_usd;
use runtime::{
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
OAuthTokenExchangeRequest,
};
use serde::Deserialize;
use telemetry::SessionTracer;
use serde_json::{Map, Value};
use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer};
use crate::error::ApiError;
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use super::{Provider, ProviderFuture};
use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";
@@ -114,6 +118,10 @@ pub struct AnthropicClient {
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
request_profile: AnthropicRequestProfile,
session_tracer: Option<SessionTracer>,
prompt_cache: Option<PromptCache>,
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
}
impl AnthropicClient {
@@ -126,6 +134,10 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
request_profile: AnthropicRequestProfile::default(),
session_tracer: None,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
}
}
@@ -138,6 +150,10 @@ impl AnthropicClient {
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
request_profile: AnthropicRequestProfile::default(),
session_tracer: None,
prompt_cache: None,
last_prompt_cache_record: Arc::new(Mutex::new(None)),
}
}
@@ -196,7 +212,66 @@ impl AnthropicClient {
}
#[must_use]
pub fn with_session_tracer(self, _session_tracer: SessionTracer) -> Self {
pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
self.session_tracer = Some(session_tracer);
self
}
#[must_use]
pub fn with_client_identity(mut self, client_identity: ClientIdentity) -> Self {
self.request_profile.client_identity = client_identity;
self
}
#[must_use]
pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
self.request_profile = self.request_profile.with_beta(beta);
self
}
#[must_use]
pub fn with_extra_body_param(mut self, key: impl Into<String>, value: Value) -> Self {
self.request_profile = self.request_profile.with_extra_body(key, value);
self
}
#[must_use]
pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
self.prompt_cache = Some(prompt_cache);
self
}
#[must_use]
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
self.prompt_cache.as_ref().map(PromptCache::stats)
}
#[must_use]
pub fn request_profile(&self) -> &AnthropicRequestProfile {
&self.request_profile
}
#[must_use]
pub fn session_tracer(&self) -> Option<&SessionTracer> {
self.session_tracer.as_ref()
}
#[must_use]
pub fn prompt_cache(&self) -> Option<&PromptCache> {
self.prompt_cache.as_ref()
}
#[must_use]
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
self.last_prompt_cache_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take()
}
#[must_use]
pub fn with_request_profile(mut self, request_profile: AnthropicRequestProfile) -> Self {
self.request_profile = request_profile;
self
}
@@ -213,6 +288,13 @@ impl AnthropicClient {
stream: false,
..request.clone()
};
if let Some(prompt_cache) = &self.prompt_cache {
if let Some(response) = prompt_cache.lookup_completion(&request) {
return Ok(response);
}
}
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let mut response = response
@@ -222,6 +304,30 @@ impl AnthropicClient {
if response.request_id.is_none() {
response.request_id = request_id;
}
if let Some(prompt_cache) = &self.prompt_cache {
let record = prompt_cache.record_response(&request, &response);
self.store_last_prompt_cache_record(record);
}
if let Some(session_tracer) = &self.session_tracer {
session_tracer.record_analytics(
AnalyticsEvent::new("api", "message_usage")
.with_property(
"request_id",
response
.request_id
.clone()
.map_or(Value::Null, Value::String),
)
.with_property("total_tokens", Value::from(response.total_tokens()))
.with_property(
"estimated_cost_usd",
Value::String(format_usd(
response.usage.estimated_cost_usd(&response.model).total_cost_usd(),
)),
),
);
}
Ok(response)
}
@@ -238,6 +344,11 @@ impl AnthropicClient {
parser: SseParser::new(),
pending: VecDeque::new(),
done: false,
request: request.clone(),
prompt_cache: self.prompt_cache.clone(),
latest_usage: None,
usage_recorded: false,
last_prompt_cache_record: Arc::clone(&self.last_prompt_cache_record),
})
}
@@ -290,18 +401,46 @@ impl AnthropicClient {
loop {
attempts += 1;
if let Some(session_tracer) = &self.session_tracer {
session_tracer.record_http_request_started(
attempts,
"POST",
"/v1/messages",
Map::new(),
);
}
match self.send_raw_request(request).await {
Ok(response) => match expect_success(response).await {
Ok(response) => return Ok(response),
Ok(response) => {
if let Some(session_tracer) = &self.session_tracer {
session_tracer.record_http_request_succeeded(
attempts,
"POST",
"/v1/messages",
response.status().as_u16(),
request_id_from_headers(response.headers()),
Map::new(),
);
}
return Ok(response);
}
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
self.record_request_failure(attempts, &error);
last_error = Some(error);
}
Err(error) => return Err(error),
Err(error) => {
self.record_request_failure(attempts, &error);
return Err(error);
}
},
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
self.record_request_failure(attempts, &error);
last_error = Some(error);
}
Err(error) => return Err(error),
Err(error) => {
self.record_request_failure(attempts, &error);
return Err(error);
}
}
if attempts > self.max_retries {
@@ -325,14 +464,37 @@ impl AnthropicClient {
let request_builder = self
.http
.post(&request_url)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json");
let mut request_builder = self.auth.apply(request_builder);
for (header_name, header_value) in self.request_profile.header_pairs() {
request_builder = request_builder.header(header_name, header_value);
}
request_builder = request_builder.json(request);
let request_body = self.request_profile.render_json_body(request)?;
request_builder = request_builder.json(&request_body);
request_builder.send().await.map_err(ApiError::from)
}
fn record_request_failure(&self, attempt: u32, error: &ApiError) {
if let Some(session_tracer) = &self.session_tracer {
session_tracer.record_http_request_failed(
attempt,
"POST",
"/v1/messages",
error.to_string(),
error.is_retryable(),
Map::new(),
);
}
}
fn store_last_prompt_cache_record(&self, record: PromptCacheRecord) {
*self
.last_prompt_cache_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
}
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
return Err(ApiError::BackoffOverflow {
@@ -571,6 +733,11 @@ pub struct MessageStream {
parser: SseParser,
pending: VecDeque<StreamEvent>,
done: bool,
request: MessageRequest,
prompt_cache: Option<PromptCache>,
latest_usage: Option<Usage>,
usage_recorded: bool,
last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
}
impl MessageStream {
@@ -582,6 +749,7 @@ impl MessageStream {
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop {
if let Some(event) = self.pending.pop_front() {
self.observe_event(&event);
return Ok(Some(event));
}
@@ -604,6 +772,29 @@ impl MessageStream {
}
}
}
fn observe_event(&mut self, event: &StreamEvent) {
match event {
StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
self.latest_usage = Some(usage.clone());
}
StreamEvent::MessageStop(_) => {
if !self.usage_recorded {
if let (Some(prompt_cache), Some(usage)) =
(&self.prompt_cache, self.latest_usage.as_ref())
{
let record = prompt_cache.record_usage(&self.request, usage);
*self
.last_prompt_cache_record
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
}
self.usage_recorded = true;
}
}
_ => {}
}
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {