diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 5678af8..7d23f35 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -38,6 +38,7 @@ mod session; #[cfg(test)] mod session_control; mod sse; +pub mod stale_base; pub mod stale_branch; pub mod summary_compression; pub mod task_packet; @@ -150,6 +151,10 @@ pub use session::{ SessionFork, SessionPromptEntry, }; pub use sse::{IncrementalSseParser, SseEvent}; +pub use stale_base::{ + check_base_commit, format_stale_base_warning, read_claw_base_file, resolve_expected_base, + BaseCommitSource, BaseCommitState, +}; pub use stale_branch::{ apply_policy, check_freshness, BranchFreshness, StaleBranchAction, StaleBranchEvent, StaleBranchPolicy, diff --git a/rust/crates/runtime/src/stale_base.rs b/rust/crates/runtime/src/stale_base.rs new file mode 100644 index 0000000..b432d30 --- /dev/null +++ b/rust/crates/runtime/src/stale_base.rs @@ -0,0 +1,429 @@ +#![allow(clippy::must_use_candidate)] +use std::path::Path; +use std::process::Command; + +/// Outcome of comparing the worktree HEAD against the expected base commit. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BaseCommitState { + /// HEAD matches the expected base commit. + Matches, + /// HEAD has diverged from the expected base. + Diverged { expected: String, actual: String }, + /// No expected base was supplied (neither flag nor file). + NoExpectedBase, + /// The working directory is not inside a git repository. + NotAGitRepo, +} + +/// Where the expected base commit originated from. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BaseCommitSource { + Flag(String), + File(String), +} + +/// Read the `.claw-base` file from the given directory and return the trimmed +/// commit hash, or `None` when the file is absent or empty. +pub fn read_claw_base_file(cwd: &Path) -> Option { + let path = cwd.join(".claw-base"); + let content = std::fs::read_to_string(path).ok()?; + let trimmed = content.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } +} + +/// Resolve the expected base commit: prefer the `--base-commit` flag value, +/// fall back to reading `.claw-base` from `cwd`. +pub fn resolve_expected_base(flag_value: Option<&str>, cwd: &Path) -> Option { + if let Some(value) = flag_value { + let trimmed = value.trim(); + if !trimmed.is_empty() { + return Some(BaseCommitSource::Flag(trimmed.to_string())); + } + } + read_claw_base_file(cwd).map(BaseCommitSource::File) +} + +/// Verify that the worktree HEAD matches `expected_base`. +/// +/// Returns [`BaseCommitState::NoExpectedBase`] when no expected commit is +/// provided (the check is effectively a no-op in that case). +pub fn check_base_commit(cwd: &Path, expected_base: Option<&BaseCommitSource>) -> BaseCommitState { + let Some(source) = expected_base else { + return BaseCommitState::NoExpectedBase; + }; + let expected_raw = match source { + BaseCommitSource::Flag(value) | BaseCommitSource::File(value) => value.as_str(), + }; + + let Some(head_sha) = resolve_head_sha(cwd) else { + return BaseCommitState::NotAGitRepo; + }; + + let Some(expected_sha) = resolve_rev(cwd, expected_raw) else { + // If the expected ref cannot be resolved, compare raw strings as a + // best-effort fallback (e.g. partial SHA provided by the caller). + return if head_sha.starts_with(expected_raw) || expected_raw.starts_with(&head_sha) { + BaseCommitState::Matches + } else { + BaseCommitState::Diverged { + expected: expected_raw.to_string(), + actual: head_sha, + } + }; + }; + + if head_sha == expected_sha { + BaseCommitState::Matches + } else { + BaseCommitState::Diverged { + expected: expected_sha, + actual: head_sha, + } + } +} + +/// Format a human-readable warning when the base commit has diverged. +/// +/// Returns `None` for non-warning states (`Matches`, `NoExpectedBase`). +pub fn format_stale_base_warning(state: &BaseCommitState) -> Option { + match state { + BaseCommitState::Diverged { expected, actual } => Some(format!( + "warning: worktree HEAD ({actual}) does not match expected base commit ({expected}). \ + Session may run against a stale codebase." + )), + BaseCommitState::NotAGitRepo => { + Some("warning: stale-base check skipped — not inside a git repository.".to_string()) + } + BaseCommitState::Matches | BaseCommitState::NoExpectedBase => None, + } +} + +fn resolve_head_sha(cwd: &Path) -> Option { + resolve_rev(cwd, "HEAD") +} + +fn resolve_rev(cwd: &Path, rev: &str) -> Option { + let output = Command::new("git") + .args(["rev-parse", rev]) + .current_dir(cwd) + .output() + .ok()?; + if !output.status.success() { + return None; + } + let sha = String::from_utf8(output.stdout).ok()?; + let trimmed = sha.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::process::Command; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-stale-base-{nanos}")) + } + + fn init_repo(path: &std::path::Path) { + fs::create_dir_all(path).expect("create repo dir"); + run(path, &["init", "--quiet", "-b", "main"]); + run(path, &["config", "user.email", "tests@example.com"]); + run(path, &["config", "user.name", "Stale Base Tests"]); + fs::write(path.join("init.txt"), "initial\n").expect("write init file"); + run(path, &["add", "."]); + run(path, &["commit", "-m", "initial commit", "--quiet"]); + } + + fn run(cwd: &std::path::Path, args: &[&str]) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .status() + .unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" "))); + assert!( + status.success(), + "git {} exited with {status}", + args.join(" ") + ); + } + + fn commit_file(repo: &std::path::Path, name: &str, msg: &str) { + fs::write(repo.join(name), format!("{msg}\n")).expect("write file"); + run(repo, &["add", name]); + run(repo, &["commit", "-m", msg, "--quiet"]); + } + + fn head_sha(repo: &std::path::Path) -> String { + let output = Command::new("git") + .args(["rev-parse", "HEAD"]) + .current_dir(repo) + .output() + .expect("git rev-parse HEAD"); + String::from_utf8(output.stdout) + .expect("valid utf8") + .trim() + .to_string() + } + + #[test] + fn matches_when_head_equals_expected_base() { + // given + let root = temp_dir(); + init_repo(&root); + let sha = head_sha(&root); + let source = BaseCommitSource::Flag(sha); + + // when + let state = check_base_commit(&root, Some(&source)); + + // then + assert_eq!(state, BaseCommitState::Matches); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn diverged_when_head_moved_past_expected_base() { + // given + let root = temp_dir(); + init_repo(&root); + let old_sha = head_sha(&root); + commit_file(&root, "extra.txt", "move head forward"); + let new_sha = head_sha(&root); + let source = BaseCommitSource::Flag(old_sha.clone()); + + // when + let state = check_base_commit(&root, Some(&source)); + + // then + assert_eq!( + state, + BaseCommitState::Diverged { + expected: old_sha, + actual: new_sha, + } + ); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn no_expected_base_when_source_is_none() { + // given + let root = temp_dir(); + init_repo(&root); + + // when + let state = check_base_commit(&root, None); + + // then + assert_eq!(state, BaseCommitState::NoExpectedBase); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn not_a_git_repo_when_outside_repo() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + let source = BaseCommitSource::Flag("abc1234".to_string()); + + // when + let state = check_base_commit(&root, Some(&source)); + + // then + assert_eq!(state, BaseCommitState::NotAGitRepo); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn reads_claw_base_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), "abc1234def5678\n").expect("write .claw-base"); + + // when + let value = read_claw_base_file(&root); + + // then + assert_eq!(value, Some("abc1234def5678".to_string())); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn returns_none_for_missing_claw_base_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + + // when + let value = read_claw_base_file(&root); + + // then + assert!(value.is_none()); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn returns_none_for_empty_claw_base_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), " \n").expect("write empty .claw-base"); + + // when + let value = read_claw_base_file(&root); + + // then + assert!(value.is_none()); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn resolve_expected_base_prefers_flag_over_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base"); + + // when + let source = resolve_expected_base(Some("from_flag"), &root); + + // then + assert_eq!( + source, + Some(BaseCommitSource::Flag("from_flag".to_string())) + ); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn resolve_expected_base_falls_back_to_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base"); + + // when + let source = resolve_expected_base(None, &root); + + // then + assert_eq!( + source, + Some(BaseCommitSource::File("from_file".to_string())) + ); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn resolve_expected_base_returns_none_when_nothing_available() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + + // when + let source = resolve_expected_base(None, &root); + + // then + assert!(source.is_none()); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn format_warning_returns_message_for_diverged() { + // given + let state = BaseCommitState::Diverged { + expected: "abc1234".to_string(), + actual: "def5678".to_string(), + }; + + // when + let warning = format_stale_base_warning(&state); + + // then + let message = warning.expect("should produce warning"); + assert!(message.contains("abc1234")); + assert!(message.contains("def5678")); + assert!(message.contains("stale codebase")); + } + + #[test] + fn format_warning_returns_none_for_matches() { + // given + let state = BaseCommitState::Matches; + + // when + let warning = format_stale_base_warning(&state); + + // then + assert!(warning.is_none()); + } + + #[test] + fn format_warning_returns_none_for_no_expected_base() { + // given + let state = BaseCommitState::NoExpectedBase; + + // when + let warning = format_stale_base_warning(&state); + + // then + assert!(warning.is_none()); + } + + #[test] + fn matches_with_claw_base_file_in_real_repo() { + // given + let root = temp_dir(); + init_repo(&root); + let sha = head_sha(&root); + fs::write(root.join(".claw-base"), format!("{sha}\n")).expect("write .claw-base"); + let source = resolve_expected_base(None, &root); + + // when + let state = check_base_commit(&root, source.as_ref()); + + // then + assert_eq!(state, BaseCommitState::Matches); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn diverged_with_claw_base_file_after_new_commit() { + // given + let root = temp_dir(); + init_repo(&root); + let old_sha = head_sha(&root); + fs::write(root.join(".claw-base"), format!("{old_sha}\n")).expect("write .claw-base"); + commit_file(&root, "new.txt", "advance head"); + let new_sha = head_sha(&root); + let source = resolve_expected_base(None, &root); + + // when + let state = check_base_commit(&root, source.as_ref()); + + // then + assert_eq!( + state, + BaseCommitState::Diverged { + expected: old_sha, + actual: new_sha, + } + ); + fs::remove_dir_all(&root).expect("cleanup"); + } +} diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index c1b77cd..ada8995 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -42,12 +42,13 @@ use init::initialize_repo; use plugins::{PluginHooks, PluginManager, PluginManagerConfig, PluginRegistry}; use render::{MarkdownStreamState, Spinner, TerminalRenderer}; use runtime::{ - clear_oauth_credentials, format_usd, generate_pkce_pair, generate_state, - load_oauth_credentials, load_system_prompt, parse_oauth_callback_request_target, - pricing_for_model, resolve_sandbox_status, save_oauth_credentials, ApiClient, ApiRequest, - AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, - ConversationMessage, ConversationRuntime, McpServer, McpServerManager, McpServerSpec, McpTool, - MessageRole, ModelPricing, OAuthAuthorizationRequest, OAuthConfig, OAuthTokenExchangeRequest, + check_base_commit, clear_oauth_credentials, format_stale_base_warning, format_usd, + generate_pkce_pair, generate_state, load_oauth_credentials, load_system_prompt, + parse_oauth_callback_request_target, pricing_for_model, resolve_expected_base, + resolve_sandbox_status, save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, + CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, ConversationMessage, + ConversationRuntime, McpServer, McpServerManager, McpServerSpec, McpTool, MessageRole, + ModelPricing, OAuthAuthorizationRequest, OAuthConfig, OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent, ResolvedPermissionMode, RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, }; @@ -88,6 +89,7 @@ const CLI_OPTION_SUGGESTIONS: &[&str] = &[ "--resume", "--print", "--compact", + "--base-commit", "-p", ]; @@ -198,7 +200,9 @@ fn run() -> Result<(), Box> { allowed_tools, permission_mode, compact: _, + base_commit, } => { + run_stale_base_preflight(base_commit.as_deref()); let stdin_context = read_piped_stdin(); let effective_prompt = merge_prompt_with_stdin(&prompt, stdin_context.as_deref()); LiveCli::new(model, true, allowed_tools, permission_mode)? @@ -217,7 +221,8 @@ fn run() -> Result<(), Box> { model, allowed_tools, permission_mode, - } => run_repl(model, allowed_tools, permission_mode)?, + base_commit, + } => run_repl(model, allowed_tools, permission_mode, base_commit)?, CliAction::HelpTopic(topic) => print_help_topic(topic), CliAction::Help { output_format } => print_help(output_format)?, } @@ -277,6 +282,7 @@ enum CliAction { allowed_tools: Option, permission_mode: PermissionMode, compact: bool, + base_commit: Option, }, Login { output_format: CliOutputFormat, @@ -299,6 +305,7 @@ enum CliAction { model: String, allowed_tools: Option, permission_mode: PermissionMode, + base_commit: Option, }, HelpTopic(LocalHelpTopic), // prompt-mode formatting is only supported for non-interactive runs @@ -341,6 +348,7 @@ fn parse_args(args: &[String]) -> Result { let mut wants_version = false; let mut allowed_tool_values = Vec::new(); let mut compact = false; + let mut base_commit: Option = None; let mut rest = Vec::new(); let mut index = 0; @@ -395,6 +403,17 @@ fn parse_args(args: &[String]) -> Result { compact = true; index += 1; } + "--base-commit" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --base-commit".to_string())?; + base_commit = Some(value.clone()); + index += 2; + } + flag if flag.starts_with("--base-commit=") => { + base_commit = Some(flag[14..].to_string()); + index += 1; + } "-p" => { // Claw Code compat: -p "prompt" = one-shot prompt let prompt = args[index + 1..].join(" "); @@ -409,6 +428,7 @@ fn parse_args(args: &[String]) -> Result { permission_mode: permission_mode_override .unwrap_or_else(default_permission_mode), compact, + base_commit: base_commit.clone(), }); } "--print" => { @@ -466,6 +486,7 @@ fn parse_args(args: &[String]) -> Result { model, allowed_tools, permission_mode, + base_commit, }); } if rest.first().map(String::as_str) == Some("--resume") { @@ -503,6 +524,7 @@ fn parse_args(args: &[String]) -> Result { allowed_tools, permission_mode, compact, + base_commit, }), SkillSlashDispatch::Local => Ok(CliAction::Skills { args, @@ -527,6 +549,7 @@ fn parse_args(args: &[String]) -> Result { allowed_tools, permission_mode, compact, + base_commit: base_commit.clone(), }) } other if other.starts_with('/') => parse_direct_slash_cli_action( @@ -536,6 +559,7 @@ fn parse_args(args: &[String]) -> Result { allowed_tools, permission_mode, compact, + base_commit, ), _other => Ok(CliAction::Prompt { prompt: rest.join(" "), @@ -544,6 +568,7 @@ fn parse_args(args: &[String]) -> Result { allowed_tools, permission_mode, compact, + base_commit, }), } } @@ -635,6 +660,7 @@ fn parse_direct_slash_cli_action( allowed_tools: Option, permission_mode: PermissionMode, compact: bool, + base_commit: Option, ) -> Result { let raw = rest.join(" "); match SlashCommand::parse(&raw) { @@ -661,6 +687,7 @@ fn parse_direct_slash_cli_action( allowed_tools, permission_mode, compact, + base_commit, }), SkillSlashDispatch::Local => Ok(CliAction::Skills { args, @@ -2665,11 +2692,28 @@ fn run_resume_command( } } +/// Stale-base preflight: verify the worktree HEAD matches the expected base +/// commit (from `--base-commit` flag or `.claw-base` file). Emits a warning to +/// stderr when the HEAD has diverged. +fn run_stale_base_preflight(flag_value: Option<&str>) { + let cwd = match env::current_dir() { + Ok(cwd) => cwd, + Err(_) => return, + }; + let source = resolve_expected_base(flag_value, &cwd); + let state = check_base_commit(&cwd, source.as_ref()); + if let Some(warning) = format_stale_base_warning(&state) { + eprintln!("{warning}"); + } +} + fn run_repl( model: String, allowed_tools: Option, permission_mode: PermissionMode, + base_commit: Option, ) -> Result<(), Box> { + run_stale_base_preflight(base_commit.as_deref()); let resolved_model = resolve_repl_model(model); let mut cli = LiveCli::new(resolved_model, true, allowed_tools, permission_mode)?; let mut editor = @@ -7980,6 +8024,7 @@ mod tests { model: DEFAULT_MODEL.to_string(), allowed_tools: None, permission_mode: PermissionMode::DangerFullAccess, + base_commit: None, } ); } @@ -8142,6 +8187,7 @@ mod tests { allowed_tools: None, permission_mode: PermissionMode::DangerFullAccess, compact: false, + base_commit: None, } ); }