1pub mod budget;
4pub mod packs;
5pub mod pii;
6pub mod tool_acl;
7
8use std::sync::Arc;
9
10#[cfg(feature = "wasm")]
11use anyhow::anyhow;
12#[cfg(feature = "wasm")]
13use anyhow::Context;
14use anyhow::Result;
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use serde_json::{Map, Value};
18#[cfg(feature = "wasm")]
19use tracing::debug;
20use tracing::warn;
21use uuid::Uuid;
22
23pub use budget::BudgetCapsPolicy;
24pub use packs::{
25 shared_default_pack, DefaultPack, DefaultPolicyConfig, GovernConfig, MeasureConfig,
26 RegulatedPack, RegulatedVertical,
27};
28pub use pii::{BasicPiiPolicy, PiiMode};
29pub use tool_acl::ToolAclPolicy;
30
31#[cfg(feature = "wasm")]
32use tokio::task::spawn_blocking;
33#[cfg(feature = "wasm")]
34use wasmtime::{Caller, Engine, Instance, Linker, Module, Store, Trap, TypedFunc};
35
36#[cfg_attr(not(feature = "wasm"), allow(dead_code))]
38#[derive(Clone)]
39pub struct PolicyModule {
40 bytes: Arc<Vec<u8>>,
41 entrypoint: Option<String>,
42 data: Option<Value>,
43}
44
45impl PolicyModule {
46 pub fn from_bytes(bytes: impl Into<Vec<u8>>) -> Self {
48 Self {
49 bytes: Arc::new(bytes.into()),
50 entrypoint: None,
51 data: None,
52 }
53 }
54
55 pub fn with_entrypoint(mut self, entrypoint: impl Into<String>) -> Self {
57 self.entrypoint = Some(entrypoint.into());
58 self
59 }
60
61 pub fn with_data(mut self, data: Value) -> Self {
63 self.data = Some(data);
64 self
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct Decision {
71 pub effect: DecisionEffect,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub reason: Option<String>,
74 #[serde(default, skip_serializing_if = "Vec::is_empty")]
75 pub patches: Vec<Value>,
76}
77
78impl Decision {
79 pub fn allow() -> Self {
80 Self {
81 effect: DecisionEffect::Allow,
82 reason: None,
83 patches: Vec::new(),
84 }
85 }
86
87 pub fn merge<I>(decisions: I) -> Self
90 where
91 I: IntoIterator<Item = Decision>,
92 {
93 let mut effect = DecisionEffect::Allow;
94 let mut patches: Vec<Value> = Vec::new();
95 let mut reasons: Vec<String> = Vec::new();
96
97 for decision in decisions {
98 if let Some(reason) = decision.reason {
99 if !reason.is_empty() {
100 reasons.push(reason);
101 }
102 }
103 if !decision.patches.is_empty() {
104 patches.extend(decision.patches);
105 }
106
107 match (effect, decision.effect) {
108 (DecisionEffect::Deny, _) => {}
109 (_, DecisionEffect::Deny) => effect = DecisionEffect::Deny,
110 (DecisionEffect::Redact, _) => {}
111 (_, DecisionEffect::Redact) => effect = DecisionEffect::Redact,
112 _ => {}
113 }
114 }
115
116 let reason = if reasons.is_empty() {
117 None
118 } else {
119 Some(reasons.join("; "))
120 };
121
122 Self {
123 effect,
124 reason,
125 patches,
126 }
127 }
128
129 #[cfg(feature = "wasm")]
130 fn from_value(value: Value) -> Result<Self> {
131 let payload = match value {
132 Value::Array(items) => {
133 let mut chosen = None;
134 for item in items {
135 match item {
136 Value::Object(obj) => {
137 if let Some(res) = obj.get("result") {
138 chosen = Some(res.clone());
139 break;
140 } else if chosen.is_none() {
141 chosen = Some(Value::Object(obj));
142 }
143 }
144 other if !other.is_null() && chosen.is_none() => {
145 chosen = Some(other);
146 }
147 _ => {}
148 }
149 }
150 chosen.ok_or_else(|| anyhow!("policy returned empty result set"))?
151 }
152 other => other,
153 };
154
155 let obj = payload
156 .as_object()
157 .ok_or_else(|| anyhow!("policy result must be an object"))?;
158
159 let effect = match obj.get("effect").and_then(Value::as_str).unwrap_or("allow") {
160 "allow" => DecisionEffect::Allow,
161 "deny" => DecisionEffect::Deny,
162 "redact" => DecisionEffect::Redact,
163 other => return Err(anyhow!("unknown decision effect '{}'", other)),
164 };
165
166 let reason = obj
167 .get("reason")
168 .and_then(Value::as_str)
169 .map(|s| s.to_owned());
170 let patches = obj
171 .get("patches")
172 .and_then(Value::as_array)
173 .map(|arr| arr.clone())
174 .unwrap_or_default();
175
176 Ok(Self {
177 effect,
178 reason,
179 patches,
180 })
181 }
182}
183
184#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum DecisionEffect {
188 Allow,
189 Deny,
190 Redact,
191}
192
193#[derive(Debug, Clone)]
194pub struct PolicyRequest {
195 run_id: Uuid,
196 step_id: Uuid,
197 context: Value,
198}
199
200impl PolicyRequest {
201 pub fn new(run_id: Uuid, step_id: Uuid, context: Value) -> Self {
202 Self {
203 run_id,
204 step_id,
205 context,
206 }
207 }
208
209 pub fn run_id(&self) -> Uuid {
210 self.run_id
211 }
212
213 pub fn step_id(&self) -> Uuid {
214 self.step_id
215 }
216
217 pub fn context(&self) -> &Value {
218 &self.context
219 }
220
221 pub fn to_payload(&self) -> Value {
222 let mut map = match self.context.clone() {
223 Value::Object(map) => map,
224 other => {
225 let mut map = Map::new();
226 map.insert("context".to_string(), other);
227 map
228 }
229 };
230 map.entry("run_id")
231 .or_insert_with(|| Value::String(self.run_id.to_string()));
232 map.entry("step_id")
233 .or_insert_with(|| Value::String(self.step_id.to_string()));
234 Value::Object(map)
235 }
236}
237
238#[async_trait]
239pub trait PolicyEngine: Send + Sync {
240 async fn evaluate(&self, request: &PolicyRequest) -> Result<Decision>;
241}
242
243#[derive(Debug, Default, Clone)]
244pub struct AllowAllPolicy;
245
246#[async_trait]
247impl PolicyEngine for AllowAllPolicy {
248 async fn evaluate(&self, _request: &PolicyRequest) -> Result<Decision> {
249 Ok(Decision::allow())
250 }
251}
252
253impl AllowAllPolicy {
254 pub fn shared() -> Arc<dyn PolicyEngine> {
255 Arc::new(Self::default())
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use serde_json::json;
263
264 #[test]
265 fn merge_prefers_redact_over_allow() {
266 let allow = Decision::allow();
267 let redact = Decision {
268 effect: DecisionEffect::Redact,
269 reason: Some("mask".into()),
270 patches: vec![json!({"op": "remove", "path": "/secret"})],
271 };
272
273 let merged = Decision::merge(vec![allow, redact.clone()]);
274 assert_eq!(merged.effect, DecisionEffect::Redact);
275 assert!(merged.reason.as_deref().expect("reason").contains("mask"));
276 assert_eq!(merged.patches, redact.patches);
277 }
278
279 #[test]
280 fn merge_prefers_deny_over_redact() {
281 let redact = Decision {
282 effect: DecisionEffect::Redact,
283 reason: Some("mask".into()),
284 patches: vec![json!({"op": "remove", "path": "/secret"})],
285 };
286 let deny = Decision {
287 effect: DecisionEffect::Deny,
288 reason: Some("blocked".into()),
289 patches: vec![],
290 };
291
292 let merged = Decision::merge(vec![redact.clone(), deny.clone()]);
293 assert_eq!(merged.effect, DecisionEffect::Deny);
294 let reason = merged.reason.expect("reason");
295 assert!(reason.contains("blocked"));
296 assert!(reason.contains("mask"));
297 assert_eq!(merged.patches, redact.patches);
298 }
299}
300
301#[cfg(feature = "wasm")]
307#[derive(Default)]
308struct HostState;
309
310#[cfg(feature = "wasm")]
311#[derive(Clone)]
312pub struct WasmPolicyEngine {
313 engine: Engine,
314 module: Arc<Module>,
315 entrypoint: Option<String>,
316 data: Option<Value>,
317 data_bytes: Option<Arc<Vec<u8>>>,
318}
319
320#[cfg(feature = "wasm")]
321impl WasmPolicyEngine {
322 pub fn new(policy: &PolicyModule) -> Result<Self> {
323 let engine = Engine::default();
324 let module = Module::from_binary(&engine, &policy.bytes)?;
325 let data_bytes = policy
326 .data
327 .as_ref()
328 .map(|value| serde_json::to_vec(value))
329 .transpose()?
330 .map(Arc::new);
331 Ok(Self {
332 engine,
333 module: Arc::new(module),
334 entrypoint: policy.entrypoint.clone(),
335 data: policy.data.clone(),
336 data_bytes,
337 })
338 }
339}
340
341#[cfg(feature = "wasm")]
342#[async_trait]
343impl PolicyEngine for WasmPolicyEngine {
344 async fn evaluate(&self, request: &PolicyRequest) -> Result<Decision> {
345 let module = self.module.clone();
346 let engine = self.engine.clone();
347 let input = request.to_payload();
348 let data = self.data.clone();
349 let data_bytes = self.data_bytes.clone();
350 let entrypoint = self.entrypoint.clone();
351
352 spawn_blocking(move || {
353 let mut linker = build_linker(&engine)?;
354 let mut store = Store::new(&engine, HostState::default());
355 let instance = linker
356 .instantiate(&mut store, module.as_ref())
357 .map_err(|trap| anyhow!("failed to instantiate policy module: {trap}"))?;
358
359 let memory = instance
360 .get_memory(&mut store, "memory")
361 .context("policy module missing exported memory")?;
362
363 let opa_malloc = typed_func::<i32, i32>(&instance, &mut store, "opa_malloc")?;
364 let opa_json_parse =
365 typed_func::<(i32, i32), i32>(&instance, &mut store, "opa_json_parse")?;
366 let opa_eval_ctx_new =
367 typed_func::<(), i32>(&instance, &mut store, "opa_eval_ctx_new")?;
368 let opa_eval_ctx_set_input =
369 typed_func::<(i32, i32), ()>(&instance, &mut store, "opa_eval_ctx_set_input")?;
370 let opa_eval_ctx_set_data = instance
371 .get_typed_func::<(i32, i32), ()>(&mut store, "opa_eval_ctx_set_data")
372 .ok();
373 let opa_eval_ctx_set_entrypoint = instance
374 .get_typed_func::<(i32, i32), ()>(&mut store, "opa_eval_ctx_set_entrypoint")
375 .ok();
376 let opa_eval_ctx_get_result =
377 typed_func::<i32, i32>(&instance, &mut store, "opa_eval_ctx_get_result")?;
378 let opa_json_dump = typed_func::<i32, i32>(&instance, &mut store, "opa_json_dump")?;
379 let opa_eval = typed_func::<i32, i32>(&instance, &mut store, "opa_eval")?;
380 let opa_heap_ptr_get =
381 typed_func::<(), i32>(&instance, &mut store, "opa_heap_ptr_get")?;
382 let opa_heap_ptr_set =
383 typed_func::<i32, ()>(&instance, &mut store, "opa_heap_ptr_set")?;
384
385 let heap_start = opa_heap_ptr_get.call(&mut store, ())?;
386 opa_heap_ptr_set
387 .call(&mut store, heap_start)
388 .context("failed to reset policy heap pointer")?;
389
390 let (input_addr, input_len) = write_json(&memory, &mut store, &opa_malloc, &input)?;
391 let input_val = opa_json_parse.call(&mut store, (input_addr, input_len))?;
392
393 let ctx = opa_eval_ctx_new.call(&mut store, ())?;
394 opa_eval_ctx_set_input.call(&mut store, (ctx, input_val))?;
395
396 if let (Some(bytes), Some(ref data_fn)) =
397 (data_bytes.as_ref(), opa_eval_ctx_set_data.as_ref())
398 {
399 let (data_addr, data_len) =
400 write_json_bytes(&memory, &mut store, &opa_malloc, bytes.as_ref())?;
401 let data_parsed = opa_json_parse.call(&mut store, (data_addr, data_len))?;
402 data_fn.call(&mut store, (ctx, data_parsed))?;
403 } else if let (Some(data_value), Some(ref data_fn)) =
404 (data.as_ref(), opa_eval_ctx_set_data.as_ref())
405 {
406 let (data_addr, data_len) =
407 write_json(&memory, &mut store, &opa_malloc, data_value)?;
408 let data_parsed = opa_json_parse.call(&mut store, (data_addr, data_len))?;
409 data_fn.call(&mut store, (ctx, data_parsed))?;
410 }
411
412 if let Some(ref set_entrypoint) = opa_eval_ctx_set_entrypoint {
413 if let Some(ref entrypoint_name) = entrypoint {
414 let entrypoints_fn = instance
415 .get_typed_func::<(), i32>(&mut store, "entrypoints")
416 .context("policy module missing entrypoints function")?;
417 let entrypoint_id =
418 lookup_entrypoint(&entrypoints_fn, &memory, &mut store, entrypoint_name)?;
419 set_entrypoint.call(&mut store, (ctx, entrypoint_id))?;
420 }
421 }
422
423 let status = opa_eval.call(&mut store, ctx)?;
424 if status != 0 {
425 let _ = opa_heap_ptr_set.call(&mut store, heap_start);
426 return Err(anyhow!("policy evaluation failed with status {status}"));
427 }
428
429 let result_addr = opa_eval_ctx_get_result.call(&mut store, ctx)?;
430 let json_addr = opa_json_dump.call(&mut store, result_addr)?;
431 let decision_json = read_json(&memory, &mut store, json_addr);
432 opa_heap_ptr_set
433 .call(&mut store, heap_start)
434 .context("failed to restore policy heap pointer")?;
435 let decision_json = decision_json?;
436
437 let decision = Decision::from_value(decision_json).unwrap_or_else(|err| {
438 debug!(error = %err, "policy returned unparsable decision; defaulting to allow");
439 Decision::allow()
440 });
441
442 Ok(decision)
443 })
444 .await??
445 }
446}
447
448#[cfg(not(feature = "wasm"))]
449#[derive(Clone)]
450pub struct WasmPolicyEngine;
451
452#[cfg(not(feature = "wasm"))]
453impl WasmPolicyEngine {
454 pub fn new(_policy: &PolicyModule) -> Result<Self> {
455 warn!(
456 "Wasm policy engine requested but crate built without 'wasm' feature; returning allow-all stub"
457 );
458 Ok(Self)
459 }
460}
461
462#[cfg(not(feature = "wasm"))]
463#[async_trait]
464impl PolicyEngine for WasmPolicyEngine {
465 async fn evaluate(&self, _request: &PolicyRequest) -> Result<Decision> {
466 Ok(Decision::allow())
467 }
468}
469
470#[cfg(feature = "wasm")]
471fn build_linker(engine: &Engine) -> Result<Linker<HostState>> {
472 let mut linker = Linker::new(engine);
473
474 linker.func_wrap(
475 "env",
476 "opa_abort",
477 |mut caller: Caller<'_, HostState>, addr: i32| -> Result<(), Trap> {
478 let message = read_c_string(&mut caller, addr)
479 .unwrap_or_else(|_| "<invalid abort message>".to_string());
480 Err(Trap::new(format!("policy aborted: {message}")))
481 },
482 )?;
483
484 linker.func_wrap(
485 "env",
486 "opa_println",
487 |mut caller: Caller<'_, HostState>, addr: i32| -> Result<i32, Trap> {
488 let message =
489 read_c_string(&mut caller, addr).unwrap_or_else(|_| "<invalid utf8>".to_string());
490 debug!(target = "policy::opa", "{message}");
491 Ok(0)
492 },
493 )?;
494
495 linker.func_wrap(
496 "env",
497 "opa_builtin0",
498 |_caller: Caller<'_, HostState>, _ctx: i32, builtin: i32| -> Result<i32, Trap> {
499 log_unimplemented_builtin(builtin, 0);
500 Ok(0)
501 },
502 )?;
503
504 linker.func_wrap(
505 "env",
506 "opa_builtin1",
507 |_caller: Caller<'_, HostState>,
508 _ctx: i32,
509 builtin: i32,
510 _arg1: i32|
511 -> Result<i32, Trap> {
512 log_unimplemented_builtin(builtin, 1);
513 Ok(0)
514 },
515 )?;
516
517 linker.func_wrap(
518 "env",
519 "opa_builtin2",
520 |_caller: Caller<'_, HostState>,
521 _ctx: i32,
522 builtin: i32,
523 _arg1: i32,
524 _arg2: i32|
525 -> Result<i32, Trap> {
526 log_unimplemented_builtin(builtin, 2);
527 Ok(0)
528 },
529 )?;
530
531 linker.func_wrap(
532 "env",
533 "opa_builtin3",
534 |_caller: Caller<'_, HostState>,
535 _ctx: i32,
536 builtin: i32,
537 _arg1: i32,
538 _arg2: i32,
539 _arg3: i32|
540 -> Result<i32, Trap> {
541 log_unimplemented_builtin(builtin, 3);
542 Ok(0)
543 },
544 )?;
545
546 linker.func_wrap(
547 "env",
548 "opa_builtin4",
549 |_caller: Caller<'_, HostState>,
550 _ctx: i32,
551 builtin: i32,
552 _arg1: i32,
553 _arg2: i32,
554 _arg3: i32,
555 _arg4: i32|
556 -> Result<i32, Trap> {
557 log_unimplemented_builtin(builtin, 4);
558 Ok(0)
559 },
560 )?;
561
562 Ok(linker)
563}
564
565#[cfg(feature = "wasm")]
566fn log_unimplemented_builtin(builtin: i32, arity: u8) {
567 warn!(
568 builtin_id = builtin,
569 arity, "OPA builtin not implemented; returning undefined value"
570 );
571}
572
573#[cfg(feature = "wasm")]
574fn lookup_entrypoint(
575 entrypoints_fn: &TypedFunc<(), i32>,
576 memory: &wasmtime::Memory,
577 store: &mut Store<HostState>,
578 name: &str,
579) -> Result<i32> {
580 let ptr = entrypoints_fn
581 .call(store, ())
582 .context("failed to call entrypoints function")?;
583 let mapping = read_json(memory, store, ptr)?;
584 let entry_id = mapping
585 .as_object()
586 .and_then(|map| map.get(name))
587 .and_then(Value::as_i64)
588 .ok_or_else(|| anyhow!("entrypoint '{name}' not defined in policy module"))?;
589 Ok(entry_id as i32)
590}
591
592#[cfg(feature = "wasm")]
593fn read_c_string(caller: &mut Caller<'_, HostState>, addr: i32) -> Result<String, Trap> {
594 let memory = caller
595 .get_export("memory")
596 .and_then(|export| export.into_memory())
597 .ok_or_else(|| Trap::new("policy module missing memory export"))?;
598
599 let data = memory.data(&caller);
600 let mut end = addr as usize;
601 while end < data.len() && data[end] != 0 {
602 end += 1;
603 }
604
605 if end >= data.len() {
606 return Err(Trap::new("unterminated string in policy memory"));
607 }
608
609 let bytes = &data[addr as usize..end];
610 String::from_utf8(bytes.to_vec()).map_err(|_| Trap::new("policy emitted invalid utf-8"))
611}
612
613#[cfg(feature = "wasm")]
614fn typed_func<P, R>(
615 instance: &Instance,
616 store: &mut Store<HostState>,
617 name: &str,
618) -> Result<TypedFunc<P, R>>
619where
620 P: wasmtime::WasmParams,
621 R: wasmtime::WasmResults,
622{
623 instance
624 .get_typed_func::<P, R>(store, name)
625 .with_context(|| format!("policy module missing function '{name}'"))
626}
627
628#[cfg(feature = "wasm")]
629fn write_json(
630 memory: &wasmtime::Memory,
631 store: &mut Store<HostState>,
632 opa_malloc: &TypedFunc<i32, i32>,
633 value: &Value,
634) -> Result<(i32, i32)> {
635 let json = serde_json::to_vec(value)?;
636 write_json_bytes(memory, store, opa_malloc, &json)
637}
638
639#[cfg(feature = "wasm")]
640fn write_json_bytes(
641 memory: &wasmtime::Memory,
642 store: &mut Store<HostState>,
643 opa_malloc: &TypedFunc<i32, i32>,
644 bytes: &[u8],
645) -> Result<(i32, i32)> {
646 let len = bytes.len();
647 let ptr = opa_malloc.call(store, len as i32)?;
648 let start = ptr as usize;
649 let end = start + len;
650 memory
651 .data_mut(store)
652 .get_mut(start..end)
653 .ok_or_else(|| anyhow!("failed to write JSON into policy memory"))?
654 .copy_from_slice(bytes);
655 Ok((ptr, len as i32))
656}
657
658#[cfg(feature = "wasm")]
659fn read_json(memory: &wasmtime::Memory, store: &mut Store<HostState>, addr: i32) -> Result<Value> {
660 let data = memory.data(store);
661 let mut end = addr as usize;
662 while end < data.len() && data[end] != 0 {
663 end += 1;
664 }
665
666 if end >= data.len() {
667 return Err(anyhow!("unterminated JSON string emitted by policy"));
668 }
669
670 let slice = &data[addr as usize..end];
671 let json_str = std::str::from_utf8(slice)?;
672 Ok(serde_json::from_str(json_str)?)
673}