mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-06 16:14:49 +08:00
feat: ultraclaw session outputs — registry tests, MCP bridge, PARITY.md, cleanup
Ultraclaw mode results from 10 parallel opencode sessions: - PARITY.md: Updated both copies with all 9 landed lanes, commit hashes, line counts, and test counts. All checklist items marked complete. - MCP bridge: McpToolRegistry.call_tool now wired to real McpServerManager via async JSON-RPC (discover_tools -> tools/call -> shutdown) - Registry tests: Added coverage for TaskRegistry, TeamRegistry, CronRegistry, PermissionEnforcer, LspRegistry (branch-focused tests) - Permissions refactor: Simplified authorize_with_context, extracted helpers, added characterization tests (185 runtime tests pass) - AI slop cleanup: Removed redundant comments, unused_self suppressions, tightened unreachable branches - CLI fixes: Minor adjustments in main.rs and hooks.rs All 363+ tests pass. Workspace compiles clean.
This commit is contained in:
@@ -847,7 +847,7 @@ mod tests {
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
_ => Err(RuntimeError::new("unexpected extra API call")),
|
||||
_ => unreachable!("extra API call"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1156,7 +1156,7 @@ mod tests {
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
_ => Err(RuntimeError::new("unexpected extra API call")),
|
||||
_ => unreachable!("extra API call"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1231,7 +1231,7 @@ mod tests {
|
||||
AssistantEvent::MessageStop,
|
||||
])
|
||||
}
|
||||
_ => Err(RuntimeError::new("unexpected extra API call")),
|
||||
_ => unreachable!("extra API call"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1545,7 +1545,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn auto_compaction_threshold_defaults_and_parses_values() {
|
||||
// given / when / then
|
||||
assert_eq!(
|
||||
parse_auto_compaction_threshold(None),
|
||||
DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
|
||||
|
||||
@@ -37,7 +37,6 @@ impl LspAction {
|
||||
}
|
||||
}
|
||||
|
||||
/// A diagnostic entry from an LSP server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspDiagnostic {
|
||||
pub path: String,
|
||||
@@ -48,7 +47,6 @@ pub struct LspDiagnostic {
|
||||
pub source: Option<String>,
|
||||
}
|
||||
|
||||
/// A location result (definition, references).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspLocation {
|
||||
pub path: String,
|
||||
@@ -59,14 +57,12 @@ pub struct LspLocation {
|
||||
pub preview: Option<String>,
|
||||
}
|
||||
|
||||
/// A hover result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspHoverResult {
|
||||
pub content: String,
|
||||
pub language: Option<String>,
|
||||
}
|
||||
|
||||
/// A completion item.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspCompletionItem {
|
||||
pub label: String,
|
||||
@@ -75,7 +71,6 @@ pub struct LspCompletionItem {
|
||||
pub insert_text: Option<String>,
|
||||
}
|
||||
|
||||
/// A document symbol.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspSymbol {
|
||||
pub name: String,
|
||||
@@ -85,7 +80,6 @@ pub struct LspSymbol {
|
||||
pub character: u32,
|
||||
}
|
||||
|
||||
/// Connection status.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LspServerStatus {
|
||||
@@ -106,7 +100,6 @@ impl std::fmt::Display for LspServerStatus {
|
||||
}
|
||||
}
|
||||
|
||||
/// Tracked state of an LSP server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspServerState {
|
||||
pub language: String,
|
||||
@@ -116,7 +109,6 @@ pub struct LspServerState {
|
||||
pub diagnostics: Vec<LspDiagnostic>,
|
||||
}
|
||||
|
||||
/// Thread-safe LSP server registry.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct LspRegistry {
|
||||
inner: Arc<Mutex<RegistryInner>>,
|
||||
@@ -133,7 +125,6 @@ impl LspRegistry {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Register an LSP server for a language.
|
||||
pub fn register(
|
||||
&self,
|
||||
language: &str,
|
||||
@@ -154,7 +145,6 @@ impl LspRegistry {
|
||||
);
|
||||
}
|
||||
|
||||
/// Get server state by language.
|
||||
pub fn get(&self, language: &str) -> Option<LspServerState> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.get(language).cloned()
|
||||
@@ -435,4 +425,326 @@ mod tests {
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lsp_action_from_str_all_aliases() {
|
||||
// given
|
||||
let cases = [
|
||||
("diagnostics", Some(LspAction::Diagnostics)),
|
||||
("hover", Some(LspAction::Hover)),
|
||||
("definition", Some(LspAction::Definition)),
|
||||
("goto_definition", Some(LspAction::Definition)),
|
||||
("references", Some(LspAction::References)),
|
||||
("find_references", Some(LspAction::References)),
|
||||
("completion", Some(LspAction::Completion)),
|
||||
("completions", Some(LspAction::Completion)),
|
||||
("symbols", Some(LspAction::Symbols)),
|
||||
("document_symbols", Some(LspAction::Symbols)),
|
||||
("format", Some(LspAction::Format)),
|
||||
("formatting", Some(LspAction::Format)),
|
||||
("unknown", None),
|
||||
];
|
||||
|
||||
// when
|
||||
let resolved: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(input, expected)| (input, LspAction::from_str(input), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
for (input, actual, expected) in resolved {
|
||||
assert_eq!(actual, expected, "unexpected action resolution for {input}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lsp_server_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(LspServerStatus::Connected, "connected"),
|
||||
(LspServerStatus::Disconnected, "disconnected"),
|
||||
(LspServerStatus::Starting, "starting"),
|
||||
(LspServerStatus::Error, "error"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("connected".to_string(), "connected"),
|
||||
("disconnected".to_string(), "disconnected"),
|
||||
("starting".to_string(), "starting"),
|
||||
("error".to_string(), "error"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_diagnostics_without_path_aggregates() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/lib.rs".into(),
|
||||
line: 1,
|
||||
character: 0,
|
||||
severity: "warning".into(),
|
||||
message: "unused import".into(),
|
||||
source: Some("rust-analyzer".into()),
|
||||
}],
|
||||
)
|
||||
.expect("rust diagnostics should add");
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"python",
|
||||
vec![LspDiagnostic {
|
||||
path: "script.py".into(),
|
||||
line: 2,
|
||||
character: 4,
|
||||
severity: "error".into(),
|
||||
message: "undefined name".into(),
|
||||
source: Some("pyright".into()),
|
||||
}],
|
||||
)
|
||||
.expect("python diagnostics should add");
|
||||
|
||||
// when
|
||||
let result = registry
|
||||
.dispatch("diagnostics", None, None, None, None)
|
||||
.expect("aggregate diagnostics should work");
|
||||
|
||||
// then
|
||||
assert_eq!(result["action"], "diagnostics");
|
||||
assert_eq!(result["count"], 2);
|
||||
assert_eq!(result["diagnostics"].as_array().map(Vec::len), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_non_diagnostics_requires_path() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", None, Some(1), Some(0), None);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
result.expect_err("path should be required"),
|
||||
"path is required for this LSP action"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_no_server_for_path_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", Some("notes.md"), Some(1), Some(0), None);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing server should fail");
|
||||
assert!(error.contains("no LSP server available for path: notes.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_disconnected_server_error_payload() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("typescript", LspServerStatus::Disconnected, None, vec![]);
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", Some("src/index.ts"), Some(3), Some(2), None);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("disconnected server should fail");
|
||||
assert!(error.contains("typescript"));
|
||||
assert!(error.contains("disconnected"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_server_for_all_extensions() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
for language in [
|
||||
"rust",
|
||||
"typescript",
|
||||
"javascript",
|
||||
"python",
|
||||
"go",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
"ruby",
|
||||
"lua",
|
||||
] {
|
||||
registry.register(language, LspServerStatus::Connected, None, vec![]);
|
||||
}
|
||||
let cases = [
|
||||
("src/main.rs", "rust"),
|
||||
("src/index.ts", "typescript"),
|
||||
("src/view.tsx", "typescript"),
|
||||
("src/app.js", "javascript"),
|
||||
("src/app.jsx", "javascript"),
|
||||
("script.py", "python"),
|
||||
("main.go", "go"),
|
||||
("Main.java", "java"),
|
||||
("native.c", "c"),
|
||||
("native.h", "c"),
|
||||
("native.cpp", "cpp"),
|
||||
("native.hpp", "cpp"),
|
||||
("native.cc", "cpp"),
|
||||
("script.rb", "ruby"),
|
||||
("script.lua", "lua"),
|
||||
];
|
||||
|
||||
// when
|
||||
let resolved: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(path, expected)| {
|
||||
(
|
||||
path,
|
||||
registry
|
||||
.find_server_for_path(path)
|
||||
.map(|server| server.language),
|
||||
expected,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// then
|
||||
for (path, actual, expected) in resolved {
|
||||
assert_eq!(
|
||||
actual.as_deref(),
|
||||
Some(expected),
|
||||
"unexpected mapping for {path}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_server_for_path_no_extension() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
// when
|
||||
let result = registry.find_server_for_path("Makefile");
|
||||
|
||||
// then
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_servers_with_multiple() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("typescript", LspServerStatus::Starting, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Error, None, vec![]);
|
||||
|
||||
// when
|
||||
let servers = registry.list_servers();
|
||||
|
||||
// then
|
||||
assert_eq!(servers.len(), 3);
|
||||
assert!(servers.iter().any(|server| server.language == "rust"));
|
||||
assert!(servers.iter().any(|server| server.language == "typescript"));
|
||||
assert!(servers.iter().any(|server| server.language == "python"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_missing_server_returns_none() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let server = registry.get("missing");
|
||||
|
||||
// then
|
||||
assert!(server.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_diagnostics_missing_language_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.add_diagnostics("missing", vec![]);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing language should fail");
|
||||
assert!(error.contains("LSP server not found for language: missing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_diagnostics_across_servers() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
let shared_path = "shared/file.txt";
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: shared_path.into(),
|
||||
line: 4,
|
||||
character: 1,
|
||||
severity: "warning".into(),
|
||||
message: "warn".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.expect("rust diagnostics should add");
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"python",
|
||||
vec![LspDiagnostic {
|
||||
path: shared_path.into(),
|
||||
line: 8,
|
||||
character: 3,
|
||||
severity: "error".into(),
|
||||
message: "err".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.expect("python diagnostics should add");
|
||||
|
||||
// when
|
||||
let diagnostics = registry.get_diagnostics(shared_path);
|
||||
|
||||
// then
|
||||
assert_eq!(diagnostics.len(), 2);
|
||||
assert!(diagnostics
|
||||
.iter()
|
||||
.any(|diagnostic| diagnostic.message == "warn"));
|
||||
assert!(diagnostics
|
||||
.iter()
|
||||
.any(|diagnostic| diagnostic.message == "err"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_diagnostics_missing_language_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.clear_diagnostics("missing");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing language should fail");
|
||||
assert!(error.contains("LSP server not found for language: missing"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@
|
||||
//! connect to MCP servers and invoke their capabilities.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use crate::mcp::mcp_tool_name;
|
||||
use crate::mcp_stdio::McpServerManager;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Status of a managed MCP server connection.
|
||||
@@ -64,6 +66,7 @@ pub struct McpServerState {
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpToolRegistry {
|
||||
inner: Arc<Mutex<HashMap<String, McpServerState>>>,
|
||||
manager: Arc<OnceLock<Arc<Mutex<McpServerManager>>>>,
|
||||
}
|
||||
|
||||
impl McpToolRegistry {
|
||||
@@ -72,6 +75,13 @@ impl McpToolRegistry {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn set_manager(
|
||||
&self,
|
||||
manager: Arc<Mutex<McpServerManager>>,
|
||||
) -> Result<(), Arc<Mutex<McpServerManager>>> {
|
||||
self.manager.set(manager)
|
||||
}
|
||||
|
||||
/// Register or update an MCP server connection.
|
||||
pub fn register_server(
|
||||
&self,
|
||||
@@ -163,8 +173,66 @@ impl McpToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/// Call a tool on a specific server (returns placeholder for now;
|
||||
/// actual execution is handled by `McpServerManager::call_tool`).
|
||||
fn spawn_tool_call(
|
||||
manager: Arc<Mutex<McpServerManager>>,
|
||||
qualified_tool_name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let join_handle = std::thread::Builder::new()
|
||||
.name(format!("mcp-tool-call-{qualified_tool_name}"))
|
||||
.spawn(move || {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|error| format!("failed to create MCP tool runtime: {error}"))?;
|
||||
|
||||
runtime.block_on(async move {
|
||||
let response = {
|
||||
let mut manager = manager
|
||||
.lock()
|
||||
.map_err(|_| "mcp server manager lock poisoned".to_string())?;
|
||||
manager.discover_tools().await.map_err(|error| error.to_string())?;
|
||||
let response = manager
|
||||
.call_tool(&qualified_tool_name, arguments)
|
||||
.await
|
||||
.map_err(|error| error.to_string());
|
||||
let shutdown = manager.shutdown().await.map_err(|error| error.to_string());
|
||||
|
||||
match (response, shutdown) {
|
||||
(Ok(response), Ok(())) => Ok(response),
|
||||
(Err(error), Ok(())) | (Err(error), Err(_)) => Err(error),
|
||||
(Ok(_), Err(error)) => Err(error),
|
||||
}
|
||||
}?;
|
||||
|
||||
if let Some(error) = response.error {
|
||||
return Err(format!(
|
||||
"MCP server returned JSON-RPC error for tools/call: {} ({})",
|
||||
error.message, error.code
|
||||
));
|
||||
}
|
||||
|
||||
let result = response.result.ok_or_else(|| {
|
||||
"MCP server returned no result for tools/call".to_string()
|
||||
})?;
|
||||
|
||||
serde_json::to_value(result)
|
||||
.map_err(|error| format!("failed to serialize MCP tool result: {error}"))
|
||||
})
|
||||
})
|
||||
.map_err(|error| format!("failed to spawn MCP tool call thread: {error}"))?;
|
||||
|
||||
join_handle.join().map_err(|panic_payload| {
|
||||
if let Some(message) = panic_payload.downcast_ref::<&str>() {
|
||||
format!("MCP tool call thread panicked: {message}")
|
||||
} else if let Some(message) = panic_payload.downcast_ref::<String>() {
|
||||
format!("MCP tool call thread panicked: {message}")
|
||||
} else {
|
||||
"MCP tool call thread panicked".to_string()
|
||||
}
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn call_tool(
|
||||
&self,
|
||||
server_name: &str,
|
||||
@@ -190,15 +258,19 @@ impl McpToolRegistry {
|
||||
));
|
||||
}
|
||||
|
||||
// Return structured acknowledgment — actual execution is delegated
|
||||
// to the McpServerManager which handles the JSON-RPC call.
|
||||
Ok(serde_json::json!({
|
||||
"server": server_name,
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
"status": "dispatched",
|
||||
"message": "Tool call dispatched to MCP server"
|
||||
}))
|
||||
drop(inner);
|
||||
|
||||
let manager = self
|
||||
.manager
|
||||
.get()
|
||||
.cloned()
|
||||
.ok_or_else(|| "MCP server manager is not configured".to_string())?;
|
||||
|
||||
Self::spawn_tool_call(
|
||||
manager,
|
||||
mcp_tool_name(server_name, tool_name),
|
||||
(!arguments.is_null()).then(|| arguments.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set auth status for a server.
|
||||
@@ -236,7 +308,151 @@ impl McpToolRegistry {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::*;
|
||||
use crate::config::{
|
||||
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
|
||||
};
|
||||
|
||||
fn temp_dir() -> PathBuf {
|
||||
static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0);
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!("runtime-mcp-tool-bridge-{nanos}-{unique_id}"))
|
||||
}
|
||||
|
||||
fn cleanup_script(script_path: &Path) {
|
||||
if let Some(root) = script_path.parent() {
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
}
|
||||
|
||||
fn write_bridge_mcp_server_script() -> PathBuf {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("temp dir");
|
||||
let script_path = root.join("bridge-mcp-server.py");
|
||||
let script = [
|
||||
"#!/usr/bin/env python3",
|
||||
"import json, os, sys",
|
||||
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
|
||||
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
|
||||
"",
|
||||
"def log(method):",
|
||||
" if LOG_PATH:",
|
||||
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
|
||||
" handle.write(f'{method}\\n')",
|
||||
"",
|
||||
"def read_message():",
|
||||
" header = b''",
|
||||
r" while not header.endswith(b'\r\n\r\n'):",
|
||||
" chunk = sys.stdin.buffer.read(1)",
|
||||
" if not chunk:",
|
||||
" return None",
|
||||
" header += chunk",
|
||||
" length = 0",
|
||||
r" for line in header.decode().split('\r\n'):",
|
||||
r" if line.lower().startswith('content-length:'):",
|
||||
r" length = int(line.split(':', 1)[1].strip())",
|
||||
" payload = sys.stdin.buffer.read(length)",
|
||||
" return json.loads(payload.decode())",
|
||||
"",
|
||||
"def send_message(message):",
|
||||
" payload = json.dumps(message).encode()",
|
||||
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
|
||||
" sys.stdout.buffer.flush()",
|
||||
"",
|
||||
"while True:",
|
||||
" request = read_message()",
|
||||
" if request is None:",
|
||||
" break",
|
||||
" method = request['method']",
|
||||
" log(method)",
|
||||
" if method == 'initialize':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'protocolVersion': request['params']['protocolVersion'],",
|
||||
" 'capabilities': {'tools': {}},",
|
||||
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/list':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'tools': [",
|
||||
" {",
|
||||
" 'name': 'echo',",
|
||||
" 'description': f'Echo tool for {LABEL}',",
|
||||
" 'inputSchema': {",
|
||||
" 'type': 'object',",
|
||||
" 'properties': {'text': {'type': 'string'}},",
|
||||
" 'required': ['text']",
|
||||
" }",
|
||||
" }",
|
||||
" ]",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/call':",
|
||||
" args = request['params'].get('arguments') or {}",
|
||||
" text = args.get('text', '')",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
|
||||
" 'structuredContent': {'server': LABEL, 'echoed': text},",
|
||||
" 'isError': False",
|
||||
" }",
|
||||
" })",
|
||||
" else:",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
|
||||
" })",
|
||||
"",
|
||||
]
|
||||
.join("\n");
|
||||
fs::write(&script_path, script).expect("write script");
|
||||
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, permissions).expect("chmod");
|
||||
script_path
|
||||
}
|
||||
|
||||
fn manager_server_config(
|
||||
script_path: &Path,
|
||||
server_name: &str,
|
||||
log_path: &Path,
|
||||
) -> ScopedMcpServerConfig {
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||
command: "python3".to_string(),
|
||||
args: vec![script_path.to_string_lossy().into_owned()],
|
||||
env: BTreeMap::from([
|
||||
("MCP_SERVER_LABEL".to_string(), server_name.to_string()),
|
||||
(
|
||||
"MCP_LOG_PATH".to_string(),
|
||||
log_path.to_string_lossy().into_owned(),
|
||||
),
|
||||
]),
|
||||
tool_call_timeout_ms: Some(1_000),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registers_and_retrieves_server() {
|
||||
@@ -323,7 +539,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calls_tool_on_connected_server() {
|
||||
fn given_connected_server_without_manager_when_calling_tool_then_it_errors() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
@@ -337,10 +553,10 @@ mod tests {
|
||||
None,
|
||||
);
|
||||
|
||||
let result = registry
|
||||
let error = registry
|
||||
.call_tool("srv", "greet", &serde_json::json!({"name": "world"}))
|
||||
.expect("should dispatch");
|
||||
assert_eq!(result["status"], "dispatched");
|
||||
.expect_err("should require a configured manager");
|
||||
assert!(error.contains("MCP server manager is not configured"));
|
||||
|
||||
// Unknown tool should fail
|
||||
assert!(registry
|
||||
@@ -348,6 +564,63 @@ mod tests {
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_connected_server_with_manager_when_calling_tool_then_it_returns_live_result() {
|
||||
let script_path = write_bridge_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("bridge.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let manager = Arc::new(Mutex::new(McpServerManager::from_servers(&servers)));
|
||||
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool for alpha".into()),
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"]
|
||||
})),
|
||||
}],
|
||||
vec![],
|
||||
Some("bridge test server".into()),
|
||||
);
|
||||
registry
|
||||
.set_manager(Arc::clone(&manager))
|
||||
.expect("manager should only be set once");
|
||||
|
||||
let result = registry
|
||||
.call_tool("alpha", "echo", &serde_json::json!({"text": "hello"}))
|
||||
.expect("should return live MCP result");
|
||||
|
||||
assert_eq!(
|
||||
result["structuredContent"]["server"],
|
||||
serde_json::json!("alpha")
|
||||
);
|
||||
assert_eq!(
|
||||
result["structuredContent"]["echoed"],
|
||||
serde_json::json!("hello")
|
||||
);
|
||||
assert_eq!(
|
||||
result["content"][0]["text"],
|
||||
serde_json::json!("alpha:hello")
|
||||
);
|
||||
|
||||
let log = fs::read_to_string(&log_path).expect("read log");
|
||||
assert_eq!(
|
||||
log.lines().collect::<Vec<_>>(),
|
||||
vec!["initialize", "tools/list", "tools/call"]
|
||||
);
|
||||
|
||||
cleanup_script(&script_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_tool_call_on_disconnected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
@@ -403,4 +676,239 @@ mod tests {
|
||||
.set_auth_status("missing", McpConnectionStatus::Connected)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_connection_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(McpConnectionStatus::Disconnected, "disconnected"),
|
||||
(McpConnectionStatus::Connecting, "connecting"),
|
||||
(McpConnectionStatus::Connected, "connected"),
|
||||
(McpConnectionStatus::AuthRequired, "auth_required"),
|
||||
(McpConnectionStatus::Error, "error"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("disconnected".to_string(), "disconnected"),
|
||||
("connecting".to_string(), "connecting"),
|
||||
("connected".to_string(), "connected"),
|
||||
("auth_required".to_string(), "auth_required"),
|
||||
("error".to_string(), "error"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_servers_returns_all_registered() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry.register_server(
|
||||
"beta",
|
||||
McpConnectionStatus::Connecting,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let servers = registry.list_servers();
|
||||
|
||||
// then
|
||||
assert_eq!(servers.len(), 2);
|
||||
assert!(servers.iter().any(|server| server.server_name == "alpha"));
|
||||
assert!(servers.iter().any(|server| server.server_name == "beta"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_from_connected_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "inspect".into(),
|
||||
description: Some("Inspect data".into()),
|
||||
input_schema: Some(serde_json::json!({"type": "object"})),
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let tools = registry.list_tools("srv").expect("tools should list");
|
||||
|
||||
// then
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].name, "inspect");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_rejects_disconnected_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let result = registry.list_tools("srv");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("non-connected server should fail");
|
||||
assert!(error.contains("not connected"));
|
||||
assert!(error.contains("auth_required"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_rejects_missing_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.list_tools("missing");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
result.expect_err("missing server should fail"),
|
||||
"server 'missing' not found"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_server_returns_none_for_missing() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let server = registry.get_server("missing");
|
||||
|
||||
// then
|
||||
assert!(server.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_payload_structure() {
|
||||
let script_path = write_bridge_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("payload.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"srv".to_string(),
|
||||
manager_server_config(&script_path, "srv", &log_path),
|
||||
)]);
|
||||
let registry = McpToolRegistry::new();
|
||||
let arguments = serde_json::json!({"text": "world"});
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool for srv".into()),
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"]
|
||||
})),
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry
|
||||
.set_manager(Arc::new(Mutex::new(McpServerManager::from_servers(&servers))))
|
||||
.expect("manager should only be set once");
|
||||
|
||||
let result = registry
|
||||
.call_tool("srv", "echo", &arguments)
|
||||
.expect("tool should return live payload");
|
||||
|
||||
assert_eq!(result["structuredContent"]["server"], "srv");
|
||||
assert_eq!(result["structuredContent"]["echoed"], "world");
|
||||
assert_eq!(result["content"][0]["text"], "srv:world");
|
||||
|
||||
cleanup_script(&script_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_overwrites_existing_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server("srv", McpConnectionStatus::Connecting, vec![], vec![], None);
|
||||
|
||||
// when
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "inspect".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
Some("Inspector".into()),
|
||||
);
|
||||
let state = registry.get_server("srv").expect("server should exist");
|
||||
|
||||
// then
|
||||
assert_eq!(state.status, McpConnectionStatus::Connected);
|
||||
assert_eq!(state.tools.len(), 1);
|
||||
assert_eq!(state.server_info.as_deref(), Some("Inspector"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnect_missing_returns_none() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.disconnect("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn len_and_is_empty_transitions() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry.register_server("beta", McpConnectionStatus::Connected, vec![], vec![], None);
|
||||
let after_create = registry.len();
|
||||
registry.disconnect("alpha");
|
||||
let after_first_remove = registry.len();
|
||||
registry.disconnect("beta");
|
||||
|
||||
// then
|
||||
assert_eq!(after_create, 2);
|
||||
assert_eq!(after_first_remove, 1);
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,7 +442,7 @@ fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||
b'0'..=b'9' => Ok(byte - b'0'),
|
||||
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||
_ => Err(format!("invalid percent byte: {byte}")),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
use crate::permissions::{PermissionMode, PermissionOutcome, PermissionPolicy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Result of a permission check before tool execution.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "outcome")]
|
||||
pub enum EnforcementResult {
|
||||
@@ -23,8 +22,7 @@ pub enum EnforcementResult {
|
||||
},
|
||||
}
|
||||
|
||||
/// Permission enforcer that gates tool execution through the permission policy.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PermissionEnforcer {
|
||||
policy: PermissionPolicy,
|
||||
}
|
||||
@@ -55,13 +53,11 @@ impl PermissionEnforcer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool is allowed (returns true for Allow, false for Deny).
|
||||
#[must_use]
|
||||
pub fn is_allowed(&self, tool_name: &str, input: &str) -> bool {
|
||||
matches!(self.check(tool_name, input), EnforcementResult::Allowed)
|
||||
}
|
||||
|
||||
/// Get the active permission mode.
|
||||
#[must_use]
|
||||
pub fn active_mode(&self) -> PermissionMode {
|
||||
self.policy.active_mode()
|
||||
@@ -337,4 +333,212 @@ mod tests {
|
||||
assert!(!is_read_only_command("echo test > file.txt"));
|
||||
assert!(!is_read_only_command("sed -i 's/a/b/' file"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_mode_returns_policy_mode() {
|
||||
// given
|
||||
let modes = [
|
||||
PermissionMode::ReadOnly,
|
||||
PermissionMode::WorkspaceWrite,
|
||||
PermissionMode::DangerFullAccess,
|
||||
PermissionMode::Prompt,
|
||||
PermissionMode::Allow,
|
||||
];
|
||||
|
||||
// when
|
||||
let active_modes: Vec<_> = modes
|
||||
.into_iter()
|
||||
.map(|mode| make_enforcer(mode).active_mode())
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(active_modes, modes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn danger_full_access_permits_file_writes_and_bash() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::DangerFullAccess);
|
||||
|
||||
// when
|
||||
let file_result = enforcer.check_file_write("/outside/workspace/file.txt", "/workspace");
|
||||
let bash_result = enforcer.check_bash("rm -rf /tmp/scratch");
|
||||
|
||||
// then
|
||||
assert_eq!(file_result, EnforcementResult::Allowed);
|
||||
assert_eq!(bash_result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_denied_payload_contains_tool_and_modes() {
|
||||
// given
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||
let enforcer = PermissionEnforcer::new(policy);
|
||||
|
||||
// when
|
||||
let result = enforcer.check("write_file", "{}");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "write_file");
|
||||
assert_eq!(active_mode, "read-only");
|
||||
assert_eq!(required_mode, "workspace-write");
|
||||
assert!(reason.contains("requires workspace-write permission"));
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_relative_path_resolved() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("src/main.rs", "/workspace");
|
||||
|
||||
// then
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_root_with_trailing_slash() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace/");
|
||||
|
||||
// then
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_root_equality() {
|
||||
// given
|
||||
let root = "/workspace/";
|
||||
|
||||
// when
|
||||
let equal_to_root = is_within_workspace("/workspace", root);
|
||||
|
||||
// then
|
||||
assert!(equal_to_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_full_path_prefix() {
|
||||
// given
|
||||
let full_path_command = "/usr/bin/cat Cargo.toml";
|
||||
let git_path_command = "/usr/local/bin/git status";
|
||||
|
||||
// when
|
||||
let cat_result = is_read_only_command(full_path_command);
|
||||
let git_result = is_read_only_command(git_path_command);
|
||||
|
||||
// then
|
||||
assert!(cat_result);
|
||||
assert!(git_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_redirects_block_read_only_commands() {
|
||||
// given
|
||||
let overwrite = "cat Cargo.toml > out.txt";
|
||||
let append = "echo test >> out.txt";
|
||||
|
||||
// when
|
||||
let overwrite_result = is_read_only_command(overwrite);
|
||||
let append_result = is_read_only_command(append);
|
||||
|
||||
// then
|
||||
assert!(!overwrite_result);
|
||||
assert!(!append_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_in_place_flag_blocks() {
|
||||
// given
|
||||
let interactive_python = "python -i script.py";
|
||||
let in_place_sed = "sed --in-place 's/a/b/' file.txt";
|
||||
|
||||
// when
|
||||
let interactive_result = is_read_only_command(interactive_python);
|
||||
let in_place_result = is_read_only_command(in_place_sed);
|
||||
|
||||
// then
|
||||
assert!(!interactive_result);
|
||||
assert!(!in_place_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_empty_command() {
|
||||
// given
|
||||
let empty = "";
|
||||
let whitespace = " ";
|
||||
|
||||
// when
|
||||
let empty_result = is_read_only_command(empty);
|
||||
let whitespace_result = is_read_only_command(whitespace);
|
||||
|
||||
// then
|
||||
assert!(!empty_result);
|
||||
assert!(!whitespace_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_mode_check_bash_denied_payload_fields() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::Prompt);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_bash("git status");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "bash");
|
||||
assert_eq!(active_mode, "prompt");
|
||||
assert_eq!(required_mode, "danger-full-access");
|
||||
assert_eq!(reason, "bash requires confirmation in prompt mode");
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_check_file_write_denied_payload() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("/workspace/file.txt", "/workspace");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "write_file");
|
||||
assert_eq!(active_mode, "read-only");
|
||||
assert_eq!(required_mode, "workspace-write");
|
||||
assert!(reason.contains("file writes are not allowed"));
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,4 +332,139 @@ mod tests {
|
||||
.set_status("nonexistent", TaskStatus::Running)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn task_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(TaskStatus::Created, "created"),
|
||||
(TaskStatus::Running, "running"),
|
||||
(TaskStatus::Completed, "completed"),
|
||||
(TaskStatus::Failed, "failed"),
|
||||
(TaskStatus::Stopped, "stopped"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("created".to_string(), "created"),
|
||||
("running".to_string(), "running"),
|
||||
("completed".to_string(), "completed"),
|
||||
("failed".to_string(), "failed"),
|
||||
("stopped".to_string(), "stopped"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_rejects_completed_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("done", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Completed)
|
||||
.expect("set status should succeed");
|
||||
|
||||
// when
|
||||
let result = registry.stop(&task.task_id);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("completed task should be rejected");
|
||||
assert!(error.contains("already in terminal state"));
|
||||
assert!(error.contains("completed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_rejects_failed_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("failed", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Failed)
|
||||
.expect("set status should succeed");
|
||||
|
||||
// when
|
||||
let result = registry.stop(&task.task_id);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("failed task should be rejected");
|
||||
assert!(error.contains("already in terminal state"));
|
||||
assert!(error.contains("failed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_succeeds_from_created_state() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("created task", None);
|
||||
|
||||
// when
|
||||
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
|
||||
|
||||
// then
|
||||
assert_eq!(stopped.status, TaskStatus::Stopped);
|
||||
assert!(stopped.updated_at >= task.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_registry_is_empty() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let all_tasks = registry.list(None);
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(all_tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_without_description() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let task = registry.create("Do the thing", None);
|
||||
|
||||
// then
|
||||
assert!(task.task_id.starts_with("task_"));
|
||||
assert_eq!(task.description, None);
|
||||
assert!(task.messages.is_empty());
|
||||
assert!(task.output.is_empty());
|
||||
assert_eq!(task.team_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_returns_none() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.remove("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assign_team_rejects_missing_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.assign_team("missing", "team_123");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing task should be rejected");
|
||||
assert_eq!(error, "task not found: missing");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,11 +16,6 @@ fn now_secs() -> u64 {
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────
|
||||
// Team registry
|
||||
// ─────────────────────────────────────────────
|
||||
|
||||
/// A team groups multiple tasks for parallel execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Team {
|
||||
pub team_id: String,
|
||||
@@ -51,7 +46,6 @@ impl std::fmt::Display for TeamStatus {
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe team registry.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TeamRegistry {
|
||||
inner: Arc<Mutex<TeamInner>>,
|
||||
@@ -69,7 +63,6 @@ impl TeamRegistry {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a new team with the given name and task IDs.
|
||||
pub fn create(&self, name: &str, task_ids: Vec<String>) -> Team {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
@@ -87,19 +80,16 @@ impl TeamRegistry {
|
||||
team
|
||||
}
|
||||
|
||||
/// Get a team by ID.
|
||||
pub fn get(&self, team_id: &str) -> Option<Team> {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.get(team_id).cloned()
|
||||
}
|
||||
|
||||
/// List all teams.
|
||||
pub fn list(&self) -> Vec<Team> {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Delete a team.
|
||||
pub fn delete(&self, team_id: &str) -> Result<Team, String> {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
let team = inner
|
||||
@@ -111,7 +101,6 @@ impl TeamRegistry {
|
||||
Ok(team.clone())
|
||||
}
|
||||
|
||||
/// Remove a team entirely from the registry.
|
||||
pub fn remove(&self, team_id: &str) -> Option<Team> {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.remove(team_id)
|
||||
@@ -129,11 +118,6 @@ impl TeamRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────
|
||||
// Cron registry
|
||||
// ─────────────────────────────────────────────
|
||||
|
||||
/// A cron entry schedules a prompt to run on a recurring schedule.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronEntry {
|
||||
pub cron_id: String,
|
||||
@@ -147,7 +131,6 @@ pub struct CronEntry {
|
||||
pub run_count: u64,
|
||||
}
|
||||
|
||||
/// Thread-safe cron registry.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CronRegistry {
|
||||
inner: Arc<Mutex<CronInner>>,
|
||||
@@ -165,7 +148,6 @@ impl CronRegistry {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a new cron entry.
|
||||
pub fn create(&self, schedule: &str, prompt: &str, description: Option<&str>) -> CronEntry {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
@@ -186,13 +168,11 @@ impl CronRegistry {
|
||||
entry
|
||||
}
|
||||
|
||||
/// Get a cron entry by ID.
|
||||
pub fn get(&self, cron_id: &str) -> Option<CronEntry> {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.entries.get(cron_id).cloned()
|
||||
}
|
||||
|
||||
/// List all cron entries, optionally filtered to enabled only.
|
||||
pub fn list(&self, enabled_only: bool) -> Vec<CronEntry> {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner
|
||||
@@ -203,7 +183,6 @@ impl CronRegistry {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Delete (remove) a cron entry.
|
||||
pub fn delete(&self, cron_id: &str) -> Result<CronEntry, String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner
|
||||
@@ -360,4 +339,170 @@ mod tests {
|
||||
assert!(registry.record_run("nonexistent").is_err());
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(TeamStatus::Created, "created"),
|
||||
(TeamStatus::Running, "running"),
|
||||
(TeamStatus::Completed, "completed"),
|
||||
(TeamStatus::Deleted, "deleted"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("created".to_string(), "created"),
|
||||
("running".to_string(), "running"),
|
||||
("completed".to_string(), "completed"),
|
||||
("deleted".to_string(), "deleted"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_team_registry_is_empty() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let teams = registry.list();
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(teams.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_remove_nonexistent_returns_none() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.remove("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_len_transitions() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let alpha = registry.create("Alpha", vec![]);
|
||||
let beta = registry.create("Beta", vec![]);
|
||||
let after_create = registry.len();
|
||||
registry.remove(&alpha.team_id);
|
||||
let after_first_remove = registry.len();
|
||||
registry.remove(&beta.team_id);
|
||||
|
||||
// then
|
||||
assert_eq!(after_create, 2);
|
||||
assert_eq!(after_first_remove, 1);
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_list_all_disabled_returns_empty_for_enabled_only() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let first = registry.create("* * * * *", "Task 1", None);
|
||||
let second = registry.create("0 * * * *", "Task 2", None);
|
||||
registry
|
||||
.disable(&first.cron_id)
|
||||
.expect("disable should succeed");
|
||||
registry
|
||||
.disable(&second.cron_id)
|
||||
.expect("disable should succeed");
|
||||
|
||||
// when
|
||||
let enabled_only = registry.list(true);
|
||||
let all_entries = registry.list(false);
|
||||
|
||||
// then
|
||||
assert!(enabled_only.is_empty());
|
||||
assert_eq!(all_entries.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_create_without_description() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
|
||||
// when
|
||||
let entry = registry.create("*/15 * * * *", "Check health", None);
|
||||
|
||||
// then
|
||||
assert!(entry.cron_id.starts_with("cron_"));
|
||||
assert_eq!(entry.description, None);
|
||||
assert!(entry.enabled);
|
||||
assert_eq!(entry.run_count, 0);
|
||||
assert_eq!(entry.last_run_at, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_cron_registry_is_empty() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
|
||||
// when
|
||||
let enabled_only = registry.list(true);
|
||||
let all_entries = registry.list(false);
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(enabled_only.is_empty());
|
||||
assert!(all_entries.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_record_run_updates_timestamp_and_counter() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("*/5 * * * *", "Recurring", None);
|
||||
|
||||
// when
|
||||
registry
|
||||
.record_run(&entry.cron_id)
|
||||
.expect("first run should succeed");
|
||||
registry
|
||||
.record_run(&entry.cron_id)
|
||||
.expect("second run should succeed");
|
||||
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
|
||||
|
||||
// then
|
||||
assert_eq!(fetched.run_count, 2);
|
||||
assert!(fetched.last_run_at.is_some());
|
||||
assert!(fetched.updated_at >= entry.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_disable_updates_timestamp() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("0 0 * * *", "Nightly", None);
|
||||
|
||||
// when
|
||||
registry
|
||||
.disable(&entry.cron_id)
|
||||
.expect("disable should succeed");
|
||||
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
|
||||
|
||||
// then
|
||||
assert!(!fetched.enabled);
|
||||
assert!(fetched.updated_at >= entry.updated_at);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user