fleetforge_policy/
lib.rs

1//! Policy evaluation sandbox using Wasmtime and OPA-compiled WebAssembly policies.
2
3pub mod budget;
4pub mod packs;
5pub mod pii;
6pub mod tool_acl;
7
8use std::sync::Arc;
9
10#[cfg(feature = "wasm")]
11use anyhow::anyhow;
12#[cfg(feature = "wasm")]
13use anyhow::Context;
14use anyhow::Result;
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use serde_json::{Map, Value};
18#[cfg(feature = "wasm")]
19use tracing::debug;
20use tracing::warn;
21use uuid::Uuid;
22
23pub use budget::BudgetCapsPolicy;
24pub use packs::{
25    shared_default_pack, DefaultPack, DefaultPolicyConfig, GovernConfig, MeasureConfig,
26    RegulatedPack, RegulatedVertical,
27};
28pub use pii::{BasicPiiPolicy, PiiMode};
29pub use tool_acl::ToolAclPolicy;
30
31#[cfg(feature = "wasm")]
32use tokio::task::spawn_blocking;
33#[cfg(feature = "wasm")]
34use wasmtime::{Caller, Engine, Instance, Linker, Module, Store, Trap, TypedFunc};
35
36/// Represents a compiled policy module plus optional static data.
37#[cfg_attr(not(feature = "wasm"), allow(dead_code))]
38#[derive(Clone)]
39pub struct PolicyModule {
40    bytes: Arc<Vec<u8>>,
41    entrypoint: Option<String>,
42    data: Option<Value>,
43}
44
45impl PolicyModule {
46    /// Loads policy bytes (Wasm emitted from `opa build -t wasm`).
47    pub fn from_bytes(bytes: impl Into<Vec<u8>>) -> Self {
48        Self {
49            bytes: Arc::new(bytes.into()),
50            entrypoint: None,
51            data: None,
52        }
53    }
54
55    /// Overrides the default entrypoint (defaults to index 0).
56    pub fn with_entrypoint(mut self, entrypoint: impl Into<String>) -> Self {
57        self.entrypoint = Some(entrypoint.into());
58        self
59    }
60
61    /// Attaches static policy data (`opa build --data`).
62    pub fn with_data(mut self, data: Value) -> Self {
63        self.data = Some(data);
64        self
65    }
66}
67
68/// Policy decision produced by evaluating a policy.
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct Decision {
71    pub effect: DecisionEffect,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub reason: Option<String>,
74    #[serde(default, skip_serializing_if = "Vec::is_empty")]
75    pub patches: Vec<Value>,
76}
77
78impl Decision {
79    pub fn allow() -> Self {
80        Self {
81            effect: DecisionEffect::Allow,
82            reason: None,
83            patches: Vec::new(),
84        }
85    }
86
87    /// Merge multiple decisions, prioritising `deny` over `redact` over `allow`
88    /// and concatenating JSON patches.
89    pub fn merge<I>(decisions: I) -> Self
90    where
91        I: IntoIterator<Item = Decision>,
92    {
93        let mut effect = DecisionEffect::Allow;
94        let mut patches: Vec<Value> = Vec::new();
95        let mut reasons: Vec<String> = Vec::new();
96
97        for decision in decisions {
98            if let Some(reason) = decision.reason {
99                if !reason.is_empty() {
100                    reasons.push(reason);
101                }
102            }
103            if !decision.patches.is_empty() {
104                patches.extend(decision.patches);
105            }
106
107            match (effect, decision.effect) {
108                (DecisionEffect::Deny, _) => {}
109                (_, DecisionEffect::Deny) => effect = DecisionEffect::Deny,
110                (DecisionEffect::Redact, _) => {}
111                (_, DecisionEffect::Redact) => effect = DecisionEffect::Redact,
112                _ => {}
113            }
114        }
115
116        let reason = if reasons.is_empty() {
117            None
118        } else {
119            Some(reasons.join("; "))
120        };
121
122        Self {
123            effect,
124            reason,
125            patches,
126        }
127    }
128
129    #[cfg(feature = "wasm")]
130    fn from_value(value: Value) -> Result<Self> {
131        let payload = match value {
132            Value::Array(items) => {
133                let mut chosen = None;
134                for item in items {
135                    match item {
136                        Value::Object(obj) => {
137                            if let Some(res) = obj.get("result") {
138                                chosen = Some(res.clone());
139                                break;
140                            } else if chosen.is_none() {
141                                chosen = Some(Value::Object(obj));
142                            }
143                        }
144                        other if !other.is_null() && chosen.is_none() => {
145                            chosen = Some(other);
146                        }
147                        _ => {}
148                    }
149                }
150                chosen.ok_or_else(|| anyhow!("policy returned empty result set"))?
151            }
152            other => other,
153        };
154
155        let obj = payload
156            .as_object()
157            .ok_or_else(|| anyhow!("policy result must be an object"))?;
158
159        let effect = match obj.get("effect").and_then(Value::as_str).unwrap_or("allow") {
160            "allow" => DecisionEffect::Allow,
161            "deny" => DecisionEffect::Deny,
162            "redact" => DecisionEffect::Redact,
163            other => return Err(anyhow!("unknown decision effect '{}'", other)),
164        };
165
166        let reason = obj
167            .get("reason")
168            .and_then(Value::as_str)
169            .map(|s| s.to_owned());
170        let patches = obj
171            .get("patches")
172            .and_then(Value::as_array)
173            .map(|arr| arr.clone())
174            .unwrap_or_default();
175
176        Ok(Self {
177            effect,
178            reason,
179            patches,
180        })
181    }
182}
183
184/// Effect chosen by the policy engine.
185#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
186#[serde(rename_all = "snake_case")]
187pub enum DecisionEffect {
188    Allow,
189    Deny,
190    Redact,
191}
192
193#[derive(Debug, Clone)]
194pub struct PolicyRequest {
195    run_id: Uuid,
196    step_id: Uuid,
197    context: Value,
198}
199
200impl PolicyRequest {
201    pub fn new(run_id: Uuid, step_id: Uuid, context: Value) -> Self {
202        Self {
203            run_id,
204            step_id,
205            context,
206        }
207    }
208
209    pub fn run_id(&self) -> Uuid {
210        self.run_id
211    }
212
213    pub fn step_id(&self) -> Uuid {
214        self.step_id
215    }
216
217    pub fn context(&self) -> &Value {
218        &self.context
219    }
220
221    pub fn to_payload(&self) -> Value {
222        let mut map = match self.context.clone() {
223            Value::Object(map) => map,
224            other => {
225                let mut map = Map::new();
226                map.insert("context".to_string(), other);
227                map
228            }
229        };
230        map.entry("run_id")
231            .or_insert_with(|| Value::String(self.run_id.to_string()));
232        map.entry("step_id")
233            .or_insert_with(|| Value::String(self.step_id.to_string()));
234        Value::Object(map)
235    }
236}
237
238#[async_trait]
239pub trait PolicyEngine: Send + Sync {
240    async fn evaluate(&self, request: &PolicyRequest) -> Result<Decision>;
241}
242
243#[derive(Debug, Default, Clone)]
244pub struct AllowAllPolicy;
245
246#[async_trait]
247impl PolicyEngine for AllowAllPolicy {
248    async fn evaluate(&self, _request: &PolicyRequest) -> Result<Decision> {
249        Ok(Decision::allow())
250    }
251}
252
253impl AllowAllPolicy {
254    pub fn shared() -> Arc<dyn PolicyEngine> {
255        Arc::new(Self::default())
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use serde_json::json;
263
264    #[test]
265    fn merge_prefers_redact_over_allow() {
266        let allow = Decision::allow();
267        let redact = Decision {
268            effect: DecisionEffect::Redact,
269            reason: Some("mask".into()),
270            patches: vec![json!({"op": "remove", "path": "/secret"})],
271        };
272
273        let merged = Decision::merge(vec![allow, redact.clone()]);
274        assert_eq!(merged.effect, DecisionEffect::Redact);
275        assert!(merged.reason.as_deref().expect("reason").contains("mask"));
276        assert_eq!(merged.patches, redact.patches);
277    }
278
279    #[test]
280    fn merge_prefers_deny_over_redact() {
281        let redact = Decision {
282            effect: DecisionEffect::Redact,
283            reason: Some("mask".into()),
284            patches: vec![json!({"op": "remove", "path": "/secret"})],
285        };
286        let deny = Decision {
287            effect: DecisionEffect::Deny,
288            reason: Some("blocked".into()),
289            patches: vec![],
290        };
291
292        let merged = Decision::merge(vec![redact.clone(), deny.clone()]);
293        assert_eq!(merged.effect, DecisionEffect::Deny);
294        let reason = merged.reason.expect("reason");
295        assert!(reason.contains("blocked"));
296        assert!(reason.contains("mask"));
297        assert_eq!(merged.patches, redact.patches);
298    }
299}
300
301/// Wasmtime-backed PolicyEngine for OPA Rego policies compiled to Wasm.
302///
303/// This keeps policy evaluation lightweight while preserving a clear upgrade
304/// path to stronger isolation (e.g. Firecracker micro-VM sandboxes) by
305/// wrapping the Wasm invocation behind the [`PolicyEngine`] trait.
306#[cfg(feature = "wasm")]
307#[derive(Default)]
308struct HostState;
309
310#[cfg(feature = "wasm")]
311#[derive(Clone)]
312pub struct WasmPolicyEngine {
313    engine: Engine,
314    module: Arc<Module>,
315    entrypoint: Option<String>,
316    data: Option<Value>,
317    data_bytes: Option<Arc<Vec<u8>>>,
318}
319
320#[cfg(feature = "wasm")]
321impl WasmPolicyEngine {
322    pub fn new(policy: &PolicyModule) -> Result<Self> {
323        let engine = Engine::default();
324        let module = Module::from_binary(&engine, &policy.bytes)?;
325        let data_bytes = policy
326            .data
327            .as_ref()
328            .map(|value| serde_json::to_vec(value))
329            .transpose()?
330            .map(Arc::new);
331        Ok(Self {
332            engine,
333            module: Arc::new(module),
334            entrypoint: policy.entrypoint.clone(),
335            data: policy.data.clone(),
336            data_bytes,
337        })
338    }
339}
340
341#[cfg(feature = "wasm")]
342#[async_trait]
343impl PolicyEngine for WasmPolicyEngine {
344    async fn evaluate(&self, request: &PolicyRequest) -> Result<Decision> {
345        let module = self.module.clone();
346        let engine = self.engine.clone();
347        let input = request.to_payload();
348        let data = self.data.clone();
349        let data_bytes = self.data_bytes.clone();
350        let entrypoint = self.entrypoint.clone();
351
352        spawn_blocking(move || {
353            let mut linker = build_linker(&engine)?;
354            let mut store = Store::new(&engine, HostState::default());
355            let instance = linker
356                .instantiate(&mut store, module.as_ref())
357                .map_err(|trap| anyhow!("failed to instantiate policy module: {trap}"))?;
358
359            let memory = instance
360                .get_memory(&mut store, "memory")
361                .context("policy module missing exported memory")?;
362
363            let opa_malloc = typed_func::<i32, i32>(&instance, &mut store, "opa_malloc")?;
364            let opa_json_parse =
365                typed_func::<(i32, i32), i32>(&instance, &mut store, "opa_json_parse")?;
366            let opa_eval_ctx_new =
367                typed_func::<(), i32>(&instance, &mut store, "opa_eval_ctx_new")?;
368            let opa_eval_ctx_set_input =
369                typed_func::<(i32, i32), ()>(&instance, &mut store, "opa_eval_ctx_set_input")?;
370            let opa_eval_ctx_set_data = instance
371                .get_typed_func::<(i32, i32), ()>(&mut store, "opa_eval_ctx_set_data")
372                .ok();
373            let opa_eval_ctx_set_entrypoint = instance
374                .get_typed_func::<(i32, i32), ()>(&mut store, "opa_eval_ctx_set_entrypoint")
375                .ok();
376            let opa_eval_ctx_get_result =
377                typed_func::<i32, i32>(&instance, &mut store, "opa_eval_ctx_get_result")?;
378            let opa_json_dump = typed_func::<i32, i32>(&instance, &mut store, "opa_json_dump")?;
379            let opa_eval = typed_func::<i32, i32>(&instance, &mut store, "opa_eval")?;
380            let opa_heap_ptr_get =
381                typed_func::<(), i32>(&instance, &mut store, "opa_heap_ptr_get")?;
382            let opa_heap_ptr_set =
383                typed_func::<i32, ()>(&instance, &mut store, "opa_heap_ptr_set")?;
384
385            let heap_start = opa_heap_ptr_get.call(&mut store, ())?;
386            opa_heap_ptr_set
387                .call(&mut store, heap_start)
388                .context("failed to reset policy heap pointer")?;
389
390            let (input_addr, input_len) = write_json(&memory, &mut store, &opa_malloc, &input)?;
391            let input_val = opa_json_parse.call(&mut store, (input_addr, input_len))?;
392
393            let ctx = opa_eval_ctx_new.call(&mut store, ())?;
394            opa_eval_ctx_set_input.call(&mut store, (ctx, input_val))?;
395
396            if let (Some(bytes), Some(ref data_fn)) =
397                (data_bytes.as_ref(), opa_eval_ctx_set_data.as_ref())
398            {
399                let (data_addr, data_len) =
400                    write_json_bytes(&memory, &mut store, &opa_malloc, bytes.as_ref())?;
401                let data_parsed = opa_json_parse.call(&mut store, (data_addr, data_len))?;
402                data_fn.call(&mut store, (ctx, data_parsed))?;
403            } else if let (Some(data_value), Some(ref data_fn)) =
404                (data.as_ref(), opa_eval_ctx_set_data.as_ref())
405            {
406                let (data_addr, data_len) =
407                    write_json(&memory, &mut store, &opa_malloc, data_value)?;
408                let data_parsed = opa_json_parse.call(&mut store, (data_addr, data_len))?;
409                data_fn.call(&mut store, (ctx, data_parsed))?;
410            }
411
412            if let Some(ref set_entrypoint) = opa_eval_ctx_set_entrypoint {
413                if let Some(ref entrypoint_name) = entrypoint {
414                    let entrypoints_fn = instance
415                        .get_typed_func::<(), i32>(&mut store, "entrypoints")
416                        .context("policy module missing entrypoints function")?;
417                    let entrypoint_id =
418                        lookup_entrypoint(&entrypoints_fn, &memory, &mut store, entrypoint_name)?;
419                    set_entrypoint.call(&mut store, (ctx, entrypoint_id))?;
420                }
421            }
422
423            let status = opa_eval.call(&mut store, ctx)?;
424            if status != 0 {
425                let _ = opa_heap_ptr_set.call(&mut store, heap_start);
426                return Err(anyhow!("policy evaluation failed with status {status}"));
427            }
428
429            let result_addr = opa_eval_ctx_get_result.call(&mut store, ctx)?;
430            let json_addr = opa_json_dump.call(&mut store, result_addr)?;
431            let decision_json = read_json(&memory, &mut store, json_addr);
432            opa_heap_ptr_set
433                .call(&mut store, heap_start)
434                .context("failed to restore policy heap pointer")?;
435            let decision_json = decision_json?;
436
437            let decision = Decision::from_value(decision_json).unwrap_or_else(|err| {
438                debug!(error = %err, "policy returned unparsable decision; defaulting to allow");
439                Decision::allow()
440            });
441
442            Ok(decision)
443        })
444        .await??
445    }
446}
447
448#[cfg(not(feature = "wasm"))]
449#[derive(Clone)]
450pub struct WasmPolicyEngine;
451
452#[cfg(not(feature = "wasm"))]
453impl WasmPolicyEngine {
454    pub fn new(_policy: &PolicyModule) -> Result<Self> {
455        warn!(
456            "Wasm policy engine requested but crate built without 'wasm' feature; returning allow-all stub"
457        );
458        Ok(Self)
459    }
460}
461
462#[cfg(not(feature = "wasm"))]
463#[async_trait]
464impl PolicyEngine for WasmPolicyEngine {
465    async fn evaluate(&self, _request: &PolicyRequest) -> Result<Decision> {
466        Ok(Decision::allow())
467    }
468}
469
470#[cfg(feature = "wasm")]
471fn build_linker(engine: &Engine) -> Result<Linker<HostState>> {
472    let mut linker = Linker::new(engine);
473
474    linker.func_wrap(
475        "env",
476        "opa_abort",
477        |mut caller: Caller<'_, HostState>, addr: i32| -> Result<(), Trap> {
478            let message = read_c_string(&mut caller, addr)
479                .unwrap_or_else(|_| "<invalid abort message>".to_string());
480            Err(Trap::new(format!("policy aborted: {message}")))
481        },
482    )?;
483
484    linker.func_wrap(
485        "env",
486        "opa_println",
487        |mut caller: Caller<'_, HostState>, addr: i32| -> Result<i32, Trap> {
488            let message =
489                read_c_string(&mut caller, addr).unwrap_or_else(|_| "<invalid utf8>".to_string());
490            debug!(target = "policy::opa", "{message}");
491            Ok(0)
492        },
493    )?;
494
495    linker.func_wrap(
496        "env",
497        "opa_builtin0",
498        |_caller: Caller<'_, HostState>, _ctx: i32, builtin: i32| -> Result<i32, Trap> {
499            log_unimplemented_builtin(builtin, 0);
500            Ok(0)
501        },
502    )?;
503
504    linker.func_wrap(
505        "env",
506        "opa_builtin1",
507        |_caller: Caller<'_, HostState>,
508         _ctx: i32,
509         builtin: i32,
510         _arg1: i32|
511         -> Result<i32, Trap> {
512            log_unimplemented_builtin(builtin, 1);
513            Ok(0)
514        },
515    )?;
516
517    linker.func_wrap(
518        "env",
519        "opa_builtin2",
520        |_caller: Caller<'_, HostState>,
521         _ctx: i32,
522         builtin: i32,
523         _arg1: i32,
524         _arg2: i32|
525         -> Result<i32, Trap> {
526            log_unimplemented_builtin(builtin, 2);
527            Ok(0)
528        },
529    )?;
530
531    linker.func_wrap(
532        "env",
533        "opa_builtin3",
534        |_caller: Caller<'_, HostState>,
535         _ctx: i32,
536         builtin: i32,
537         _arg1: i32,
538         _arg2: i32,
539         _arg3: i32|
540         -> Result<i32, Trap> {
541            log_unimplemented_builtin(builtin, 3);
542            Ok(0)
543        },
544    )?;
545
546    linker.func_wrap(
547        "env",
548        "opa_builtin4",
549        |_caller: Caller<'_, HostState>,
550         _ctx: i32,
551         builtin: i32,
552         _arg1: i32,
553         _arg2: i32,
554         _arg3: i32,
555         _arg4: i32|
556         -> Result<i32, Trap> {
557            log_unimplemented_builtin(builtin, 4);
558            Ok(0)
559        },
560    )?;
561
562    Ok(linker)
563}
564
565#[cfg(feature = "wasm")]
566fn log_unimplemented_builtin(builtin: i32, arity: u8) {
567    warn!(
568        builtin_id = builtin,
569        arity, "OPA builtin not implemented; returning undefined value"
570    );
571}
572
573#[cfg(feature = "wasm")]
574fn lookup_entrypoint(
575    entrypoints_fn: &TypedFunc<(), i32>,
576    memory: &wasmtime::Memory,
577    store: &mut Store<HostState>,
578    name: &str,
579) -> Result<i32> {
580    let ptr = entrypoints_fn
581        .call(store, ())
582        .context("failed to call entrypoints function")?;
583    let mapping = read_json(memory, store, ptr)?;
584    let entry_id = mapping
585        .as_object()
586        .and_then(|map| map.get(name))
587        .and_then(Value::as_i64)
588        .ok_or_else(|| anyhow!("entrypoint '{name}' not defined in policy module"))?;
589    Ok(entry_id as i32)
590}
591
592#[cfg(feature = "wasm")]
593fn read_c_string(caller: &mut Caller<'_, HostState>, addr: i32) -> Result<String, Trap> {
594    let memory = caller
595        .get_export("memory")
596        .and_then(|export| export.into_memory())
597        .ok_or_else(|| Trap::new("policy module missing memory export"))?;
598
599    let data = memory.data(&caller);
600    let mut end = addr as usize;
601    while end < data.len() && data[end] != 0 {
602        end += 1;
603    }
604
605    if end >= data.len() {
606        return Err(Trap::new("unterminated string in policy memory"));
607    }
608
609    let bytes = &data[addr as usize..end];
610    String::from_utf8(bytes.to_vec()).map_err(|_| Trap::new("policy emitted invalid utf-8"))
611}
612
613#[cfg(feature = "wasm")]
614fn typed_func<P, R>(
615    instance: &Instance,
616    store: &mut Store<HostState>,
617    name: &str,
618) -> Result<TypedFunc<P, R>>
619where
620    P: wasmtime::WasmParams,
621    R: wasmtime::WasmResults,
622{
623    instance
624        .get_typed_func::<P, R>(store, name)
625        .with_context(|| format!("policy module missing function '{name}'"))
626}
627
628#[cfg(feature = "wasm")]
629fn write_json(
630    memory: &wasmtime::Memory,
631    store: &mut Store<HostState>,
632    opa_malloc: &TypedFunc<i32, i32>,
633    value: &Value,
634) -> Result<(i32, i32)> {
635    let json = serde_json::to_vec(value)?;
636    write_json_bytes(memory, store, opa_malloc, &json)
637}
638
639#[cfg(feature = "wasm")]
640fn write_json_bytes(
641    memory: &wasmtime::Memory,
642    store: &mut Store<HostState>,
643    opa_malloc: &TypedFunc<i32, i32>,
644    bytes: &[u8],
645) -> Result<(i32, i32)> {
646    let len = bytes.len();
647    let ptr = opa_malloc.call(store, len as i32)?;
648    let start = ptr as usize;
649    let end = start + len;
650    memory
651        .data_mut(store)
652        .get_mut(start..end)
653        .ok_or_else(|| anyhow!("failed to write JSON into policy memory"))?
654        .copy_from_slice(bytes);
655    Ok((ptr, len as i32))
656}
657
658#[cfg(feature = "wasm")]
659fn read_json(memory: &wasmtime::Memory, store: &mut Store<HostState>, addr: i32) -> Result<Value> {
660    let data = memory.data(store);
661    let mut end = addr as usize;
662    while end < data.len() && data[end] != 0 {
663        end += 1;
664    }
665
666    if end >= data.len() {
667        return Err(anyhow!("unterminated JSON string emitted by policy"));
668    }
669
670    let slice = &data[addr as usize..end];
671    let json_str = std::str::from_utf8(slice)?;
672    Ok(serde_json::from_str(json_str)?)
673}