mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-06 16:14:49 +08:00
Recover the MCP lane on top of current main
This resolves the stale-branch merge against origin/main, keeps the MCP runtime wiring, and preserves prompt-approved CLI tool execution after the mock parity harness additions landed upstream. Constraint: Branch had to absorb origin/main changes through a contentful merge before more MCP work Constraint: Prompt-approved runtime tool execution must continue working with new CLI/mock parity coverage Rejected: Keep permission enforcer attached inside CliToolExecutor for conversation turns | caused prompt-approved bash parity flow to fail as a tool error Rejected: Defer the merge and continue on stale history | would leave the lane red against current main Confidence: high Scope-risk: moderate Reversibility: clean Directive: Runtime permission policy and executor-side permission enforcement are separate layers; do not reapply executor enforcement to conversation turns without revalidating mock parity harness approval flows Tested: cargo test -p rusty-claude-cli --test mock_parity_harness -- --nocapture; cargo test -p rusty-claude-cli -- --nocapture; cargo test --workspace -- --nocapture Not-tested: Additional live remote/provider scenarios beyond the existing workspace suite
This commit is contained in:
@@ -134,8 +134,8 @@ async fn execute_bash_async(
|
||||
};
|
||||
|
||||
let (output, interrupted) = output_result;
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
|
||||
let stdout = truncate_output(&String::from_utf8_lossy(&output.stdout));
|
||||
let stderr = truncate_output(&String::from_utf8_lossy(&output.stderr));
|
||||
let no_output_expected = Some(stdout.trim().is_empty() && stderr.trim().is_empty());
|
||||
let return_code_interpretation = output.status.code().and_then(|code| {
|
||||
if code == 0 {
|
||||
@@ -281,3 +281,53 @@ mod tests {
|
||||
assert!(!output.sandbox_status.expect("sandbox status").enabled);
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum output bytes before truncation (16 KiB, matching upstream).
|
||||
const MAX_OUTPUT_BYTES: usize = 16_384;
|
||||
|
||||
/// Truncate output to `MAX_OUTPUT_BYTES`, appending a marker when trimmed.
|
||||
fn truncate_output(s: &str) -> String {
|
||||
if s.len() <= MAX_OUTPUT_BYTES {
|
||||
return s.to_string();
|
||||
}
|
||||
// Find the last valid UTF-8 boundary at or before MAX_OUTPUT_BYTES
|
||||
let mut end = MAX_OUTPUT_BYTES;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
let mut truncated = s[..end].to_string();
|
||||
truncated.push_str("\n\n[output truncated — exceeded 16384 bytes]");
|
||||
truncated
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod truncation_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn short_output_unchanged() {
|
||||
let s = "hello world";
|
||||
assert_eq!(truncate_output(s), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn long_output_truncated() {
|
||||
let s = "x".repeat(20_000);
|
||||
let result = truncate_output(&s);
|
||||
assert!(result.len() < 20_000);
|
||||
assert!(result.ends_with("[output truncated — exceeded 16384 bytes]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_boundary_unchanged() {
|
||||
let s = "a".repeat(MAX_OUTPUT_BYTES);
|
||||
assert_eq!(truncate_output(&s), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_over_boundary_truncated() {
|
||||
let s = "a".repeat(MAX_OUTPUT_BYTES + 1);
|
||||
let result = truncate_output(&s);
|
||||
assert!(result.contains("[output truncated"));
|
||||
}
|
||||
}
|
||||
|
||||
1004
rust/crates/runtime/src/bash_validation.rs
Normal file
1004
rust/crates/runtime/src/bash_validation.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -847,7 +847,7 @@ mod tests {
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
_ => Err(RuntimeError::new("unexpected extra API call")),
|
||||
_ => unreachable!("extra API call"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1156,7 +1156,7 @@ mod tests {
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
_ => Err(RuntimeError::new("unexpected extra API call")),
|
||||
_ => unreachable!("extra API call"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1231,7 +1231,7 @@ mod tests {
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
_ => Err(RuntimeError::new("unexpected extra API call")),
|
||||
_ => unreachable!("extra API call"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1545,7 +1545,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn auto_compaction_threshold_defaults_and_parses_values() {
|
||||
// given / when / then
|
||||
assert_eq!(
|
||||
parse_auto_compaction_threshold(None),
|
||||
DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
|
||||
|
||||
@@ -9,6 +9,39 @@ use regex::RegexBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
/// Maximum file size that can be read (10 MB).
|
||||
const MAX_READ_SIZE: u64 = 10 * 1024 * 1024;
|
||||
|
||||
/// Maximum file size that can be written (10 MB).
|
||||
const MAX_WRITE_SIZE: usize = 10 * 1024 * 1024;
|
||||
|
||||
/// Check whether a file appears to contain binary content by examining
|
||||
/// the first chunk for NUL bytes.
|
||||
fn is_binary_file(path: &Path) -> io::Result<bool> {
|
||||
use std::io::Read;
|
||||
let mut file = fs::File::open(path)?;
|
||||
let mut buffer = [0u8; 8192];
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
Ok(buffer[..bytes_read].contains(&0))
|
||||
}
|
||||
|
||||
/// Validate that a resolved path stays within the given workspace root.
|
||||
/// Returns the canonical path on success, or an error if the path escapes
|
||||
/// the workspace boundary (e.g. via `../` traversal or symlink).
|
||||
fn validate_workspace_boundary(resolved: &Path, workspace_root: &Path) -> io::Result<()> {
|
||||
if !resolved.starts_with(workspace_root) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
format!(
|
||||
"path {} escapes workspace boundary {}",
|
||||
resolved.display(),
|
||||
workspace_root.display()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct TextFilePayload {
|
||||
#[serde(rename = "filePath")]
|
||||
@@ -135,6 +168,28 @@ pub fn read_file(
|
||||
limit: Option<usize>,
|
||||
) -> io::Result<ReadFileOutput> {
|
||||
let absolute_path = normalize_path(path)?;
|
||||
|
||||
// Check file size before reading
|
||||
let metadata = fs::metadata(&absolute_path)?;
|
||||
if metadata.len() > MAX_READ_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"file is too large ({} bytes, max {} bytes)",
|
||||
metadata.len(),
|
||||
MAX_READ_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Detect binary files
|
||||
if is_binary_file(&absolute_path)? {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"file appears to be binary",
|
||||
));
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&absolute_path)?;
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let start_index = offset.unwrap_or(0).min(lines.len());
|
||||
@@ -156,6 +211,17 @@ pub fn read_file(
|
||||
}
|
||||
|
||||
pub fn write_file(path: &str, content: &str) -> io::Result<WriteFileOutput> {
|
||||
if content.len() > MAX_WRITE_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"content is too large ({} bytes, max {} bytes)",
|
||||
content.len(),
|
||||
MAX_WRITE_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let absolute_path = normalize_path_allow_missing(path)?;
|
||||
let original_file = fs::read_to_string(&absolute_path).ok();
|
||||
if let Some(parent) = absolute_path.parent() {
|
||||
@@ -477,11 +543,72 @@ fn normalize_path_allow_missing(path: &str) -> io::Result<PathBuf> {
|
||||
Ok(candidate)
|
||||
}
|
||||
|
||||
/// Read a file with workspace boundary enforcement.
|
||||
pub fn read_file_in_workspace(
|
||||
path: &str,
|
||||
offset: Option<usize>,
|
||||
limit: Option<usize>,
|
||||
workspace_root: &Path,
|
||||
) -> io::Result<ReadFileOutput> {
|
||||
let absolute_path = normalize_path(path)?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
validate_workspace_boundary(&absolute_path, &canonical_root)?;
|
||||
read_file(path, offset, limit)
|
||||
}
|
||||
|
||||
/// Write a file with workspace boundary enforcement.
|
||||
pub fn write_file_in_workspace(
|
||||
path: &str,
|
||||
content: &str,
|
||||
workspace_root: &Path,
|
||||
) -> io::Result<WriteFileOutput> {
|
||||
let absolute_path = normalize_path_allow_missing(path)?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
validate_workspace_boundary(&absolute_path, &canonical_root)?;
|
||||
write_file(path, content)
|
||||
}
|
||||
|
||||
/// Edit a file with workspace boundary enforcement.
|
||||
pub fn edit_file_in_workspace(
|
||||
path: &str,
|
||||
old_string: &str,
|
||||
new_string: &str,
|
||||
replace_all: bool,
|
||||
workspace_root: &Path,
|
||||
) -> io::Result<EditFileOutput> {
|
||||
let absolute_path = normalize_path(path)?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
validate_workspace_boundary(&absolute_path, &canonical_root)?;
|
||||
edit_file(path, old_string, new_string, replace_all)
|
||||
}
|
||||
|
||||
/// Check whether a path is a symlink that resolves outside the workspace.
|
||||
pub fn is_symlink_escape(path: &Path, workspace_root: &Path) -> io::Result<bool> {
|
||||
let metadata = fs::symlink_metadata(path)?;
|
||||
if !metadata.is_symlink() {
|
||||
return Ok(false);
|
||||
}
|
||||
let resolved = path.canonicalize()?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
Ok(!resolved.starts_with(&canonical_root))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{edit_file, glob_search, grep_search, read_file, write_file, GrepSearchInput};
|
||||
use super::{
|
||||
edit_file, glob_search, grep_search, is_symlink_escape, read_file, read_file_in_workspace,
|
||||
write_file, GrepSearchInput, MAX_WRITE_SIZE,
|
||||
};
|
||||
|
||||
fn temp_path(name: &str) -> std::path::PathBuf {
|
||||
let unique = SystemTime::now()
|
||||
@@ -513,6 +640,73 @@ mod tests {
|
||||
assert!(output.replace_all);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_binary_files() {
|
||||
let path = temp_path("binary-test.bin");
|
||||
std::fs::write(&path, b"\x00\x01\x02\x03binary content").expect("write should succeed");
|
||||
let result = read_file(path.to_string_lossy().as_ref(), None, None);
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(error.to_string().contains("binary"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_oversized_writes() {
|
||||
let path = temp_path("oversize-write.txt");
|
||||
let huge = "x".repeat(MAX_WRITE_SIZE + 1);
|
||||
let result = write_file(path.to_string_lossy().as_ref(), &huge);
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(error.to_string().contains("too large"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforces_workspace_boundary() {
|
||||
let workspace = temp_path("workspace-boundary");
|
||||
std::fs::create_dir_all(&workspace).expect("workspace dir should be created");
|
||||
let inside = workspace.join("inside.txt");
|
||||
write_file(inside.to_string_lossy().as_ref(), "safe content")
|
||||
.expect("write inside workspace should succeed");
|
||||
|
||||
// Reading inside workspace should succeed
|
||||
let result =
|
||||
read_file_in_workspace(inside.to_string_lossy().as_ref(), None, None, &workspace);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Reading outside workspace should fail
|
||||
let outside = temp_path("outside-boundary.txt");
|
||||
write_file(outside.to_string_lossy().as_ref(), "unsafe content")
|
||||
.expect("write outside should succeed");
|
||||
let result =
|
||||
read_file_in_workspace(outside.to_string_lossy().as_ref(), None, None, &workspace);
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.kind(), std::io::ErrorKind::PermissionDenied);
|
||||
assert!(error.to_string().contains("escapes workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_symlink_escape() {
|
||||
let workspace = temp_path("symlink-workspace");
|
||||
std::fs::create_dir_all(&workspace).expect("workspace dir should be created");
|
||||
let outside = temp_path("symlink-target.txt");
|
||||
std::fs::write(&outside, "target content").expect("target should write");
|
||||
|
||||
let link_path = workspace.join("escape-link.txt");
|
||||
#[cfg(unix)]
|
||||
{
|
||||
std::os::unix::fs::symlink(&outside, &link_path).expect("symlink should create");
|
||||
assert!(is_symlink_escape(&link_path, &workspace).expect("check should succeed"));
|
||||
}
|
||||
|
||||
// Non-symlink file should not be an escape
|
||||
let normal = workspace.join("normal.txt");
|
||||
std::fs::write(&normal, "normal content").expect("normal file should write");
|
||||
assert!(!is_symlink_escape(&normal, &workspace).expect("check should succeed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn globs_and_greps_directory() {
|
||||
let dir = temp_path("search-dir");
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod bash;
|
||||
pub mod bash_validation;
|
||||
mod bootstrap;
|
||||
mod compact;
|
||||
mod config;
|
||||
@@ -6,16 +7,21 @@ mod conversation;
|
||||
mod file_ops;
|
||||
mod hooks;
|
||||
mod json;
|
||||
pub mod lsp_client;
|
||||
mod mcp;
|
||||
mod mcp_client;
|
||||
mod mcp_stdio;
|
||||
pub mod mcp_tool_bridge;
|
||||
mod oauth;
|
||||
pub mod permission_enforcer;
|
||||
mod permissions;
|
||||
mod prompt;
|
||||
mod remote;
|
||||
pub mod sandbox;
|
||||
mod session;
|
||||
mod sse;
|
||||
pub mod task_registry;
|
||||
pub mod team_cron_registry;
|
||||
mod usage;
|
||||
|
||||
pub use bash::{execute_bash, BashCommandInput, BashCommandOutput};
|
||||
|
||||
746
rust/crates/runtime/src/lsp_client.rs
Normal file
746
rust/crates/runtime/src/lsp_client.rs
Normal file
@@ -0,0 +1,746 @@
|
||||
//! LSP (Language Server Protocol) client registry for tool dispatch.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Supported LSP actions.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LspAction {
|
||||
Diagnostics,
|
||||
Hover,
|
||||
Definition,
|
||||
References,
|
||||
Completion,
|
||||
Symbols,
|
||||
Format,
|
||||
}
|
||||
|
||||
impl LspAction {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"diagnostics" => Some(Self::Diagnostics),
|
||||
"hover" => Some(Self::Hover),
|
||||
"definition" | "goto_definition" => Some(Self::Definition),
|
||||
"references" | "find_references" => Some(Self::References),
|
||||
"completion" | "completions" => Some(Self::Completion),
|
||||
"symbols" | "document_symbols" => Some(Self::Symbols),
|
||||
"format" | "formatting" => Some(Self::Format),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspDiagnostic {
|
||||
pub path: String,
|
||||
pub line: u32,
|
||||
pub character: u32,
|
||||
pub severity: String,
|
||||
pub message: String,
|
||||
pub source: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspLocation {
|
||||
pub path: String,
|
||||
pub line: u32,
|
||||
pub character: u32,
|
||||
pub end_line: Option<u32>,
|
||||
pub end_character: Option<u32>,
|
||||
pub preview: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspHoverResult {
|
||||
pub content: String,
|
||||
pub language: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspCompletionItem {
|
||||
pub label: String,
|
||||
pub kind: Option<String>,
|
||||
pub detail: Option<String>,
|
||||
pub insert_text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspSymbol {
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub path: String,
|
||||
pub line: u32,
|
||||
pub character: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LspServerStatus {
|
||||
Connected,
|
||||
Disconnected,
|
||||
Starting,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LspServerStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Connected => write!(f, "connected"),
|
||||
Self::Disconnected => write!(f, "disconnected"),
|
||||
Self::Starting => write!(f, "starting"),
|
||||
Self::Error => write!(f, "error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspServerState {
|
||||
pub language: String,
|
||||
pub status: LspServerStatus,
|
||||
pub root_path: Option<String>,
|
||||
pub capabilities: Vec<String>,
|
||||
pub diagnostics: Vec<LspDiagnostic>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct LspRegistry {
|
||||
inner: Arc<Mutex<RegistryInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct RegistryInner {
|
||||
servers: HashMap<String, LspServerState>,
|
||||
}
|
||||
|
||||
impl LspRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn register(
|
||||
&self,
|
||||
language: &str,
|
||||
status: LspServerStatus,
|
||||
root_path: Option<&str>,
|
||||
capabilities: Vec<String>,
|
||||
) {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.insert(
|
||||
language.to_owned(),
|
||||
LspServerState {
|
||||
language: language.to_owned(),
|
||||
status,
|
||||
root_path: root_path.map(str::to_owned),
|
||||
capabilities,
|
||||
diagnostics: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get(&self, language: &str) -> Option<LspServerState> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.get(language).cloned()
|
||||
}
|
||||
|
||||
/// Find the appropriate server for a file path based on extension.
|
||||
pub fn find_server_for_path(&self, path: &str) -> Option<LspServerState> {
|
||||
let ext = std::path::Path::new(path)
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let language = match ext {
|
||||
"rs" => "rust",
|
||||
"ts" | "tsx" => "typescript",
|
||||
"js" | "jsx" => "javascript",
|
||||
"py" => "python",
|
||||
"go" => "go",
|
||||
"java" => "java",
|
||||
"c" | "h" => "c",
|
||||
"cpp" | "hpp" | "cc" => "cpp",
|
||||
"rb" => "ruby",
|
||||
"lua" => "lua",
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
self.get(language)
|
||||
}
|
||||
|
||||
/// List all registered servers.
|
||||
pub fn list_servers(&self) -> Vec<LspServerState> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Add diagnostics to a server.
|
||||
pub fn add_diagnostics(
|
||||
&self,
|
||||
language: &str,
|
||||
diagnostics: Vec<LspDiagnostic>,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
let server = inner
|
||||
.servers
|
||||
.get_mut(language)
|
||||
.ok_or_else(|| format!("LSP server not found for language: {language}"))?;
|
||||
server.diagnostics.extend(diagnostics);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get diagnostics for a specific file path.
|
||||
pub fn get_diagnostics(&self, path: &str) -> Vec<LspDiagnostic> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner
|
||||
.servers
|
||||
.values()
|
||||
.flat_map(|s| &s.diagnostics)
|
||||
.filter(|d| d.path == path)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clear diagnostics for a language server.
|
||||
pub fn clear_diagnostics(&self, language: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
let server = inner
|
||||
.servers
|
||||
.get_mut(language)
|
||||
.ok_or_else(|| format!("LSP server not found for language: {language}"))?;
|
||||
server.diagnostics.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect a server.
|
||||
pub fn disconnect(&self, language: &str) -> Option<LspServerState> {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.remove(language)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Dispatch an LSP action and return a structured result.
|
||||
pub fn dispatch(
|
||||
&self,
|
||||
action: &str,
|
||||
path: Option<&str>,
|
||||
line: Option<u32>,
|
||||
character: Option<u32>,
|
||||
_query: Option<&str>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let lsp_action =
|
||||
LspAction::from_str(action).ok_or_else(|| format!("unknown LSP action: {action}"))?;
|
||||
|
||||
// For diagnostics, we can check existing cached diagnostics
|
||||
if lsp_action == LspAction::Diagnostics {
|
||||
if let Some(path) = path {
|
||||
let diags = self.get_diagnostics(path);
|
||||
return Ok(serde_json::json!({
|
||||
"action": "diagnostics",
|
||||
"path": path,
|
||||
"diagnostics": diags,
|
||||
"count": diags.len()
|
||||
}));
|
||||
}
|
||||
// All diagnostics across all servers
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
let all_diags: Vec<_> = inner
|
||||
.servers
|
||||
.values()
|
||||
.flat_map(|s| &s.diagnostics)
|
||||
.collect();
|
||||
return Ok(serde_json::json!({
|
||||
"action": "diagnostics",
|
||||
"diagnostics": all_diags,
|
||||
"count": all_diags.len()
|
||||
}));
|
||||
}
|
||||
|
||||
// For other actions, we need a connected server for the given file
|
||||
let path = path.ok_or("path is required for this LSP action")?;
|
||||
let server = self
|
||||
.find_server_for_path(path)
|
||||
.ok_or_else(|| format!("no LSP server available for path: {path}"))?;
|
||||
|
||||
if server.status != LspServerStatus::Connected {
|
||||
return Err(format!(
|
||||
"LSP server for '{}' is not connected (status: {})",
|
||||
server.language, server.status
|
||||
));
|
||||
}
|
||||
|
||||
// Return structured placeholder — actual LSP JSON-RPC calls would
|
||||
// go through the real LSP process here.
|
||||
Ok(serde_json::json!({
|
||||
"action": action,
|
||||
"path": path,
|
||||
"line": line,
|
||||
"character": character,
|
||||
"language": server.language,
|
||||
"status": "dispatched",
|
||||
"message": format!("LSP {} dispatched to {} server", action, server.language)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn registers_and_retrieves_server() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register(
|
||||
"rust",
|
||||
LspServerStatus::Connected,
|
||||
Some("/workspace"),
|
||||
vec!["hover".into(), "completion".into()],
|
||||
);
|
||||
|
||||
let server = registry.get("rust").expect("should exist");
|
||||
assert_eq!(server.language, "rust");
|
||||
assert_eq!(server.status, LspServerStatus::Connected);
|
||||
assert_eq!(server.capabilities.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finds_server_by_file_extension() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("typescript", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
let rs_server = registry.find_server_for_path("src/main.rs").unwrap();
|
||||
assert_eq!(rs_server.language, "rust");
|
||||
|
||||
let ts_server = registry.find_server_for_path("src/index.ts").unwrap();
|
||||
assert_eq!(ts_server.language, "typescript");
|
||||
|
||||
assert!(registry.find_server_for_path("data.csv").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manages_diagnostics() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/main.rs".into(),
|
||||
line: 10,
|
||||
character: 5,
|
||||
severity: "error".into(),
|
||||
message: "mismatched types".into(),
|
||||
source: Some("rust-analyzer".into()),
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let diags = registry.get_diagnostics("src/main.rs");
|
||||
assert_eq!(diags.len(), 1);
|
||||
assert_eq!(diags[0].message, "mismatched types");
|
||||
|
||||
registry.clear_diagnostics("rust").unwrap();
|
||||
assert!(registry.get_diagnostics("src/main.rs").is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatches_diagnostics_action() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/lib.rs".into(),
|
||||
line: 1,
|
||||
character: 0,
|
||||
severity: "warning".into(),
|
||||
message: "unused import".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = registry
|
||||
.dispatch("diagnostics", Some("src/lib.rs"), None, None, None)
|
||||
.unwrap();
|
||||
assert_eq!(result["count"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatches_hover_action() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
let result = registry
|
||||
.dispatch("hover", Some("src/main.rs"), Some(10), Some(5), None)
|
||||
.unwrap();
|
||||
assert_eq!(result["action"], "hover");
|
||||
assert_eq!(result["language"], "rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_action_on_disconnected_server() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Disconnected, None, vec![]);
|
||||
|
||||
assert!(registry
|
||||
.dispatch("hover", Some("src/main.rs"), Some(1), Some(0), None)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unknown_action() {
|
||||
let registry = LspRegistry::new();
|
||||
assert!(registry
|
||||
.dispatch("unknown_action", Some("file.rs"), None, None, None)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnects_server() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
assert_eq!(registry.len(), 1);
|
||||
|
||||
let removed = registry.disconnect("rust");
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lsp_action_from_str_all_aliases() {
|
||||
// given
|
||||
let cases = [
|
||||
("diagnostics", Some(LspAction::Diagnostics)),
|
||||
("hover", Some(LspAction::Hover)),
|
||||
("definition", Some(LspAction::Definition)),
|
||||
("goto_definition", Some(LspAction::Definition)),
|
||||
("references", Some(LspAction::References)),
|
||||
("find_references", Some(LspAction::References)),
|
||||
("completion", Some(LspAction::Completion)),
|
||||
("completions", Some(LspAction::Completion)),
|
||||
("symbols", Some(LspAction::Symbols)),
|
||||
("document_symbols", Some(LspAction::Symbols)),
|
||||
("format", Some(LspAction::Format)),
|
||||
("formatting", Some(LspAction::Format)),
|
||||
("unknown", None),
|
||||
];
|
||||
|
||||
// when
|
||||
let resolved: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(input, expected)| (input, LspAction::from_str(input), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
for (input, actual, expected) in resolved {
|
||||
assert_eq!(actual, expected, "unexpected action resolution for {input}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lsp_server_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(LspServerStatus::Connected, "connected"),
|
||||
(LspServerStatus::Disconnected, "disconnected"),
|
||||
(LspServerStatus::Starting, "starting"),
|
||||
(LspServerStatus::Error, "error"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("connected".to_string(), "connected"),
|
||||
("disconnected".to_string(), "disconnected"),
|
||||
("starting".to_string(), "starting"),
|
||||
("error".to_string(), "error"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_diagnostics_without_path_aggregates() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/lib.rs".into(),
|
||||
line: 1,
|
||||
character: 0,
|
||||
severity: "warning".into(),
|
||||
message: "unused import".into(),
|
||||
source: Some("rust-analyzer".into()),
|
||||
}],
|
||||
)
|
||||
.expect("rust diagnostics should add");
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"python",
|
||||
vec![LspDiagnostic {
|
||||
path: "script.py".into(),
|
||||
line: 2,
|
||||
character: 4,
|
||||
severity: "error".into(),
|
||||
message: "undefined name".into(),
|
||||
source: Some("pyright".into()),
|
||||
}],
|
||||
)
|
||||
.expect("python diagnostics should add");
|
||||
|
||||
// when
|
||||
let result = registry
|
||||
.dispatch("diagnostics", None, None, None, None)
|
||||
.expect("aggregate diagnostics should work");
|
||||
|
||||
// then
|
||||
assert_eq!(result["action"], "diagnostics");
|
||||
assert_eq!(result["count"], 2);
|
||||
assert_eq!(result["diagnostics"].as_array().map(Vec::len), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_non_diagnostics_requires_path() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", None, Some(1), Some(0), None);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
result.expect_err("path should be required"),
|
||||
"path is required for this LSP action"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_no_server_for_path_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", Some("notes.md"), Some(1), Some(0), None);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing server should fail");
|
||||
assert!(error.contains("no LSP server available for path: notes.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_disconnected_server_error_payload() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("typescript", LspServerStatus::Disconnected, None, vec![]);
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", Some("src/index.ts"), Some(3), Some(2), None);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("disconnected server should fail");
|
||||
assert!(error.contains("typescript"));
|
||||
assert!(error.contains("disconnected"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_server_for_all_extensions() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
for language in [
|
||||
"rust",
|
||||
"typescript",
|
||||
"javascript",
|
||||
"python",
|
||||
"go",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
"ruby",
|
||||
"lua",
|
||||
] {
|
||||
registry.register(language, LspServerStatus::Connected, None, vec![]);
|
||||
}
|
||||
let cases = [
|
||||
("src/main.rs", "rust"),
|
||||
("src/index.ts", "typescript"),
|
||||
("src/view.tsx", "typescript"),
|
||||
("src/app.js", "javascript"),
|
||||
("src/app.jsx", "javascript"),
|
||||
("script.py", "python"),
|
||||
("main.go", "go"),
|
||||
("Main.java", "java"),
|
||||
("native.c", "c"),
|
||||
("native.h", "c"),
|
||||
("native.cpp", "cpp"),
|
||||
("native.hpp", "cpp"),
|
||||
("native.cc", "cpp"),
|
||||
("script.rb", "ruby"),
|
||||
("script.lua", "lua"),
|
||||
];
|
||||
|
||||
// when
|
||||
let resolved: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(path, expected)| {
|
||||
(
|
||||
path,
|
||||
registry
|
||||
.find_server_for_path(path)
|
||||
.map(|server| server.language),
|
||||
expected,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// then
|
||||
for (path, actual, expected) in resolved {
|
||||
assert_eq!(
|
||||
actual.as_deref(),
|
||||
Some(expected),
|
||||
"unexpected mapping for {path}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_server_for_path_no_extension() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
// when
|
||||
let result = registry.find_server_for_path("Makefile");
|
||||
|
||||
// then
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_servers_with_multiple() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("typescript", LspServerStatus::Starting, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Error, None, vec![]);
|
||||
|
||||
// when
|
||||
let servers = registry.list_servers();
|
||||
|
||||
// then
|
||||
assert_eq!(servers.len(), 3);
|
||||
assert!(servers.iter().any(|server| server.language == "rust"));
|
||||
assert!(servers.iter().any(|server| server.language == "typescript"));
|
||||
assert!(servers.iter().any(|server| server.language == "python"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_missing_server_returns_none() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let server = registry.get("missing");
|
||||
|
||||
// then
|
||||
assert!(server.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_diagnostics_missing_language_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.add_diagnostics("missing", vec![]);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing language should fail");
|
||||
assert!(error.contains("LSP server not found for language: missing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_diagnostics_across_servers() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
let shared_path = "shared/file.txt";
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: shared_path.into(),
|
||||
line: 4,
|
||||
character: 1,
|
||||
severity: "warning".into(),
|
||||
message: "warn".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.expect("rust diagnostics should add");
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"python",
|
||||
vec![LspDiagnostic {
|
||||
path: shared_path.into(),
|
||||
line: 8,
|
||||
character: 3,
|
||||
severity: "error".into(),
|
||||
message: "err".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.expect("python diagnostics should add");
|
||||
|
||||
// when
|
||||
let diagnostics = registry.get_diagnostics(shared_path);
|
||||
|
||||
// then
|
||||
assert_eq!(diagnostics.len(), 2);
|
||||
assert!(diagnostics
|
||||
.iter()
|
||||
.any(|diagnostic| diagnostic.message == "warn"));
|
||||
assert!(diagnostics
|
||||
.iter()
|
||||
.any(|diagnostic| diagnostic.message == "err"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_diagnostics_missing_language_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.clear_diagnostics("missing");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing language should fail");
|
||||
assert!(error.contains("LSP server not found for language: missing"));
|
||||
}
|
||||
}
|
||||
907
rust/crates/runtime/src/mcp_tool_bridge.rs
Normal file
907
rust/crates/runtime/src/mcp_tool_bridge.rs
Normal file
@@ -0,0 +1,907 @@
|
||||
//! Bridge between MCP tool surface (ListMcpResources, ReadMcpResource, McpAuth, MCP)
|
||||
//! and the existing McpServerManager runtime.
|
||||
//!
|
||||
//! Provides a stateful client registry that tool handlers can use to
|
||||
//! connect to MCP servers and invoke their capabilities.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use crate::mcp::mcp_tool_name;
|
||||
use crate::mcp_stdio::McpServerManager;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Status of a managed MCP server connection.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum McpConnectionStatus {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
AuthRequired,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for McpConnectionStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Disconnected => write!(f, "disconnected"),
|
||||
Self::Connecting => write!(f, "connecting"),
|
||||
Self::Connected => write!(f, "connected"),
|
||||
Self::AuthRequired => write!(f, "auth_required"),
|
||||
Self::Error => write!(f, "error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata about an MCP resource.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpResourceInfo {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
/// Metadata about an MCP tool exposed by a server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub input_schema: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Tracked state of an MCP server connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpServerState {
|
||||
pub server_name: String,
|
||||
pub status: McpConnectionStatus,
|
||||
pub tools: Vec<McpToolInfo>,
|
||||
pub resources: Vec<McpResourceInfo>,
|
||||
pub server_info: Option<String>,
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpToolRegistry {
|
||||
inner: Arc<Mutex<HashMap<String, McpServerState>>>,
|
||||
manager: Arc<OnceLock<Arc<Mutex<McpServerManager>>>>,
|
||||
}
|
||||
|
||||
impl McpToolRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn set_manager(
|
||||
&self,
|
||||
manager: Arc<Mutex<McpServerManager>>,
|
||||
) -> Result<(), Arc<Mutex<McpServerManager>>> {
|
||||
self.manager.set(manager)
|
||||
}
|
||||
|
||||
pub fn register_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
status: McpConnectionStatus,
|
||||
tools: Vec<McpToolInfo>,
|
||||
resources: Vec<McpResourceInfo>,
|
||||
server_info: Option<String>,
|
||||
) {
|
||||
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.insert(
|
||||
server_name.to_owned(),
|
||||
McpServerState {
|
||||
server_name: server_name.to_owned(),
|
||||
status,
|
||||
tools,
|
||||
resources,
|
||||
server_info,
|
||||
error_message: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_server(&self, server_name: &str) -> Option<McpServerState> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.get(server_name).cloned()
|
||||
}
|
||||
|
||||
pub fn list_servers(&self) -> Vec<McpServerState> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn list_resources(&self, server_name: &str) -> Result<Vec<McpResourceInfo>, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
match inner.get(server_name) {
|
||||
Some(state) => {
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
Ok(state.resources.clone())
|
||||
}
|
||||
None => Err(format!("server '{}' not found", server_name)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_resource(&self, server_name: &str, uri: &str) -> Result<McpResourceInfo, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
let state = inner
|
||||
.get(server_name)
|
||||
.ok_or_else(|| format!("server '{}' not found", server_name))?;
|
||||
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
|
||||
state
|
||||
.resources
|
||||
.iter()
|
||||
.find(|r| r.uri == uri)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("resource '{}' not found on server '{}'", uri, server_name))
|
||||
}
|
||||
|
||||
pub fn list_tools(&self, server_name: &str) -> Result<Vec<McpToolInfo>, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
match inner.get(server_name) {
|
||||
Some(state) => {
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
Ok(state.tools.clone())
|
||||
}
|
||||
None => Err(format!("server '{}' not found", server_name)),
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_tool_call(
|
||||
manager: Arc<Mutex<McpServerManager>>,
|
||||
qualified_tool_name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let join_handle = std::thread::Builder::new()
|
||||
.name(format!("mcp-tool-call-{qualified_tool_name}"))
|
||||
.spawn(move || {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|error| format!("failed to create MCP tool runtime: {error}"))?;
|
||||
|
||||
runtime.block_on(async move {
|
||||
let response = {
|
||||
let mut manager = manager
|
||||
.lock()
|
||||
.map_err(|_| "mcp server manager lock poisoned".to_string())?;
|
||||
manager.discover_tools().await.map_err(|error| error.to_string())?;
|
||||
let response = manager
|
||||
.call_tool(&qualified_tool_name, arguments)
|
||||
.await
|
||||
.map_err(|error| error.to_string());
|
||||
let shutdown = manager.shutdown().await.map_err(|error| error.to_string());
|
||||
|
||||
match (response, shutdown) {
|
||||
(Ok(response), Ok(())) => Ok(response),
|
||||
(Err(error), Ok(())) | (Err(error), Err(_)) => Err(error),
|
||||
(Ok(_), Err(error)) => Err(error),
|
||||
}
|
||||
}?;
|
||||
|
||||
if let Some(error) = response.error {
|
||||
return Err(format!(
|
||||
"MCP server returned JSON-RPC error for tools/call: {} ({})",
|
||||
error.message, error.code
|
||||
));
|
||||
}
|
||||
|
||||
let result = response.result.ok_or_else(|| {
|
||||
"MCP server returned no result for tools/call".to_string()
|
||||
})?;
|
||||
|
||||
serde_json::to_value(result)
|
||||
.map_err(|error| format!("failed to serialize MCP tool result: {error}"))
|
||||
})
|
||||
})
|
||||
.map_err(|error| format!("failed to spawn MCP tool call thread: {error}"))?;
|
||||
|
||||
join_handle.join().map_err(|panic_payload| {
|
||||
if let Some(message) = panic_payload.downcast_ref::<&str>() {
|
||||
format!("MCP tool call thread panicked: {message}")
|
||||
} else if let Some(message) = panic_payload.downcast_ref::<String>() {
|
||||
format!("MCP tool call thread panicked: {message}")
|
||||
} else {
|
||||
"MCP tool call thread panicked".to_string()
|
||||
}
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn call_tool(
|
||||
&self,
|
||||
server_name: &str,
|
||||
tool_name: &str,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
let state = inner
|
||||
.get(server_name)
|
||||
.ok_or_else(|| format!("server '{}' not found", server_name))?;
|
||||
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
|
||||
if !state.tools.iter().any(|t| t.name == tool_name) {
|
||||
return Err(format!(
|
||||
"tool '{}' not found on server '{}'",
|
||||
tool_name, server_name
|
||||
));
|
||||
}
|
||||
|
||||
drop(inner);
|
||||
|
||||
let manager = self
|
||||
.manager
|
||||
.get()
|
||||
.cloned()
|
||||
.ok_or_else(|| "MCP server manager is not configured".to_string())?;
|
||||
|
||||
Self::spawn_tool_call(
|
||||
manager,
|
||||
mcp_tool_name(server_name, tool_name),
|
||||
(!arguments.is_null()).then(|| arguments.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set auth status for a server.
|
||||
pub fn set_auth_status(
|
||||
&self,
|
||||
server_name: &str,
|
||||
status: McpConnectionStatus,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
let state = inner
|
||||
.get_mut(server_name)
|
||||
.ok_or_else(|| format!("server '{}' not found", server_name))?;
|
||||
state.status = status;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect / remove a server.
|
||||
pub fn disconnect(&self, server_name: &str) -> Option<McpServerState> {
|
||||
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.remove(server_name)
|
||||
}
|
||||
|
||||
/// Number of registered servers.
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::*;
|
||||
use crate::config::{
|
||||
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
|
||||
};
|
||||
|
||||
fn temp_dir() -> PathBuf {
|
||||
static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0);
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!("runtime-mcp-tool-bridge-{nanos}-{unique_id}"))
|
||||
}
|
||||
|
||||
fn cleanup_script(script_path: &Path) {
|
||||
if let Some(root) = script_path.parent() {
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
}
|
||||
|
||||
fn write_bridge_mcp_server_script() -> PathBuf {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("temp dir");
|
||||
let script_path = root.join("bridge-mcp-server.py");
|
||||
let script = [
|
||||
"#!/usr/bin/env python3",
|
||||
"import json, os, sys",
|
||||
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
|
||||
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
|
||||
"",
|
||||
"def log(method):",
|
||||
" if LOG_PATH:",
|
||||
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
|
||||
" handle.write(f'{method}\\n')",
|
||||
"",
|
||||
"def read_message():",
|
||||
" header = b''",
|
||||
r" while not header.endswith(b'\r\n\r\n'):",
|
||||
" chunk = sys.stdin.buffer.read(1)",
|
||||
" if not chunk:",
|
||||
" return None",
|
||||
" header += chunk",
|
||||
" length = 0",
|
||||
r" for line in header.decode().split('\r\n'):",
|
||||
r" if line.lower().startswith('content-length:'):",
|
||||
r" length = int(line.split(':', 1)[1].strip())",
|
||||
" payload = sys.stdin.buffer.read(length)",
|
||||
" return json.loads(payload.decode())",
|
||||
"",
|
||||
"def send_message(message):",
|
||||
" payload = json.dumps(message).encode()",
|
||||
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
|
||||
" sys.stdout.buffer.flush()",
|
||||
"",
|
||||
"while True:",
|
||||
" request = read_message()",
|
||||
" if request is None:",
|
||||
" break",
|
||||
" method = request['method']",
|
||||
" log(method)",
|
||||
" if method == 'initialize':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'protocolVersion': request['params']['protocolVersion'],",
|
||||
" 'capabilities': {'tools': {}},",
|
||||
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/list':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'tools': [",
|
||||
" {",
|
||||
" 'name': 'echo',",
|
||||
" 'description': f'Echo tool for {LABEL}',",
|
||||
" 'inputSchema': {",
|
||||
" 'type': 'object',",
|
||||
" 'properties': {'text': {'type': 'string'}},",
|
||||
" 'required': ['text']",
|
||||
" }",
|
||||
" }",
|
||||
" ]",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/call':",
|
||||
" args = request['params'].get('arguments') or {}",
|
||||
" text = args.get('text', '')",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
|
||||
" 'structuredContent': {'server': LABEL, 'echoed': text},",
|
||||
" 'isError': False",
|
||||
" }",
|
||||
" })",
|
||||
" else:",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
|
||||
" })",
|
||||
"",
|
||||
]
|
||||
.join("\n");
|
||||
fs::write(&script_path, script).expect("write script");
|
||||
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, permissions).expect("chmod");
|
||||
script_path
|
||||
}
|
||||
|
||||
fn manager_server_config(
|
||||
script_path: &Path,
|
||||
server_name: &str,
|
||||
log_path: &Path,
|
||||
) -> ScopedMcpServerConfig {
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||
command: "python3".to_string(),
|
||||
args: vec![script_path.to_string_lossy().into_owned()],
|
||||
env: BTreeMap::from([
|
||||
("MCP_SERVER_LABEL".to_string(), server_name.to_string()),
|
||||
(
|
||||
"MCP_LOG_PATH".to_string(),
|
||||
log_path.to_string_lossy().into_owned(),
|
||||
),
|
||||
]),
|
||||
tool_call_timeout_ms: Some(1_000),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registers_and_retrieves_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"test-server",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "greet".into(),
|
||||
description: Some("Greet someone".into()),
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![McpResourceInfo {
|
||||
uri: "res://data".into(),
|
||||
name: "Data".into(),
|
||||
description: None,
|
||||
mime_type: Some("application/json".into()),
|
||||
}],
|
||||
Some("TestServer v1.0".into()),
|
||||
);
|
||||
|
||||
let server = registry.get_server("test-server").expect("should exist");
|
||||
assert_eq!(server.status, McpConnectionStatus::Connected);
|
||||
assert_eq!(server.tools.len(), 1);
|
||||
assert_eq!(server.resources.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_resources_from_connected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![McpResourceInfo {
|
||||
uri: "res://alpha".into(),
|
||||
name: "Alpha".into(),
|
||||
description: None,
|
||||
mime_type: None,
|
||||
}],
|
||||
None,
|
||||
);
|
||||
|
||||
let resources = registry.list_resources("srv").expect("should succeed");
|
||||
assert_eq!(resources.len(), 1);
|
||||
assert_eq!(resources[0].uri, "res://alpha");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_resource_listing_for_disconnected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Disconnected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
assert!(registry.list_resources("srv").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_specific_resource() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![McpResourceInfo {
|
||||
uri: "res://data".into(),
|
||||
name: "Data".into(),
|
||||
description: Some("Test data".into()),
|
||||
mime_type: Some("text/plain".into()),
|
||||
}],
|
||||
None,
|
||||
);
|
||||
|
||||
let resource = registry
|
||||
.read_resource("srv", "res://data")
|
||||
.expect("should find");
|
||||
assert_eq!(resource.name, "Data");
|
||||
|
||||
assert!(registry.read_resource("srv", "res://missing").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_connected_server_without_manager_when_calling_tool_then_it_errors() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "greet".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
let error = registry
|
||||
.call_tool("srv", "greet", &serde_json::json!({"name": "world"}))
|
||||
.expect_err("should require a configured manager");
|
||||
assert!(error.contains("MCP server manager is not configured"));
|
||||
|
||||
// Unknown tool should fail
|
||||
assert!(registry
|
||||
.call_tool("srv", "missing", &serde_json::json!({}))
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_connected_server_with_manager_when_calling_tool_then_it_returns_live_result() {
|
||||
let script_path = write_bridge_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("bridge.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let manager = Arc::new(Mutex::new(McpServerManager::from_servers(&servers)));
|
||||
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool for alpha".into()),
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"]
|
||||
})),
|
||||
}],
|
||||
vec![],
|
||||
Some("bridge test server".into()),
|
||||
);
|
||||
registry
|
||||
.set_manager(Arc::clone(&manager))
|
||||
.expect("manager should only be set once");
|
||||
|
||||
let result = registry
|
||||
.call_tool("alpha", "echo", &serde_json::json!({"text": "hello"}))
|
||||
.expect("should return live MCP result");
|
||||
|
||||
assert_eq!(
|
||||
result["structuredContent"]["server"],
|
||||
serde_json::json!("alpha")
|
||||
);
|
||||
assert_eq!(
|
||||
result["structuredContent"]["echoed"],
|
||||
serde_json::json!("hello")
|
||||
);
|
||||
assert_eq!(
|
||||
result["content"][0]["text"],
|
||||
serde_json::json!("alpha:hello")
|
||||
);
|
||||
|
||||
let log = fs::read_to_string(&log_path).expect("read log");
|
||||
assert_eq!(
|
||||
log.lines().collect::<Vec<_>>(),
|
||||
vec!["initialize", "tools/list", "tools/call"]
|
||||
);
|
||||
|
||||
cleanup_script(&script_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_tool_call_on_disconnected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![McpToolInfo {
|
||||
name: "greet".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(registry
|
||||
.call_tool("srv", "greet", &serde_json::json!({}))
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sets_auth_and_disconnects() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
registry
|
||||
.set_auth_status("srv", McpConnectionStatus::Connected)
|
||||
.expect("should succeed");
|
||||
let state = registry.get_server("srv").unwrap();
|
||||
assert_eq!(state.status, McpConnectionStatus::Connected);
|
||||
|
||||
let removed = registry.disconnect("srv");
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_operations_on_missing_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
assert!(registry.list_resources("missing").is_err());
|
||||
assert!(registry.read_resource("missing", "uri").is_err());
|
||||
assert!(registry.list_tools("missing").is_err());
|
||||
assert!(registry
|
||||
.call_tool("missing", "tool", &serde_json::json!({}))
|
||||
.is_err());
|
||||
assert!(registry
|
||||
.set_auth_status("missing", McpConnectionStatus::Connected)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_connection_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(McpConnectionStatus::Disconnected, "disconnected"),
|
||||
(McpConnectionStatus::Connecting, "connecting"),
|
||||
(McpConnectionStatus::Connected, "connected"),
|
||||
(McpConnectionStatus::AuthRequired, "auth_required"),
|
||||
(McpConnectionStatus::Error, "error"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("disconnected".to_string(), "disconnected"),
|
||||
("connecting".to_string(), "connecting"),
|
||||
("connected".to_string(), "connected"),
|
||||
("auth_required".to_string(), "auth_required"),
|
||||
("error".to_string(), "error"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_servers_returns_all_registered() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry.register_server(
|
||||
"beta",
|
||||
McpConnectionStatus::Connecting,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let servers = registry.list_servers();
|
||||
|
||||
// then
|
||||
assert_eq!(servers.len(), 2);
|
||||
assert!(servers.iter().any(|server| server.server_name == "alpha"));
|
||||
assert!(servers.iter().any(|server| server.server_name == "beta"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_from_connected_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "inspect".into(),
|
||||
description: Some("Inspect data".into()),
|
||||
input_schema: Some(serde_json::json!({"type": "object"})),
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let tools = registry.list_tools("srv").expect("tools should list");
|
||||
|
||||
// then
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].name, "inspect");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_rejects_disconnected_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let result = registry.list_tools("srv");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("non-connected server should fail");
|
||||
assert!(error.contains("not connected"));
|
||||
assert!(error.contains("auth_required"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_rejects_missing_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.list_tools("missing");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
result.expect_err("missing server should fail"),
|
||||
"server 'missing' not found"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_server_returns_none_for_missing() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let server = registry.get_server("missing");
|
||||
|
||||
// then
|
||||
assert!(server.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_payload_structure() {
|
||||
let script_path = write_bridge_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("payload.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"srv".to_string(),
|
||||
manager_server_config(&script_path, "srv", &log_path),
|
||||
)]);
|
||||
let registry = McpToolRegistry::new();
|
||||
let arguments = serde_json::json!({"text": "world"});
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool for srv".into()),
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"]
|
||||
})),
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry
|
||||
.set_manager(Arc::new(Mutex::new(McpServerManager::from_servers(&servers))))
|
||||
.expect("manager should only be set once");
|
||||
|
||||
let result = registry
|
||||
.call_tool("srv", "echo", &arguments)
|
||||
.expect("tool should return live payload");
|
||||
|
||||
assert_eq!(result["structuredContent"]["server"], "srv");
|
||||
assert_eq!(result["structuredContent"]["echoed"], "world");
|
||||
assert_eq!(result["content"][0]["text"], "srv:world");
|
||||
|
||||
cleanup_script(&script_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_overwrites_existing_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server("srv", McpConnectionStatus::Connecting, vec![], vec![], None);
|
||||
|
||||
// when
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "inspect".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
Some("Inspector".into()),
|
||||
);
|
||||
let state = registry.get_server("srv").expect("server should exist");
|
||||
|
||||
// then
|
||||
assert_eq!(state.status, McpConnectionStatus::Connected);
|
||||
assert_eq!(state.tools.len(), 1);
|
||||
assert_eq!(state.server_info.as_deref(), Some("Inspector"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnect_missing_returns_none() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.disconnect("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn len_and_is_empty_transitions() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry.register_server("beta", McpConnectionStatus::Connected, vec![], vec![], None);
|
||||
let after_create = registry.len();
|
||||
registry.disconnect("alpha");
|
||||
let after_first_remove = registry.len();
|
||||
registry.disconnect("beta");
|
||||
|
||||
// then
|
||||
assert_eq!(after_create, 2);
|
||||
assert_eq!(after_first_remove, 1);
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -442,7 +442,7 @@ fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||
b'0'..=b'9' => Ok(byte - b'0'),
|
||||
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||
_ => Err(format!("invalid percent byte: {byte}")),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
546
rust/crates/runtime/src/permission_enforcer.rs
Normal file
546
rust/crates/runtime/src/permission_enforcer.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
//! Permission enforcement layer that gates tool execution based on the
|
||||
//! active `PermissionPolicy`.
|
||||
|
||||
use crate::permissions::{PermissionMode, PermissionOutcome, PermissionPolicy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "outcome")]
|
||||
pub enum EnforcementResult {
|
||||
/// Tool execution is allowed.
|
||||
Allowed,
|
||||
/// Tool execution was denied due to insufficient permissions.
|
||||
Denied {
|
||||
tool: String,
|
||||
active_mode: String,
|
||||
required_mode: String,
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PermissionEnforcer {
|
||||
policy: PermissionPolicy,
|
||||
}
|
||||
|
||||
impl PermissionEnforcer {
|
||||
#[must_use]
|
||||
pub fn new(policy: PermissionPolicy) -> Self {
|
||||
Self { policy }
|
||||
}
|
||||
|
||||
/// Check whether a tool can be executed under the current permission policy.
|
||||
/// Auto-denies when prompting is required but no prompter is provided.
|
||||
pub fn check(&self, tool_name: &str, input: &str) -> EnforcementResult {
|
||||
// When the active mode is Prompt, defer to the caller's interactive
|
||||
// prompt flow rather than hard-denying (the enforcer has no prompter).
|
||||
if self.policy.active_mode() == PermissionMode::Prompt {
|
||||
return EnforcementResult::Allowed;
|
||||
}
|
||||
|
||||
let outcome = self.policy.authorize(tool_name, input, None);
|
||||
|
||||
match outcome {
|
||||
PermissionOutcome::Allow => EnforcementResult::Allowed,
|
||||
PermissionOutcome::Deny { reason } => {
|
||||
let active_mode = self.policy.active_mode();
|
||||
let required_mode = self.policy.required_mode_for(tool_name);
|
||||
EnforcementResult::Denied {
|
||||
tool: tool_name.to_owned(),
|
||||
active_mode: active_mode.as_str().to_owned(),
|
||||
required_mode: required_mode.as_str().to_owned(),
|
||||
reason,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_allowed(&self, tool_name: &str, input: &str) -> bool {
|
||||
matches!(self.check(tool_name, input), EnforcementResult::Allowed)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn active_mode(&self) -> PermissionMode {
|
||||
self.policy.active_mode()
|
||||
}
|
||||
|
||||
/// Classify a file operation against workspace boundaries.
|
||||
pub fn check_file_write(&self, path: &str, workspace_root: &str) -> EnforcementResult {
|
||||
let mode = self.policy.active_mode();
|
||||
|
||||
match mode {
|
||||
PermissionMode::ReadOnly => EnforcementResult::Denied {
|
||||
tool: "write_file".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
|
||||
reason: format!("file writes are not allowed in '{}' mode", mode.as_str()),
|
||||
},
|
||||
PermissionMode::WorkspaceWrite => {
|
||||
if is_within_workspace(path, workspace_root) {
|
||||
EnforcementResult::Allowed
|
||||
} else {
|
||||
EnforcementResult::Denied {
|
||||
tool: "write_file".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(),
|
||||
reason: format!(
|
||||
"path '{}' is outside workspace root '{}'",
|
||||
path, workspace_root
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Allow and DangerFullAccess permit all writes
|
||||
PermissionMode::Allow | PermissionMode::DangerFullAccess => EnforcementResult::Allowed,
|
||||
PermissionMode::Prompt => EnforcementResult::Denied {
|
||||
tool: "write_file".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
|
||||
reason: "file write requires confirmation in prompt mode".to_owned(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a bash command should be allowed based on current mode.
|
||||
pub fn check_bash(&self, command: &str) -> EnforcementResult {
|
||||
let mode = self.policy.active_mode();
|
||||
|
||||
match mode {
|
||||
PermissionMode::ReadOnly => {
|
||||
if is_read_only_command(command) {
|
||||
EnforcementResult::Allowed
|
||||
} else {
|
||||
EnforcementResult::Denied {
|
||||
tool: "bash".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
|
||||
reason: format!(
|
||||
"command may modify state; not allowed in '{}' mode",
|
||||
mode.as_str()
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
PermissionMode::Prompt => EnforcementResult::Denied {
|
||||
tool: "bash".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(),
|
||||
reason: "bash requires confirmation in prompt mode".to_owned(),
|
||||
},
|
||||
// WorkspaceWrite, Allow, DangerFullAccess: permit bash
|
||||
_ => EnforcementResult::Allowed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple workspace boundary check via string prefix.
|
||||
fn is_within_workspace(path: &str, workspace_root: &str) -> bool {
|
||||
let normalized = if path.starts_with('/') {
|
||||
path.to_owned()
|
||||
} else {
|
||||
format!("{workspace_root}/{path}")
|
||||
};
|
||||
|
||||
let root = if workspace_root.ends_with('/') {
|
||||
workspace_root.to_owned()
|
||||
} else {
|
||||
format!("{workspace_root}/")
|
||||
};
|
||||
|
||||
normalized.starts_with(&root) || normalized == workspace_root.trim_end_matches('/')
|
||||
}
|
||||
|
||||
/// Conservative heuristic: is this bash command read-only?
|
||||
fn is_read_only_command(command: &str) -> bool {
|
||||
let first_token = command
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
|
||||
matches!(
|
||||
first_token,
|
||||
"cat"
|
||||
| "head"
|
||||
| "tail"
|
||||
| "less"
|
||||
| "more"
|
||||
| "wc"
|
||||
| "ls"
|
||||
| "find"
|
||||
| "grep"
|
||||
| "rg"
|
||||
| "awk"
|
||||
| "sed"
|
||||
| "echo"
|
||||
| "printf"
|
||||
| "which"
|
||||
| "where"
|
||||
| "whoami"
|
||||
| "pwd"
|
||||
| "env"
|
||||
| "printenv"
|
||||
| "date"
|
||||
| "cal"
|
||||
| "df"
|
||||
| "du"
|
||||
| "free"
|
||||
| "uptime"
|
||||
| "uname"
|
||||
| "file"
|
||||
| "stat"
|
||||
| "diff"
|
||||
| "sort"
|
||||
| "uniq"
|
||||
| "tr"
|
||||
| "cut"
|
||||
| "paste"
|
||||
| "tee"
|
||||
| "xargs"
|
||||
| "test"
|
||||
| "true"
|
||||
| "false"
|
||||
| "type"
|
||||
| "readlink"
|
||||
| "realpath"
|
||||
| "basename"
|
||||
| "dirname"
|
||||
| "sha256sum"
|
||||
| "md5sum"
|
||||
| "b3sum"
|
||||
| "xxd"
|
||||
| "hexdump"
|
||||
| "od"
|
||||
| "strings"
|
||||
| "tree"
|
||||
| "jq"
|
||||
| "yq"
|
||||
| "python3"
|
||||
| "python"
|
||||
| "node"
|
||||
| "ruby"
|
||||
| "cargo"
|
||||
| "rustc"
|
||||
| "git"
|
||||
| "gh"
|
||||
) && !command.contains("-i ")
|
||||
&& !command.contains("--in-place")
|
||||
&& !command.contains(" > ")
|
||||
&& !command.contains(" >> ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_enforcer(mode: PermissionMode) -> PermissionEnforcer {
|
||||
let policy = PermissionPolicy::new(mode);
|
||||
PermissionEnforcer::new(policy)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_mode_permits_everything() {
|
||||
let enforcer = make_enforcer(PermissionMode::Allow);
|
||||
assert!(enforcer.is_allowed("bash", ""));
|
||||
assert!(enforcer.is_allowed("write_file", ""));
|
||||
assert!(enforcer.is_allowed("edit_file", ""));
|
||||
assert_eq!(
|
||||
enforcer.check_file_write("/outside/path", "/workspace"),
|
||||
EnforcementResult::Allowed
|
||||
);
|
||||
assert_eq!(enforcer.check_bash("rm -rf /"), EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_denies_writes() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("grep_search", PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||
|
||||
let enforcer = PermissionEnforcer::new(policy);
|
||||
assert!(enforcer.is_allowed("read_file", ""));
|
||||
assert!(enforcer.is_allowed("grep_search", ""));
|
||||
|
||||
// write_file requires WorkspaceWrite but we're in ReadOnly
|
||||
let result = enforcer.check("write_file", "");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
|
||||
let result = enforcer.check_file_write("/workspace/file.rs", "/workspace");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_allows_read_commands() {
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
assert_eq!(
|
||||
enforcer.check_bash("cat src/main.rs"),
|
||||
EnforcementResult::Allowed
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.check_bash("grep -r 'pattern' ."),
|
||||
EnforcementResult::Allowed
|
||||
);
|
||||
assert_eq!(enforcer.check_bash("ls -la"), EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_denies_write_commands() {
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
let result = enforcer.check_bash("rm file.txt");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_allows_within_workspace() {
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace");
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_denies_outside_workspace() {
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
let result = enforcer.check_file_write("/etc/passwd", "/workspace");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_mode_denies_without_prompter() {
|
||||
let enforcer = make_enforcer(PermissionMode::Prompt);
|
||||
let result = enforcer.check_bash("echo test");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
|
||||
let result = enforcer.check_file_write("/workspace/file.rs", "/workspace");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_boundary_check() {
|
||||
assert!(is_within_workspace("/workspace/src/main.rs", "/workspace"));
|
||||
assert!(is_within_workspace("/workspace", "/workspace"));
|
||||
assert!(!is_within_workspace("/etc/passwd", "/workspace"));
|
||||
assert!(!is_within_workspace("/workspacex/hack", "/workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_command_heuristic() {
|
||||
assert!(is_read_only_command("cat file.txt"));
|
||||
assert!(is_read_only_command("grep pattern file"));
|
||||
assert!(is_read_only_command("git log --oneline"));
|
||||
assert!(!is_read_only_command("rm file.txt"));
|
||||
assert!(!is_read_only_command("echo test > file.txt"));
|
||||
assert!(!is_read_only_command("sed -i 's/a/b/' file"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_mode_returns_policy_mode() {
|
||||
// given
|
||||
let modes = [
|
||||
PermissionMode::ReadOnly,
|
||||
PermissionMode::WorkspaceWrite,
|
||||
PermissionMode::DangerFullAccess,
|
||||
PermissionMode::Prompt,
|
||||
PermissionMode::Allow,
|
||||
];
|
||||
|
||||
// when
|
||||
let active_modes: Vec<_> = modes
|
||||
.into_iter()
|
||||
.map(|mode| make_enforcer(mode).active_mode())
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(active_modes, modes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn danger_full_access_permits_file_writes_and_bash() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::DangerFullAccess);
|
||||
|
||||
// when
|
||||
let file_result = enforcer.check_file_write("/outside/workspace/file.txt", "/workspace");
|
||||
let bash_result = enforcer.check_bash("rm -rf /tmp/scratch");
|
||||
|
||||
// then
|
||||
assert_eq!(file_result, EnforcementResult::Allowed);
|
||||
assert_eq!(bash_result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_denied_payload_contains_tool_and_modes() {
|
||||
// given
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||
let enforcer = PermissionEnforcer::new(policy);
|
||||
|
||||
// when
|
||||
let result = enforcer.check("write_file", "{}");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "write_file");
|
||||
assert_eq!(active_mode, "read-only");
|
||||
assert_eq!(required_mode, "workspace-write");
|
||||
assert!(reason.contains("requires workspace-write permission"));
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_relative_path_resolved() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("src/main.rs", "/workspace");
|
||||
|
||||
// then
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_root_with_trailing_slash() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace/");
|
||||
|
||||
// then
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_root_equality() {
|
||||
// given
|
||||
let root = "/workspace/";
|
||||
|
||||
// when
|
||||
let equal_to_root = is_within_workspace("/workspace", root);
|
||||
|
||||
// then
|
||||
assert!(equal_to_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_full_path_prefix() {
|
||||
// given
|
||||
let full_path_command = "/usr/bin/cat Cargo.toml";
|
||||
let git_path_command = "/usr/local/bin/git status";
|
||||
|
||||
// when
|
||||
let cat_result = is_read_only_command(full_path_command);
|
||||
let git_result = is_read_only_command(git_path_command);
|
||||
|
||||
// then
|
||||
assert!(cat_result);
|
||||
assert!(git_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_redirects_block_read_only_commands() {
|
||||
// given
|
||||
let overwrite = "cat Cargo.toml > out.txt";
|
||||
let append = "echo test >> out.txt";
|
||||
|
||||
// when
|
||||
let overwrite_result = is_read_only_command(overwrite);
|
||||
let append_result = is_read_only_command(append);
|
||||
|
||||
// then
|
||||
assert!(!overwrite_result);
|
||||
assert!(!append_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_in_place_flag_blocks() {
|
||||
// given
|
||||
let interactive_python = "python -i script.py";
|
||||
let in_place_sed = "sed --in-place 's/a/b/' file.txt";
|
||||
|
||||
// when
|
||||
let interactive_result = is_read_only_command(interactive_python);
|
||||
let in_place_result = is_read_only_command(in_place_sed);
|
||||
|
||||
// then
|
||||
assert!(!interactive_result);
|
||||
assert!(!in_place_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_empty_command() {
|
||||
// given
|
||||
let empty = "";
|
||||
let whitespace = " ";
|
||||
|
||||
// when
|
||||
let empty_result = is_read_only_command(empty);
|
||||
let whitespace_result = is_read_only_command(whitespace);
|
||||
|
||||
// then
|
||||
assert!(!empty_result);
|
||||
assert!(!whitespace_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_mode_check_bash_denied_payload_fields() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::Prompt);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_bash("git status");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "bash");
|
||||
assert_eq!(active_mode, "prompt");
|
||||
assert_eq!(required_mode, "danger-full-access");
|
||||
assert_eq!(reason, "bash requires confirmation in prompt mode");
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_check_file_write_denied_payload() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("/workspace/file.txt", "/workspace");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "write_file");
|
||||
assert_eq!(active_mode, "read-only");
|
||||
assert_eq!(required_mode, "workspace-write");
|
||||
assert!(reason.contains("file writes are not allowed"));
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@ pub fn resolve_sandbox_status(config: &SandboxConfig, cwd: &Path) -> SandboxStat
|
||||
#[must_use]
|
||||
pub fn resolve_sandbox_status_for_request(request: &SandboxRequest, cwd: &Path) -> SandboxStatus {
|
||||
let container = detect_container_environment();
|
||||
let namespace_supported = cfg!(target_os = "linux") && command_exists("unshare");
|
||||
let namespace_supported = cfg!(target_os = "linux") && unshare_user_namespace_works();
|
||||
let network_supported = namespace_supported;
|
||||
let filesystem_active =
|
||||
request.enabled && request.filesystem_mode != FilesystemIsolationMode::Off;
|
||||
@@ -282,6 +282,27 @@ fn command_exists(command: &str) -> bool {
|
||||
.is_some_and(|paths| env::split_paths(&paths).any(|path| path.join(command).exists()))
|
||||
}
|
||||
|
||||
/// Check whether `unshare --user` actually works on this system.
|
||||
/// On some CI environments (e.g. GitHub Actions), the binary exists but
|
||||
/// user namespaces are restricted, causing silent failures.
|
||||
fn unshare_user_namespace_works() -> bool {
|
||||
use std::sync::OnceLock;
|
||||
static RESULT: OnceLock<bool> = OnceLock::new();
|
||||
*RESULT.get_or_init(|| {
|
||||
if !command_exists("unshare") {
|
||||
return false;
|
||||
}
|
||||
std::process::Command::new("unshare")
|
||||
.args(["--user", "--map-root-user", "true"])
|
||||
.stdin(std::process::Stdio::null())
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
|
||||
449
rust/crates/runtime/src/task_registry.rs
Normal file
449
rust/crates/runtime/src/task_registry.rs
Normal file
@@ -0,0 +1,449 @@
|
||||
//! In-memory task registry for sub-agent task lifecycle management.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TaskStatus {
|
||||
Created,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Created => write!(f, "created"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
Self::Stopped => write!(f, "stopped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub task_id: String,
|
||||
pub prompt: String,
|
||||
pub description: Option<String>,
|
||||
pub status: TaskStatus,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
pub messages: Vec<TaskMessage>,
|
||||
pub output: String,
|
||||
pub team_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TaskRegistry {
|
||||
inner: Arc<Mutex<RegistryInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct RegistryInner {
|
||||
tasks: HashMap<String, Task>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
impl TaskRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn create(&self, prompt: &str, description: Option<&str>) -> Task {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
let ts = now_secs();
|
||||
let task_id = format!("task_{:08x}_{}", ts, inner.counter);
|
||||
let task = Task {
|
||||
task_id: task_id.clone(),
|
||||
prompt: prompt.to_owned(),
|
||||
description: description.map(str::to_owned),
|
||||
status: TaskStatus::Created,
|
||||
created_at: ts,
|
||||
updated_at: ts,
|
||||
messages: Vec::new(),
|
||||
output: String::new(),
|
||||
team_id: None,
|
||||
};
|
||||
inner.tasks.insert(task_id, task.clone());
|
||||
task
|
||||
}
|
||||
|
||||
pub fn get(&self, task_id: &str) -> Option<Task> {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.tasks.get(task_id).cloned()
|
||||
}
|
||||
|
||||
pub fn list(&self, status_filter: Option<TaskStatus>) -> Vec<Task> {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner
|
||||
.tasks
|
||||
.values()
|
||||
.filter(|t| status_filter.map_or(true, |s| t.status == s))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn stop(&self, task_id: &str) -> Result<Task, String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
|
||||
match task.status {
|
||||
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => {
|
||||
return Err(format!(
|
||||
"task {task_id} is already in terminal state: {}",
|
||||
task.status
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
task.status = TaskStatus::Stopped;
|
||||
task.updated_at = now_secs();
|
||||
Ok(task.clone())
|
||||
}
|
||||
|
||||
pub fn update(&self, task_id: &str, message: &str) -> Result<Task, String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
|
||||
task.messages.push(TaskMessage {
|
||||
role: String::from("user"),
|
||||
content: message.to_owned(),
|
||||
timestamp: now_secs(),
|
||||
});
|
||||
task.updated_at = now_secs();
|
||||
Ok(task.clone())
|
||||
}
|
||||
|
||||
pub fn output(&self, task_id: &str) -> Result<String, String> {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
Ok(task.output.clone())
|
||||
}
|
||||
|
||||
pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
task.output.push_str(output);
|
||||
task.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
task.status = status;
|
||||
task.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
task.team_id = Some(team_id.to_owned());
|
||||
task.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove(&self, task_id: &str) -> Option<Task> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.tasks.remove(task_id)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.tasks.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_and_retrieves_tasks() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Do something", Some("A test task"));
|
||||
assert_eq!(task.status, TaskStatus::Created);
|
||||
assert_eq!(task.prompt, "Do something");
|
||||
assert_eq!(task.description.as_deref(), Some("A test task"));
|
||||
|
||||
let fetched = registry.get(&task.task_id).expect("task should exist");
|
||||
assert_eq!(fetched.task_id, task.task_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_tasks_with_optional_filter() {
|
||||
let registry = TaskRegistry::new();
|
||||
registry.create("Task A", None);
|
||||
let task_b = registry.create("Task B", None);
|
||||
registry
|
||||
.set_status(&task_b.task_id, TaskStatus::Running)
|
||||
.expect("set status should succeed");
|
||||
|
||||
let all = registry.list(None);
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let running = registry.list(Some(TaskStatus::Running));
|
||||
assert_eq!(running.len(), 1);
|
||||
assert_eq!(running[0].task_id, task_b.task_id);
|
||||
|
||||
let created = registry.list(Some(TaskStatus::Created));
|
||||
assert_eq!(created.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stops_running_task() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Stoppable", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Running)
|
||||
.unwrap();
|
||||
|
||||
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
|
||||
assert_eq!(stopped.status, TaskStatus::Stopped);
|
||||
|
||||
// Stopping again should fail
|
||||
let result = registry.stop(&task.task_id);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn updates_task_with_messages() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Messageable", None);
|
||||
let updated = registry
|
||||
.update(&task.task_id, "Here's more context")
|
||||
.expect("update should succeed");
|
||||
assert_eq!(updated.messages.len(), 1);
|
||||
assert_eq!(updated.messages[0].content, "Here's more context");
|
||||
assert_eq!(updated.messages[0].role, "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn appends_and_retrieves_output() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Output task", None);
|
||||
registry
|
||||
.append_output(&task.task_id, "line 1\n")
|
||||
.expect("append should succeed");
|
||||
registry
|
||||
.append_output(&task.task_id, "line 2\n")
|
||||
.expect("append should succeed");
|
||||
|
||||
let output = registry.output(&task.task_id).expect("output should exist");
|
||||
assert_eq!(output, "line 1\nline 2\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assigns_team_and_removes_task() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Team task", None);
|
||||
registry
|
||||
.assign_team(&task.task_id, "team_abc")
|
||||
.expect("assign should succeed");
|
||||
|
||||
let fetched = registry.get(&task.task_id).unwrap();
|
||||
assert_eq!(fetched.team_id.as_deref(), Some("team_abc"));
|
||||
|
||||
let removed = registry.remove(&task.task_id);
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.get(&task.task_id).is_none());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_operations_on_missing_task() {
|
||||
let registry = TaskRegistry::new();
|
||||
assert!(registry.stop("nonexistent").is_err());
|
||||
assert!(registry.update("nonexistent", "msg").is_err());
|
||||
assert!(registry.output("nonexistent").is_err());
|
||||
assert!(registry.append_output("nonexistent", "data").is_err());
|
||||
assert!(registry
|
||||
.set_status("nonexistent", TaskStatus::Running)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn task_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(TaskStatus::Created, "created"),
|
||||
(TaskStatus::Running, "running"),
|
||||
(TaskStatus::Completed, "completed"),
|
||||
(TaskStatus::Failed, "failed"),
|
||||
(TaskStatus::Stopped, "stopped"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("created".to_string(), "created"),
|
||||
("running".to_string(), "running"),
|
||||
("completed".to_string(), "completed"),
|
||||
("failed".to_string(), "failed"),
|
||||
("stopped".to_string(), "stopped"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_rejects_completed_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("done", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Completed)
|
||||
.expect("set status should succeed");
|
||||
|
||||
// when
|
||||
let result = registry.stop(&task.task_id);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("completed task should be rejected");
|
||||
assert!(error.contains("already in terminal state"));
|
||||
assert!(error.contains("completed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_rejects_failed_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("failed", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Failed)
|
||||
.expect("set status should succeed");
|
||||
|
||||
// when
|
||||
let result = registry.stop(&task.task_id);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("failed task should be rejected");
|
||||
assert!(error.contains("already in terminal state"));
|
||||
assert!(error.contains("failed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_succeeds_from_created_state() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("created task", None);
|
||||
|
||||
// when
|
||||
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
|
||||
|
||||
// then
|
||||
assert_eq!(stopped.status, TaskStatus::Stopped);
|
||||
assert!(stopped.updated_at >= task.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_registry_is_empty() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let all_tasks = registry.list(None);
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(all_tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_without_description() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let task = registry.create("Do the thing", None);
|
||||
|
||||
// then
|
||||
assert!(task.task_id.starts_with("task_"));
|
||||
assert_eq!(task.description, None);
|
||||
assert!(task.messages.is_empty());
|
||||
assert!(task.output.is_empty());
|
||||
assert_eq!(task.team_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_returns_none() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.remove("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assign_team_rejects_missing_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.assign_team("missing", "team_123");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing task should be rejected");
|
||||
assert_eq!(error, "task not found: missing");
|
||||
}
|
||||
}
|
||||
508
rust/crates/runtime/src/team_cron_registry.rs
Normal file
508
rust/crates/runtime/src/team_cron_registry.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
//! In-memory registries for Team and Cron lifecycle management.
|
||||
//!
|
||||
//! Provides TeamCreate/Delete and CronCreate/Delete/List runtime backing
|
||||
//! to replace the stub implementations in the tools crate.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Team {
|
||||
pub team_id: String,
|
||||
pub name: String,
|
||||
pub task_ids: Vec<String>,
|
||||
pub status: TeamStatus,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TeamStatus {
|
||||
Created,
|
||||
Running,
|
||||
Completed,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TeamStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Created => write!(f, "created"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Deleted => write!(f, "deleted"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TeamRegistry {
|
||||
inner: Arc<Mutex<TeamInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct TeamInner {
|
||||
teams: HashMap<String, Team>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
impl TeamRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn create(&self, name: &str, task_ids: Vec<String>) -> Team {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
let ts = now_secs();
|
||||
let team_id = format!("team_{:08x}_{}", ts, inner.counter);
|
||||
let team = Team {
|
||||
team_id: team_id.clone(),
|
||||
name: name.to_owned(),
|
||||
task_ids,
|
||||
status: TeamStatus::Created,
|
||||
created_at: ts,
|
||||
updated_at: ts,
|
||||
};
|
||||
inner.teams.insert(team_id, team.clone());
|
||||
team
|
||||
}
|
||||
|
||||
pub fn get(&self, team_id: &str) -> Option<Team> {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.get(team_id).cloned()
|
||||
}
|
||||
|
||||
pub fn list(&self) -> Vec<Team> {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn delete(&self, team_id: &str) -> Result<Team, String> {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
let team = inner
|
||||
.teams
|
||||
.get_mut(team_id)
|
||||
.ok_or_else(|| format!("team not found: {team_id}"))?;
|
||||
team.status = TeamStatus::Deleted;
|
||||
team.updated_at = now_secs();
|
||||
Ok(team.clone())
|
||||
}
|
||||
|
||||
pub fn remove(&self, team_id: &str) -> Option<Team> {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.remove(team_id)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronEntry {
|
||||
pub cron_id: String,
|
||||
pub schedule: String,
|
||||
pub prompt: String,
|
||||
pub description: Option<String>,
|
||||
pub enabled: bool,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
pub last_run_at: Option<u64>,
|
||||
pub run_count: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CronRegistry {
|
||||
inner: Arc<Mutex<CronInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct CronInner {
|
||||
entries: HashMap<String, CronEntry>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
impl CronRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn create(&self, schedule: &str, prompt: &str, description: Option<&str>) -> CronEntry {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
let ts = now_secs();
|
||||
let cron_id = format!("cron_{:08x}_{}", ts, inner.counter);
|
||||
let entry = CronEntry {
|
||||
cron_id: cron_id.clone(),
|
||||
schedule: schedule.to_owned(),
|
||||
prompt: prompt.to_owned(),
|
||||
description: description.map(str::to_owned),
|
||||
enabled: true,
|
||||
created_at: ts,
|
||||
updated_at: ts,
|
||||
last_run_at: None,
|
||||
run_count: 0,
|
||||
};
|
||||
inner.entries.insert(cron_id, entry.clone());
|
||||
entry
|
||||
}
|
||||
|
||||
pub fn get(&self, cron_id: &str) -> Option<CronEntry> {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.entries.get(cron_id).cloned()
|
||||
}
|
||||
|
||||
pub fn list(&self, enabled_only: bool) -> Vec<CronEntry> {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner
|
||||
.entries
|
||||
.values()
|
||||
.filter(|e| !enabled_only || e.enabled)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn delete(&self, cron_id: &str) -> Result<CronEntry, String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner
|
||||
.entries
|
||||
.remove(cron_id)
|
||||
.ok_or_else(|| format!("cron not found: {cron_id}"))
|
||||
}
|
||||
|
||||
/// Disable a cron entry without removing it.
|
||||
pub fn disable(&self, cron_id: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
let entry = inner
|
||||
.entries
|
||||
.get_mut(cron_id)
|
||||
.ok_or_else(|| format!("cron not found: {cron_id}"))?;
|
||||
entry.enabled = false;
|
||||
entry.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record a cron run.
|
||||
pub fn record_run(&self, cron_id: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
let entry = inner
|
||||
.entries
|
||||
.get_mut(cron_id)
|
||||
.ok_or_else(|| format!("cron not found: {cron_id}"))?;
|
||||
entry.last_run_at = Some(now_secs());
|
||||
entry.run_count += 1;
|
||||
entry.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.entries.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Team tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn creates_and_retrieves_team() {
|
||||
let registry = TeamRegistry::new();
|
||||
let team = registry.create("Alpha Squad", vec!["task_001".into(), "task_002".into()]);
|
||||
assert_eq!(team.name, "Alpha Squad");
|
||||
assert_eq!(team.task_ids.len(), 2);
|
||||
assert_eq!(team.status, TeamStatus::Created);
|
||||
|
||||
let fetched = registry.get(&team.team_id).expect("team should exist");
|
||||
assert_eq!(fetched.team_id, team.team_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_and_deletes_teams() {
|
||||
let registry = TeamRegistry::new();
|
||||
let t1 = registry.create("Team A", vec![]);
|
||||
let t2 = registry.create("Team B", vec![]);
|
||||
|
||||
let all = registry.list();
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let deleted = registry.delete(&t1.team_id).expect("delete should succeed");
|
||||
assert_eq!(deleted.status, TeamStatus::Deleted);
|
||||
|
||||
// Team is still listable (soft delete)
|
||||
let still_there = registry.get(&t1.team_id).unwrap();
|
||||
assert_eq!(still_there.status, TeamStatus::Deleted);
|
||||
|
||||
// Hard remove
|
||||
registry.remove(&t2.team_id);
|
||||
assert_eq!(registry.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_missing_team_operations() {
|
||||
let registry = TeamRegistry::new();
|
||||
assert!(registry.delete("nonexistent").is_err());
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
// ── Cron tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn creates_and_retrieves_cron() {
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("0 * * * *", "Check status", Some("hourly check"));
|
||||
assert_eq!(entry.schedule, "0 * * * *");
|
||||
assert_eq!(entry.prompt, "Check status");
|
||||
assert!(entry.enabled);
|
||||
assert_eq!(entry.run_count, 0);
|
||||
assert!(entry.last_run_at.is_none());
|
||||
|
||||
let fetched = registry.get(&entry.cron_id).expect("cron should exist");
|
||||
assert_eq!(fetched.cron_id, entry.cron_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_with_enabled_filter() {
|
||||
let registry = CronRegistry::new();
|
||||
let c1 = registry.create("* * * * *", "Task 1", None);
|
||||
let c2 = registry.create("0 * * * *", "Task 2", None);
|
||||
registry
|
||||
.disable(&c1.cron_id)
|
||||
.expect("disable should succeed");
|
||||
|
||||
let all = registry.list(false);
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let enabled_only = registry.list(true);
|
||||
assert_eq!(enabled_only.len(), 1);
|
||||
assert_eq!(enabled_only[0].cron_id, c2.cron_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deletes_cron_entry() {
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("* * * * *", "To delete", None);
|
||||
let deleted = registry
|
||||
.delete(&entry.cron_id)
|
||||
.expect("delete should succeed");
|
||||
assert_eq!(deleted.cron_id, entry.cron_id);
|
||||
assert!(registry.get(&entry.cron_id).is_none());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn records_cron_runs() {
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("*/5 * * * *", "Recurring", None);
|
||||
registry.record_run(&entry.cron_id).unwrap();
|
||||
registry.record_run(&entry.cron_id).unwrap();
|
||||
|
||||
let fetched = registry.get(&entry.cron_id).unwrap();
|
||||
assert_eq!(fetched.run_count, 2);
|
||||
assert!(fetched.last_run_at.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_missing_cron_operations() {
|
||||
let registry = CronRegistry::new();
|
||||
assert!(registry.delete("nonexistent").is_err());
|
||||
assert!(registry.disable("nonexistent").is_err());
|
||||
assert!(registry.record_run("nonexistent").is_err());
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(TeamStatus::Created, "created"),
|
||||
(TeamStatus::Running, "running"),
|
||||
(TeamStatus::Completed, "completed"),
|
||||
(TeamStatus::Deleted, "deleted"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("created".to_string(), "created"),
|
||||
("running".to_string(), "running"),
|
||||
("completed".to_string(), "completed"),
|
||||
("deleted".to_string(), "deleted"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_team_registry_is_empty() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let teams = registry.list();
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(teams.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_remove_nonexistent_returns_none() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.remove("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_len_transitions() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let alpha = registry.create("Alpha", vec![]);
|
||||
let beta = registry.create("Beta", vec![]);
|
||||
let after_create = registry.len();
|
||||
registry.remove(&alpha.team_id);
|
||||
let after_first_remove = registry.len();
|
||||
registry.remove(&beta.team_id);
|
||||
|
||||
// then
|
||||
assert_eq!(after_create, 2);
|
||||
assert_eq!(after_first_remove, 1);
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_list_all_disabled_returns_empty_for_enabled_only() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let first = registry.create("* * * * *", "Task 1", None);
|
||||
let second = registry.create("0 * * * *", "Task 2", None);
|
||||
registry
|
||||
.disable(&first.cron_id)
|
||||
.expect("disable should succeed");
|
||||
registry
|
||||
.disable(&second.cron_id)
|
||||
.expect("disable should succeed");
|
||||
|
||||
// when
|
||||
let enabled_only = registry.list(true);
|
||||
let all_entries = registry.list(false);
|
||||
|
||||
// then
|
||||
assert!(enabled_only.is_empty());
|
||||
assert_eq!(all_entries.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_create_without_description() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
|
||||
// when
|
||||
let entry = registry.create("*/15 * * * *", "Check health", None);
|
||||
|
||||
// then
|
||||
assert!(entry.cron_id.starts_with("cron_"));
|
||||
assert_eq!(entry.description, None);
|
||||
assert!(entry.enabled);
|
||||
assert_eq!(entry.run_count, 0);
|
||||
assert_eq!(entry.last_run_at, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_cron_registry_is_empty() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
|
||||
// when
|
||||
let enabled_only = registry.list(true);
|
||||
let all_entries = registry.list(false);
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(enabled_only.is_empty());
|
||||
assert!(all_entries.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_record_run_updates_timestamp_and_counter() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("*/5 * * * *", "Recurring", None);
|
||||
|
||||
// when
|
||||
registry
|
||||
.record_run(&entry.cron_id)
|
||||
.expect("first run should succeed");
|
||||
registry
|
||||
.record_run(&entry.cron_id)
|
||||
.expect("second run should succeed");
|
||||
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
|
||||
|
||||
// then
|
||||
assert_eq!(fetched.run_count, 2);
|
||||
assert!(fetched.last_run_at.is_some());
|
||||
assert!(fetched.updated_at >= entry.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_disable_updates_timestamp() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("0 0 * * *", "Nightly", None);
|
||||
|
||||
// when
|
||||
registry
|
||||
.disable(&entry.cron_id)
|
||||
.expect("disable should succeed");
|
||||
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
|
||||
|
||||
// then
|
||||
assert!(!fetched.enabled);
|
||||
assert!(fetched.updated_at >= entry.updated_at);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user