mirror of
https://github.com/instructkr/claw-code.git
synced 2026-04-05 23:54:50 +08:00
Add runtime OAuth primitives for PKCE generation, authorization URL building, token exchange request shaping, and refresh request shaping. Wire the API client to a real auth-source abstraction so future OAuth tokens can flow into Anthropic requests without bespoke header code. This keeps the slice bounded to foundations: no browser flow, callback listener, or token persistence. The API client still behaves compatibly for current API-key users while gaining explicit bearer-token and combined auth modeling. Constraint: Must keep the slice minimal and real while preserving current API client behavior Constraint: Repo verification requires fmt, tests, and clippy to pass cleanly Rejected: Implement full OAuth browser/listener flow now | too broad for the current parity-unblocking slice Rejected: Keep auth handling as ad hoc env reads only | blocks reuse by future OAuth integration paths Confidence: high Scope-risk: moderate Reversibility: clean Directive: Extend OAuth behavior by composing these request/auth primitives before adding session or storage orchestration Tested: cargo fmt --all; cargo clippy -p runtime -p api --all-targets -- -D warnings; cargo test -p runtime; cargo test -p api --tests Not-tested: live OAuth token exchange; callback listener flow; workspace-wide tests outside runtime/api
339 lines
10 KiB
Rust
339 lines
10 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::fs::File;
|
|
use std::io::{self, Read};
|
|
|
|
use sha2::{Digest, Sha256};
|
|
|
|
use crate::config::OAuthConfig;
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct OAuthTokenSet {
|
|
pub access_token: String,
|
|
pub refresh_token: Option<String>,
|
|
pub expires_at: Option<u64>,
|
|
pub scopes: Vec<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct PkceCodePair {
|
|
pub verifier: String,
|
|
pub challenge: String,
|
|
pub challenge_method: PkceChallengeMethod,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum PkceChallengeMethod {
|
|
S256,
|
|
}
|
|
|
|
impl PkceChallengeMethod {
|
|
#[must_use]
|
|
pub const fn as_str(self) -> &'static str {
|
|
match self {
|
|
Self::S256 => "S256",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct OAuthAuthorizationRequest {
|
|
pub authorize_url: String,
|
|
pub client_id: String,
|
|
pub redirect_uri: String,
|
|
pub scopes: Vec<String>,
|
|
pub state: String,
|
|
pub code_challenge: String,
|
|
pub code_challenge_method: PkceChallengeMethod,
|
|
pub extra_params: BTreeMap<String, String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct OAuthTokenExchangeRequest {
|
|
pub grant_type: &'static str,
|
|
pub code: String,
|
|
pub redirect_uri: String,
|
|
pub client_id: String,
|
|
pub code_verifier: String,
|
|
pub state: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct OAuthRefreshRequest {
|
|
pub grant_type: &'static str,
|
|
pub refresh_token: String,
|
|
pub client_id: String,
|
|
pub scopes: Vec<String>,
|
|
}
|
|
|
|
impl OAuthAuthorizationRequest {
|
|
#[must_use]
|
|
pub fn from_config(
|
|
config: &OAuthConfig,
|
|
redirect_uri: impl Into<String>,
|
|
state: impl Into<String>,
|
|
pkce: &PkceCodePair,
|
|
) -> Self {
|
|
Self {
|
|
authorize_url: config.authorize_url.clone(),
|
|
client_id: config.client_id.clone(),
|
|
redirect_uri: redirect_uri.into(),
|
|
scopes: config.scopes.clone(),
|
|
state: state.into(),
|
|
code_challenge: pkce.challenge.clone(),
|
|
code_challenge_method: pkce.challenge_method,
|
|
extra_params: BTreeMap::new(),
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
|
self.extra_params.insert(key.into(), value.into());
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn build_url(&self) -> String {
|
|
let mut params = vec![
|
|
("response_type", "code".to_string()),
|
|
("client_id", self.client_id.clone()),
|
|
("redirect_uri", self.redirect_uri.clone()),
|
|
("scope", self.scopes.join(" ")),
|
|
("state", self.state.clone()),
|
|
("code_challenge", self.code_challenge.clone()),
|
|
(
|
|
"code_challenge_method",
|
|
self.code_challenge_method.as_str().to_string(),
|
|
),
|
|
];
|
|
params.extend(
|
|
self.extra_params
|
|
.iter()
|
|
.map(|(key, value)| (key.as_str(), value.clone())),
|
|
);
|
|
let query = params
|
|
.into_iter()
|
|
.map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
|
|
.collect::<Vec<_>>()
|
|
.join("&");
|
|
format!(
|
|
"{}{}{}",
|
|
self.authorize_url,
|
|
if self.authorize_url.contains('?') {
|
|
'&'
|
|
} else {
|
|
'?'
|
|
},
|
|
query
|
|
)
|
|
}
|
|
}
|
|
|
|
impl OAuthTokenExchangeRequest {
|
|
#[must_use]
|
|
pub fn from_config(
|
|
config: &OAuthConfig,
|
|
code: impl Into<String>,
|
|
state: impl Into<String>,
|
|
verifier: impl Into<String>,
|
|
redirect_uri: impl Into<String>,
|
|
) -> Self {
|
|
let _ = config;
|
|
Self {
|
|
grant_type: "authorization_code",
|
|
code: code.into(),
|
|
redirect_uri: redirect_uri.into(),
|
|
client_id: config.client_id.clone(),
|
|
code_verifier: verifier.into(),
|
|
state: state.into(),
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
|
BTreeMap::from([
|
|
("grant_type", self.grant_type.to_string()),
|
|
("code", self.code.clone()),
|
|
("redirect_uri", self.redirect_uri.clone()),
|
|
("client_id", self.client_id.clone()),
|
|
("code_verifier", self.code_verifier.clone()),
|
|
("state", self.state.clone()),
|
|
])
|
|
}
|
|
}
|
|
|
|
impl OAuthRefreshRequest {
|
|
#[must_use]
|
|
pub fn from_config(
|
|
config: &OAuthConfig,
|
|
refresh_token: impl Into<String>,
|
|
scopes: Option<Vec<String>>,
|
|
) -> Self {
|
|
Self {
|
|
grant_type: "refresh_token",
|
|
refresh_token: refresh_token.into(),
|
|
client_id: config.client_id.clone(),
|
|
scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn form_params(&self) -> BTreeMap<&str, String> {
|
|
BTreeMap::from([
|
|
("grant_type", self.grant_type.to_string()),
|
|
("refresh_token", self.refresh_token.clone()),
|
|
("client_id", self.client_id.clone()),
|
|
("scope", self.scopes.join(" ")),
|
|
])
|
|
}
|
|
}
|
|
|
|
pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
|
|
let verifier = generate_random_token(32)?;
|
|
Ok(PkceCodePair {
|
|
challenge: code_challenge_s256(&verifier),
|
|
verifier,
|
|
challenge_method: PkceChallengeMethod::S256,
|
|
})
|
|
}
|
|
|
|
pub fn generate_state() -> io::Result<String> {
|
|
generate_random_token(32)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn code_challenge_s256(verifier: &str) -> String {
|
|
let digest = Sha256::digest(verifier.as_bytes());
|
|
base64url_encode(&digest)
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn loopback_redirect_uri(port: u16) -> String {
|
|
format!("http://localhost:{port}/callback")
|
|
}
|
|
|
|
fn generate_random_token(bytes: usize) -> io::Result<String> {
|
|
let mut buffer = vec![0_u8; bytes];
|
|
File::open("/dev/urandom")?.read_exact(&mut buffer)?;
|
|
Ok(base64url_encode(&buffer))
|
|
}
|
|
|
|
fn base64url_encode(bytes: &[u8]) -> String {
|
|
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
|
let mut output = String::new();
|
|
let mut index = 0;
|
|
while index + 3 <= bytes.len() {
|
|
let block = (u32::from(bytes[index]) << 16)
|
|
| (u32::from(bytes[index + 1]) << 8)
|
|
| u32::from(bytes[index + 2]);
|
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
|
output.push(TABLE[(block & 0x3F) as usize] as char);
|
|
index += 3;
|
|
}
|
|
match bytes.len().saturating_sub(index) {
|
|
1 => {
|
|
let block = u32::from(bytes[index]) << 16;
|
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
|
}
|
|
2 => {
|
|
let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
|
|
output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
|
|
output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
|
|
output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
|
|
}
|
|
_ => {}
|
|
}
|
|
output
|
|
}
|
|
|
|
fn percent_encode(value: &str) -> String {
|
|
let mut encoded = String::new();
|
|
for byte in value.bytes() {
|
|
match byte {
|
|
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
|
encoded.push(char::from(byte));
|
|
}
|
|
_ => {
|
|
use std::fmt::Write as _;
|
|
let _ = write!(&mut encoded, "%{byte:02X}");
|
|
}
|
|
}
|
|
}
|
|
encoded
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
|
|
OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
|
|
};
|
|
|
|
fn sample_config() -> OAuthConfig {
|
|
OAuthConfig {
|
|
client_id: "runtime-client".to_string(),
|
|
authorize_url: "https://console.test/oauth/authorize".to_string(),
|
|
token_url: "https://console.test/oauth/token".to_string(),
|
|
callback_port: Some(4545),
|
|
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
|
scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn s256_challenge_matches_expected_vector() {
|
|
assert_eq!(
|
|
code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
|
|
"E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn generates_pkce_pair_and_state() {
|
|
let pair = generate_pkce_pair().expect("pkce pair");
|
|
let state = generate_state().expect("state");
|
|
assert!(!pair.verifier.is_empty());
|
|
assert!(!pair.challenge.is_empty());
|
|
assert!(!state.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn builds_authorize_url_and_form_requests() {
|
|
let config = sample_config();
|
|
let pair = generate_pkce_pair().expect("pkce");
|
|
let url = OAuthAuthorizationRequest::from_config(
|
|
&config,
|
|
loopback_redirect_uri(4545),
|
|
"state-123",
|
|
&pair,
|
|
)
|
|
.with_extra_param("login_hint", "user@example.com")
|
|
.build_url();
|
|
assert!(url.starts_with("https://console.test/oauth/authorize?"));
|
|
assert!(url.contains("response_type=code"));
|
|
assert!(url.contains("client_id=runtime-client"));
|
|
assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
|
|
assert!(url.contains("login_hint=user%40example.com"));
|
|
|
|
let exchange = OAuthTokenExchangeRequest::from_config(
|
|
&config,
|
|
"auth-code",
|
|
"state-123",
|
|
pair.verifier,
|
|
loopback_redirect_uri(4545),
|
|
);
|
|
assert_eq!(
|
|
exchange.form_params().get("grant_type").map(String::as_str),
|
|
Some("authorization_code")
|
|
);
|
|
|
|
let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
|
|
assert_eq!(
|
|
refresh.form_params().get("scope").map(String::as_str),
|
|
Some("org:read user:write")
|
|
);
|
|
}
|
|
}
|