1use std::collections::HashSet;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use base64::Engine;
6use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use uuid::Uuid;
10
11use crate::{capability_signer, trust_signer, Signer, BASE64_URL_SAFE};
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
15pub struct CapabilityToken {
16 pub jws: String,
17 pub key_id: String,
18 pub claims: CapabilityClaims,
19}
20
21impl CapabilityToken {
22 pub fn token_id(&self) -> Uuid {
23 self.claims.token_id
24 }
25}
26
27#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
29pub struct CapabilityClaims {
30 pub token_id: Uuid,
31 pub issued_at: DateTime<Utc>,
32 pub expires_at: DateTime<Utc>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub not_before: Option<DateTime<Utc>>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub nonce: Option<String>,
37 pub subject: CapabilityTokenSubject,
38 pub scope: CapabilityTokenScope,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub audience: Option<Vec<String>>,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
44pub struct CapabilityTokenSubject {
45 pub run_id: Uuid,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub step_id: Option<Uuid>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub attempt: Option<i32>,
50}
51
52#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
53pub struct CapabilityTokenScope {
54 pub tool: CapabilityToolScope,
55 pub schema: CapabilitySchemaRef,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub data_domains: Option<Vec<String>>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub budget: Option<CapabilityBudgetLimits>,
60}
61
62#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
63pub struct CapabilityToolScope {
64 pub name: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub id: Option<String>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub variant: Option<String>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
72pub struct CapabilitySchemaRef {
73 pub hash: String,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub version: Option<String>,
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq)]
79pub struct CapabilityBudgetLimits {
80 #[serde(skip_serializing_if = "Option::is_none")]
81 pub tokens: Option<i64>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub cost_usd: Option<f64>,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub duration_ms: Option<i64>,
86}
87
88pub fn mint_capability_token(
90 subject: CapabilityTokenSubject,
91 scope: CapabilityTokenScope,
92 audience: Option<Vec<String>>,
93 ttl: Duration,
94 signer: &dyn Signer,
95) -> Result<CapabilityToken> {
96 let issued_at = Utc::now();
97 let expires_at = issued_at + ttl;
98 let claims = CapabilityClaims {
99 token_id: Uuid::new_v4(),
100 issued_at,
101 expires_at,
102 not_before: Some(issued_at),
103 nonce: Some(Uuid::new_v4().to_string()),
104 subject,
105 scope,
106 audience,
107 };
108
109 let header = json!({
110 "alg": signer.algorithm().as_str(),
111 "typ": "JWT",
112 "kid": signer.key_id(),
113 });
114 let header_encoded = BASE64_URL_SAFE.encode(serde_json::to_vec(&header)?);
115 let payload_encoded = BASE64_URL_SAFE.encode(serde_json::to_vec(&claims)?);
116 let signing_input = format!("{}.{}", header_encoded, payload_encoded);
117 let envelope = signer.sign(signing_input.as_bytes())?;
118 let signature_encoded = BASE64_URL_SAFE.encode(envelope.signature.clone());
119 let jws = format!(
120 "{}.{}.{}",
121 header_encoded, payload_encoded, signature_encoded
122 );
123
124 Ok(CapabilityToken {
125 jws,
126 key_id: envelope.key_id,
127 claims,
128 })
129}
130
131pub fn verify_capability_token(token: &CapabilityToken) -> Result<()> {
133 let parts: Vec<&str> = token.jws.split('.').collect();
134 if parts.len() != 3 {
135 return Err(anyhow!("capability token must contain three JWS segments"));
136 }
137
138 let header_bytes = BASE64_URL_SAFE
139 .decode(parts[0])
140 .map_err(|err| anyhow!("failed to decode capability token header: {err}"))?;
141 let header: Value = serde_json::from_slice(&header_bytes)
142 .map_err(|err| anyhow!("failed to parse capability token header: {err}"))?;
143 let kid = header
144 .get("kid")
145 .and_then(Value::as_str)
146 .ok_or_else(|| anyhow!("capability token header missing 'kid'"))?;
147 let algorithm = header.get("alg").and_then(Value::as_str).unwrap_or("HS256");
148
149 let payload_bytes = BASE64_URL_SAFE
150 .decode(parts[1])
151 .map_err(|err| anyhow!("failed to decode capability token payload: {err}"))?;
152 let payload_claims: CapabilityClaims = serde_json::from_slice(&payload_bytes)
153 .map_err(|err| anyhow!("failed to parse capability token payload: {err}"))?;
154 if payload_claims != token.claims {
155 return Err(anyhow!("capability token payload mismatch"));
156 }
157
158 let signature_bytes = BASE64_URL_SAFE
159 .decode(parts[2])
160 .map_err(|err| anyhow!("failed to decode capability token signature: {err}"))?;
161 let signing_input = format!("{}.{}", parts[0], parts[1]);
162
163 let mut signers: Vec<Arc<dyn Signer>> = Vec::new();
164 if let Ok(signer) = capability_signer() {
165 signers.push(signer);
166 }
167 signers.push(trust_signer());
168 let mut seen: HashSet<String> = HashSet::new();
169
170 for signer in signers.into_iter() {
171 if !seen.insert(signer.key_id().to_string()) {
172 continue;
173 }
174 if signer.key_id() != kid {
175 continue;
176 }
177 if signer.algorithm().as_str() != algorithm {
178 return Err(anyhow!(
179 "capability token algorithm '{}' does not match signer '{}'",
180 algorithm,
181 signer.algorithm().as_str()
182 ));
183 }
184 if signer.verify(signing_input.as_bytes(), &signature_bytes)? {
185 return Ok(());
186 }
187 }
188
189 Err(anyhow!(
190 "no configured signer matches capability token kid '{}'",
191 kid
192 ))
193}