diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 27b83f4..76b1896 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -28,12 +28,86 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" @@ -262,7 +336,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", "rustix 1.1.4", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -318,6 +392,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -338,6 +423,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -451,6 +537,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.9.0" @@ -464,6 +556,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -708,12 +801,24 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -1091,12 +1196,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -1145,7 +1252,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1288,6 +1395,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1300,6 +1418,19 @@ dependencies = [ "serde", ] +[[package]] +name = "server" +version = "0.1.0" +dependencies = [ + "async-stream", + "axum", + "reqwest", + "runtime", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1553,6 +1684,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tools" version = "0.1.0" @@ -1579,6 +1723,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -1617,6 +1762,7 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-core", ] @@ -1791,6 +1937,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.93" diff --git a/rust/crates/server/Cargo.toml b/rust/crates/server/Cargo.toml new file mode 100644 index 0000000..9151aef --- /dev/null +++ b/rust/crates/server/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "server" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +async-stream = "0.3" +axum = "0.8" +runtime = { path = "../runtime" } +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "net", "time"] } + +[dev-dependencies] +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] } + +[lints] +workspace = true diff --git a/rust/crates/server/src/lib.rs b/rust/crates/server/src/lib.rs new file mode 100644 index 0000000..b3386ea --- /dev/null +++ b/rust/crates/server/src/lib.rs @@ -0,0 +1,442 @@ +use std::collections::HashMap; +use std::convert::Infallible; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use async_stream::stream; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::response::sse::{Event, KeepAlive, Sse}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use runtime::{ConversationMessage, Session as RuntimeSession}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{broadcast, RwLock}; + +pub type SessionId = String; +pub type SessionStore = Arc>>; + +const BROADCAST_CAPACITY: usize = 64; + +#[derive(Clone)] +pub struct AppState { + sessions: SessionStore, + next_session_id: Arc, +} + +impl AppState { + #[must_use] + pub fn new() -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + next_session_id: Arc::new(AtomicU64::new(1)), + } + } + + fn allocate_session_id(&self) -> SessionId { + let id = self.next_session_id.fetch_add(1, Ordering::Relaxed); + format!("session-{id}") + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone)] +pub struct Session { + pub id: SessionId, + pub created_at: u64, + pub conversation: RuntimeSession, + events: broadcast::Sender, +} + +impl Session { + fn new(id: SessionId) -> Self { + let (events, _) = broadcast::channel(BROADCAST_CAPACITY); + Self { + id, + created_at: unix_timestamp_millis(), + conversation: RuntimeSession::new(), + events, + } + } + + fn subscribe(&self) -> broadcast::Receiver { + self.events.subscribe() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +enum SessionEvent { + Snapshot { + session_id: SessionId, + session: RuntimeSession, + }, + Message { + session_id: SessionId, + message: ConversationMessage, + }, +} + +impl SessionEvent { + fn event_name(&self) -> &'static str { + match self { + Self::Snapshot { .. } => "snapshot", + Self::Message { .. } => "message", + } + } + + fn to_sse_event(&self) -> Result { + Ok(Event::default() + .event(self.event_name()) + .data(serde_json::to_string(self)?)) + } +} + +#[derive(Debug, Serialize)] +struct ErrorResponse { + error: String, +} + +type ApiError = (StatusCode, Json); +type ApiResult = Result; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct CreateSessionResponse { + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSummary { + pub id: SessionId, + pub created_at: u64, + pub message_count: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ListSessionsResponse { + pub sessions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionDetailsResponse { + pub id: SessionId, + pub created_at: u64, + pub session: RuntimeSession, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SendMessageRequest { + pub message: String, +} + +#[must_use] +pub fn app(state: AppState) -> Router { + Router::new() + .route("/sessions", post(create_session).get(list_sessions)) + .route("/sessions/{id}", get(get_session)) + .route("/sessions/{id}/events", get(stream_session_events)) + .route("/sessions/{id}/message", post(send_message)) + .with_state(state) +} + +async fn create_session( + State(state): State, +) -> (StatusCode, Json) { + let session_id = state.allocate_session_id(); + let session = Session::new(session_id.clone()); + + state + .sessions + .write() + .await + .insert(session_id.clone(), session); + + ( + StatusCode::CREATED, + Json(CreateSessionResponse { session_id }), + ) +} + +async fn list_sessions(State(state): State) -> Json { + let sessions = state.sessions.read().await; + let mut summaries = sessions + .values() + .map(|session| SessionSummary { + id: session.id.clone(), + created_at: session.created_at, + message_count: session.conversation.messages.len(), + }) + .collect::>(); + summaries.sort_by(|left, right| left.id.cmp(&right.id)); + + Json(ListSessionsResponse { + sessions: summaries, + }) +} + +async fn get_session( + State(state): State, + Path(id): Path, +) -> ApiResult> { + let sessions = state.sessions.read().await; + let session = sessions + .get(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + + Ok(Json(SessionDetailsResponse { + id: session.id.clone(), + created_at: session.created_at, + session: session.conversation.clone(), + })) +} + +async fn send_message( + State(state): State, + Path(id): Path, + Json(payload): Json, +) -> ApiResult { + let message = ConversationMessage::user_text(payload.message); + let broadcaster = { + let mut sessions = state.sessions.write().await; + let session = sessions + .get_mut(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + session.conversation.messages.push(message.clone()); + session.events.clone() + }; + + let _ = broadcaster.send(SessionEvent::Message { + session_id: id, + message, + }); + + Ok(StatusCode::NO_CONTENT) +} + +async fn stream_session_events( + State(state): State, + Path(id): Path, +) -> ApiResult { + let (snapshot, mut receiver) = { + let sessions = state.sessions.read().await; + let session = sessions + .get(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + ( + SessionEvent::Snapshot { + session_id: session.id.clone(), + session: session.conversation.clone(), + }, + session.subscribe(), + ) + }; + + let stream = stream! { + if let Ok(event) = snapshot.to_sse_event() { + yield Ok::(event); + } + + loop { + match receiver.recv().await { + Ok(event) => { + if let Ok(sse_event) = event.to_sse_event() { + yield Ok::(sse_event); + } + } + Err(broadcast::error::RecvError::Lagged(_)) => continue, + Err(broadcast::error::RecvError::Closed) => break, + } + } + }; + + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))) +} + +fn unix_timestamp_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_millis() as u64 +} + +fn not_found(message: String) -> ApiError { + ( + StatusCode::NOT_FOUND, + Json(ErrorResponse { error: message }), + ) +} + +#[cfg(test)] +mod tests { + use super::{ + app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse, + }; + use reqwest::Client; + use std::net::SocketAddr; + use std::time::Duration; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tokio::time::timeout; + + struct TestServer { + address: SocketAddr, + handle: JoinHandle<()>, + } + + impl TestServer { + async fn spawn() -> Self { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("test listener should bind"); + let address = listener + .local_addr() + .expect("listener should report local address"); + let handle = tokio::spawn(async move { + axum::serve(listener, app(AppState::default())) + .await + .expect("server should run"); + }); + + Self { address, handle } + } + + fn url(&self, path: &str) -> String { + format!("http://{}{}", self.address, path) + } + } + + impl Drop for TestServer { + fn drop(&mut self) { + self.handle.abort(); + } + } + + async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse { + client + .post(server.url("/sessions")) + .send() + .await + .expect("create request should succeed") + .error_for_status() + .expect("create request should return success") + .json::() + .await + .expect("create response should parse") + } + + async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String { + loop { + if let Some(index) = buffer.find("\n\n") { + let frame = buffer[..index].to_string(); + let remainder = buffer[index + 2..].to_string(); + *buffer = remainder; + return frame; + } + + let next_chunk = timeout(Duration::from_secs(5), response.chunk()) + .await + .expect("SSE stream should yield within timeout") + .expect("SSE stream should remain readable") + .expect("SSE stream should stay open"); + buffer.push_str(&String::from_utf8_lossy(&next_chunk)); + } + } + + #[tokio::test] + async fn creates_and_lists_sessions() { + let server = TestServer::spawn().await; + let client = Client::new(); + + // given + let created = create_session(&client, &server).await; + + // when + let sessions = client + .get(server.url("/sessions")) + .send() + .await + .expect("list request should succeed") + .error_for_status() + .expect("list request should return success") + .json::() + .await + .expect("list response should parse"); + let details = client + .get(server.url(&format!("/sessions/{}", created.session_id))) + .send() + .await + .expect("details request should succeed") + .error_for_status() + .expect("details request should return success") + .json::() + .await + .expect("details response should parse"); + + // then + assert_eq!(created.session_id, "session-1"); + assert_eq!(sessions.sessions.len(), 1); + assert_eq!(sessions.sessions[0].id, created.session_id); + assert_eq!(sessions.sessions[0].message_count, 0); + assert_eq!(details.id, "session-1"); + assert!(details.session.messages.is_empty()); + } + + #[tokio::test] + async fn streams_message_events_and_persists_message_flow() { + let server = TestServer::spawn().await; + let client = Client::new(); + + // given + let created = create_session(&client, &server).await; + let mut response = client + .get(server.url(&format!("/sessions/{}/events", created.session_id))) + .send() + .await + .expect("events request should succeed") + .error_for_status() + .expect("events request should return success"); + let mut buffer = String::new(); + let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await; + + // when + let send_status = client + .post(server.url(&format!("/sessions/{}/message", created.session_id))) + .json(&super::SendMessageRequest { + message: "hello from test".to_string(), + }) + .send() + .await + .expect("message request should succeed") + .status(); + let message_frame = next_sse_frame(&mut response, &mut buffer).await; + let details = client + .get(server.url(&format!("/sessions/{}", created.session_id))) + .send() + .await + .expect("details request should succeed") + .error_for_status() + .expect("details request should return success") + .json::() + .await + .expect("details response should parse"); + + // then + assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT); + assert!(snapshot_frame.contains("event: snapshot")); + assert!(snapshot_frame.contains("\"session_id\":\"session-1\"")); + assert!(message_frame.contains("event: message")); + assert!(message_frame.contains("hello from test")); + assert_eq!(details.session.messages.len(), 1); + assert_eq!( + details.session.messages[0], + runtime::ConversationMessage::user_text("hello from test") + ); + } +}