diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index ab3eb61..22bd852 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -169,7 +169,6 @@ impl GlobalToolRegistry { builtin.chain(plugin).collect() } - #[must_use] pub fn permission_specs( &self, allowed_tools: Option<&BTreeSet>, @@ -648,7 +647,7 @@ fn run_notebook_edit(input: NotebookEditInput) -> Result { } fn run_sleep(input: SleepInput) -> Result { - to_pretty_json(execute_sleep(input)) + to_pretty_json(execute_sleep(input)?) } fn run_brief(input: BriefInput) -> Result { @@ -660,7 +659,7 @@ fn run_config(input: ConfigInput) -> Result { } fn run_structured_output(input: StructuredOutputInput) -> Result { - to_pretty_json(execute_structured_output(input)) + to_pretty_json(execute_structured_output(input)?) } fn run_repl(input: ReplInput) -> Result { @@ -2347,7 +2346,8 @@ fn execute_notebook_edit(input: NotebookEditInput) -> Result { - let resolved_cell_type = resolved_cell_type.expect("insert cell type"); + let resolved_cell_type = resolved_cell_type + .ok_or_else(|| String::from("insert mode requires a cell type"))?; let new_id = make_cell_id(cells.len()); let new_cell = build_notebook_cell(&new_id, resolved_cell_type, &new_source); let insert_at = target_index.map_or(cells.len(), |index| index + 1); @@ -2359,16 +2359,21 @@ fn execute_notebook_edit(input: NotebookEditInput) -> Result { - let removed = cells.remove(target_index.expect("delete target index")); + let idx = target_index + .ok_or_else(|| String::from("delete mode requires a target cell index"))?; + let removed = cells.remove(idx); removed .get("id") .and_then(serde_json::Value::as_str) .map(ToString::to_string) } NotebookEditMode::Replace => { - let resolved_cell_type = resolved_cell_type.expect("replace cell type"); + let resolved_cell_type = resolved_cell_type + .ok_or_else(|| String::from("replace mode requires a cell type"))?; + let idx = target_index + .ok_or_else(|| String::from("replace mode requires a target cell index"))?; let cell = cells - .get_mut(target_index.expect("replace target index")) + .get_mut(idx) .ok_or_else(|| String::from("Cell index out of range"))?; cell["source"] = serde_json::Value::Array(source_lines(&new_source)); cell["cell_type"] = serde_json::Value::String(match resolved_cell_type { @@ -2459,13 +2464,21 @@ fn cell_kind(cell: &serde_json::Value) -> Option { }) } +const MAX_SLEEP_DURATION_MS: u64 = 300_000; + #[allow(clippy::needless_pass_by_value)] -fn execute_sleep(input: SleepInput) -> SleepOutput { +fn execute_sleep(input: SleepInput) -> Result { + if input.duration_ms > MAX_SLEEP_DURATION_MS { + return Err(format!( + "duration_ms {} exceeds maximum allowed sleep of {MAX_SLEEP_DURATION_MS}ms", + input.duration_ms, + )); + } std::thread::sleep(Duration::from_millis(input.duration_ms)); - SleepOutput { + Ok(SleepOutput { duration_ms: input.duration_ms, message: format!("Slept for {}ms", input.duration_ms), - } + }) } fn execute_brief(input: BriefInput) -> Result { @@ -2562,25 +2575,62 @@ fn execute_config(input: ConfigInput) -> Result { } } -fn execute_structured_output(input: StructuredOutputInput) -> StructuredOutputResult { - StructuredOutputResult { +fn execute_structured_output( + input: StructuredOutputInput, +) -> Result { + if input.0.is_empty() { + return Err(String::from("structured output payload must not be empty")); + } + Ok(StructuredOutputResult { data: String::from("Structured output provided successfully"), structured_output: input.0, - } + }) } fn execute_repl(input: ReplInput) -> Result { if input.code.trim().is_empty() { return Err(String::from("code must not be empty")); } - let _ = input.timeout_ms; let runtime = resolve_repl_runtime(&input.language)?; let started = Instant::now(); - let output = Command::new(runtime.program) + let mut process = Command::new(runtime.program); + process .args(runtime.args) .arg(&input.code) - .output() - .map_err(|error| error.to_string())?; + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()); + + let output = if let Some(timeout_ms) = input.timeout_ms { + let mut child = process.spawn().map_err(|error| error.to_string())?; + loop { + if child + .try_wait() + .map_err(|error| error.to_string())? + .is_some() + { + break child + .wait_with_output() + .map_err(|error| error.to_string())?; + } + if started.elapsed() >= Duration::from_millis(timeout_ms) { + child.kill().map_err(|error| error.to_string())?; + child + .wait_with_output() + .map_err(|error| error.to_string())?; + return Err(format!( + "REPL execution exceeded timeout of {timeout_ms} ms" + )); + } + std::thread::sleep(Duration::from_millis(10)); + } + } else { + process + .spawn() + .map_err(|error| error.to_string())? + .wait_with_output() + .map_err(|error| error.to_string())? + }; Ok(ReplOutput { language: input.language, @@ -4226,6 +4276,21 @@ mod tests { assert!(elapsed >= Duration::from_millis(15)); } + #[test] + fn given_excessive_duration_when_sleep_then_rejects_with_error() { + let result = execute_tool("Sleep", &json!({"duration_ms": 999_999_999_u64})); + let error = result.expect_err("excessive sleep should fail"); + assert!(error.contains("exceeds maximum allowed sleep")); + } + + #[test] + fn given_zero_duration_when_sleep_then_succeeds() { + let result = + execute_tool("Sleep", &json!({"duration_ms": 0})).expect("0ms sleep should succeed"); + let output: serde_json::Value = serde_json::from_str(&result).expect("json"); + assert_eq!(output["duration_ms"], 0); + } + #[test] fn brief_returns_sent_message_and_attachment_metadata() { let attachment = std::env::temp_dir().join(format!( @@ -4330,6 +4395,13 @@ mod tests { assert_eq!(output["structured_output"]["items"][1], 2); } + #[test] + fn given_empty_payload_when_structured_output_then_rejects_with_error() { + let result = execute_tool("StructuredOutput", &json!({})); + let error = result.expect_err("empty payload should fail"); + assert!(error.contains("must not be empty")); + } + #[test] fn repl_executes_python_code() { let result = execute_tool( @@ -4343,6 +4415,37 @@ mod tests { assert!(output["stdout"].as_str().expect("stdout").contains('2')); } + #[test] + fn given_empty_code_when_repl_then_rejects_with_error() { + let result = execute_tool("REPL", &json!({"language": "python", "code": " "})); + + let error = result.expect_err("empty REPL code should fail"); + assert!(error.contains("code must not be empty")); + } + + #[test] + fn given_unsupported_language_when_repl_then_rejects_with_error() { + let result = execute_tool("REPL", &json!({"language": "ruby", "code": "puts 1"})); + + let error = result.expect_err("unsupported REPL language should fail"); + assert!(error.contains("unsupported REPL language: ruby")); + } + + #[test] + fn given_timeout_ms_when_repl_blocks_then_returns_timeout_error() { + let result = execute_tool( + "REPL", + &json!({ + "language": "python", + "code": "import time\ntime.sleep(1)", + "timeout_ms": 10 + }), + ); + + let error = result.expect_err("timed out REPL execution should fail"); + assert!(error.contains("REPL execution exceeded timeout of 10 ms")); + } + #[test] fn powershell_runs_via_stub_shell() { let _guard = env_lock()