fleetforge_runtime/
memory.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use async_trait::async_trait;
6use fleetforge_trust::{Trust, TrustBoundary, TrustOrigin, TrustSource};
7use serde_json::{json, Value};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::guardrails::PolicyBundle;
12use crate::model::RunId;
13use crate::policy::DecisionEffect;
14
15/// Shared contract for run-scoped ephemeral state.
16#[async_trait]
17pub trait MemoryAdapter: Send + Sync {
18    async fn put(
19        &self,
20        run_id: RunId,
21        namespace: &str,
22        key: &str,
23        record: MemoryRecord,
24    ) -> Result<()>;
25    async fn get(&self, run_id: RunId, namespace: &str, key: &str) -> Result<Option<MemoryRecord>>;
26    async fn delete(&self, run_id: RunId, namespace: &str, key: &str) -> Result<()>;
27    async fn list(&self, run_id: RunId, namespace: &str) -> Result<Vec<String>>;
28}
29
30/// Optional extension for vector-aware memory backends (e.g., embeddings DB).
31#[async_trait]
32pub trait VectorMemoryAdapter: MemoryAdapter {
33    async fn upsert_vector(
34        &self,
35        run_id: RunId,
36        namespace: &str,
37        key: &str,
38        embedding: Vec<f32>,
39        metadata: Value,
40    ) -> Result<()>;
41
42    async fn query_vector(
43        &self,
44        run_id: RunId,
45        namespace: &str,
46        embedding: &[f32],
47        top_k: usize,
48    ) -> Result<Vec<VectorMatch>>;
49}
50
51/// Match returned by vector queries.
52#[derive(Debug, Clone)]
53pub struct VectorMatch {
54    pub key: String,
55    pub score: f32,
56    pub metadata: Value,
57}
58
59/// Stored memory payload with trust metadata.
60#[derive(Debug, Clone)]
61pub struct MemoryRecord {
62    pub value: Value,
63    pub trust: Trust,
64    pub trust_origin: Option<TrustOrigin>,
65}
66
67impl MemoryRecord {
68    pub fn new(value: Value) -> Self {
69        Self {
70            value,
71            trust: Trust::Untrusted,
72            trust_origin: None,
73        }
74    }
75
76    pub fn with_trust(mut self, trust: Trust) -> Self {
77        self.trust = trust;
78        self
79    }
80
81    pub fn with_origin(mut self, origin: TrustOrigin) -> Self {
82        self.trust_origin = Some(origin);
83        self
84    }
85}
86
87fn memory_source(namespace: &str, key: &str) -> TrustSource {
88    TrustSource::Memory {
89        namespace: namespace.to_string(),
90        key: key.to_string(),
91    }
92}
93
94fn ensure_memory_origin(
95    record: &mut MemoryRecord,
96    boundary: TrustBoundary,
97    run_id: RunId,
98    namespace: &str,
99    key: &str,
100    derived: bool,
101) {
102    let mut origin = record
103        .trust_origin
104        .clone()
105        .unwrap_or_else(|| TrustOrigin::new(boundary.clone()));
106    origin.boundary = boundary;
107    if origin.run_id.is_none() {
108        origin.run_id = Some(Uuid::from(run_id));
109    }
110    if origin.source.is_none() {
111        origin.source = Some(memory_source(namespace, key));
112    }
113    record.trust_origin = Some(origin.clone());
114    if derived {
115        record.trust = Trust::derived(origin);
116    }
117}
118
119/// Simple process-local memory adapter backed by a `RwLock<HashMap<...>>`.
120#[derive(Clone)]
121pub struct InMemoryAdapter {
122    store: Arc<RwLock<HashMap<(RunId, String, String), MemoryRecord>>>,
123    guardrails: Arc<PolicyBundle>,
124}
125
126#[async_trait]
127impl MemoryAdapter for InMemoryAdapter {
128    async fn put(
129        &self,
130        run_id: RunId,
131        namespace: &str,
132        key: &str,
133        record: MemoryRecord,
134    ) -> Result<()> {
135        let mut entry = record;
136        if !self.guardrails.is_empty() {
137            let payload = json!({
138                "action": "memory_write",
139                "namespace": namespace,
140                "key": key,
141                "value": entry.value.clone(),
142                "trust": entry.trust.clone(),
143            });
144            let outcome = self
145                .guardrails
146                .evaluate(TrustBoundary::IngressMemory, Some(run_id), None, payload)
147                .await?;
148            match outcome.effect {
149                DecisionEffect::Allow => {
150                    ensure_memory_origin(
151                        &mut entry,
152                        TrustBoundary::IngressMemory,
153                        run_id,
154                        namespace,
155                        key,
156                        false,
157                    );
158                }
159                DecisionEffect::Deny => {
160                    return Err(anyhow!(
161                        "guardrails denied memory write: {}",
162                        outcome.summary()
163                    ));
164                }
165                DecisionEffect::Redact => {
166                    if let Some(value) = outcome.value.get("value") {
167                        entry.value = value.clone();
168                    }
169                    ensure_memory_origin(
170                        &mut entry,
171                        TrustBoundary::IngressMemory,
172                        run_id,
173                        namespace,
174                        key,
175                        true,
176                    );
177                }
178            }
179        }
180        if entry.trust_origin.is_none() {
181            ensure_memory_origin(
182                &mut entry,
183                TrustBoundary::IngressMemory,
184                run_id,
185                namespace,
186                key,
187                matches!(entry.trust, Trust::Derived { .. }),
188            );
189        }
190        let mut guard = self.store.write().await;
191        guard.insert((run_id, namespace.to_owned(), key.to_owned()), entry);
192        if !self.guardrails.is_empty() {
193            self.guardrails.drain_events();
194        }
195        Ok(())
196    }
197
198    async fn get(&self, run_id: RunId, namespace: &str, key: &str) -> Result<Option<MemoryRecord>> {
199        let guard = self.store.read().await;
200        let record = guard
201            .get(&(run_id, namespace.to_owned(), key.to_owned()))
202            .cloned();
203        drop(guard);
204
205        let Some(mut record) = record else {
206            return Ok(None);
207        };
208
209        if matches!(record.trust, Trust::Trusted) {
210            ensure_memory_origin(
211                &mut record,
212                TrustBoundary::EgressMemory,
213                run_id,
214                namespace,
215                key,
216                false,
217            );
218            return Ok(Some(record));
219        }
220
221        if !self.guardrails.is_empty() {
222            let payload = json!({
223                "action": "memory_read",
224                "namespace": namespace,
225                "key": key,
226                "value": record.value.clone(),
227                "trust": record.trust.clone(),
228            });
229            let outcome = self
230                .guardrails
231                .evaluate(TrustBoundary::EgressMemory, Some(run_id), None, payload)
232                .await?;
233            match outcome.effect {
234                DecisionEffect::Allow => {
235                    ensure_memory_origin(
236                        &mut record,
237                        TrustBoundary::EgressMemory,
238                        run_id,
239                        namespace,
240                        key,
241                        false,
242                    );
243                }
244                DecisionEffect::Deny => {
245                    return Err(anyhow!(
246                        "guardrails denied memory read: {}",
247                        outcome.summary()
248                    ));
249                }
250                DecisionEffect::Redact => {
251                    if let Some(value) = outcome.value.get("value") {
252                        record.value = value.clone();
253                    }
254                    ensure_memory_origin(
255                        &mut record,
256                        TrustBoundary::EgressMemory,
257                        run_id,
258                        namespace,
259                        key,
260                        true,
261                    );
262                }
263            }
264        } else {
265            ensure_memory_origin(
266                &mut record,
267                TrustBoundary::EgressMemory,
268                run_id,
269                namespace,
270                key,
271                matches!(record.trust, Trust::Derived { .. }),
272            );
273        }
274
275        if !self.guardrails.is_empty() {
276            self.guardrails.drain_events();
277        }
278
279        Ok(Some(record))
280    }
281
282    async fn delete(&self, run_id: RunId, namespace: &str, key: &str) -> Result<()> {
283        let mut guard = self.store.write().await;
284        guard.remove(&(run_id, namespace.to_owned(), key.to_owned()));
285        Ok(())
286    }
287
288    async fn list(&self, run_id: RunId, namespace: &str) -> Result<Vec<String>> {
289        let guard = self.store.read().await;
290        let keys = guard
291            .keys()
292            .filter(|(id, ns, _)| *id == run_id && ns == namespace)
293            .map(|(_, _, key)| key.clone())
294            .collect();
295        Ok(keys)
296    }
297}
298
299impl InMemoryAdapter {
300    /// Creates a new adapter that optionally enforces the supplied guardrail policies.
301    pub fn new(guardrails: Arc<PolicyBundle>) -> Self {
302        Self {
303            store: Arc::new(RwLock::new(HashMap::new())),
304            guardrails,
305        }
306    }
307}
308
309impl Default for InMemoryAdapter {
310    fn default() -> Self {
311        Self::new(Arc::new(PolicyBundle::empty()))
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use serde_json::json;
318    use std::sync::{
319        atomic::{AtomicUsize, Ordering},
320        Arc,
321    };
322    use uuid::Uuid;
323
324    use super::*;
325    use crate::policy::{PolicyDecision, PolicyEngine, PolicyRequest};
326
327    #[tokio::test]
328    async fn in_memory_adapter_round_trip() {
329        let adapter = InMemoryAdapter::default();
330        let run = RunId(Uuid::new_v4());
331
332        adapter
333            .put(
334                run,
335                "session",
336                "foo",
337                MemoryRecord::new(json!({"value": 1})),
338            )
339            .await
340            .unwrap();
341
342        let fetched = adapter.get(run, "session", "foo").await.unwrap();
343        let record = fetched.unwrap();
344        assert_eq!(record.value["value"], 1);
345
346        let keys = adapter.list(run, "session").await.unwrap();
347        assert_eq!(keys, vec!["foo".to_string()]);
348
349        adapter.delete(run, "session", "foo").await.unwrap();
350        assert!(adapter.get(run, "session", "foo").await.unwrap().is_none());
351    }
352
353    struct CountingPolicy(Arc<AtomicUsize>);
354
355    #[async_trait]
356    impl PolicyEngine for CountingPolicy {
357        async fn evaluate(&self, _request: &PolicyRequest) -> Result<PolicyDecision> {
358            self.0.fetch_add(1, Ordering::SeqCst);
359            Ok(PolicyDecision::allow())
360        }
361    }
362
363    #[tokio::test]
364    async fn trusted_memory_read_skips_guardrails() {
365        let counter = Arc::new(AtomicUsize::new(0));
366        let bundle = PolicyBundle::new(vec![Arc::new(CountingPolicy(counter.clone()))]);
367        let adapter = InMemoryAdapter::new(Arc::new(bundle));
368
369        let run = RunId(Uuid::new_v4());
370        let record = MemoryRecord::new(Value::String("hello".into())).with_trust(Trust::Trusted);
371        adapter
372            .put(run, "session", "trusted", record)
373            .await
374            .unwrap();
375
376        let after_put = counter.load(Ordering::SeqCst);
377        let fetched = adapter
378            .get(run, "session", "trusted")
379            .await
380            .unwrap()
381            .unwrap();
382        let after_get = counter.load(Ordering::SeqCst);
383
384        assert_eq!(
385            after_put, after_get,
386            "guardrails should not run on trusted read"
387        );
388        assert!(matches!(fetched.trust, Trust::Trusted));
389        let origin = fetched
390            .trust_origin
391            .expect("trusted record should have origin");
392        assert_eq!(origin.boundary, TrustBoundary::EgressMemory);
393    }
394
395    struct RedactPolicy;
396
397    #[async_trait]
398    impl PolicyEngine for RedactPolicy {
399        async fn evaluate(&self, _request: &PolicyRequest) -> Result<PolicyDecision> {
400            Ok(PolicyDecision {
401                effect: DecisionEffect::Redact,
402                reason: Some("memory_redacted".into()),
403                patches: vec![json!({
404                    "op": "replace",
405                    "path": "/value",
406                    "value": "[memory-redacted]"
407                })],
408            })
409        }
410    }
411
412    #[tokio::test]
413    async fn redacted_memory_sets_derived_origin() {
414        let bundle = PolicyBundle::new(vec![Arc::new(RedactPolicy)]);
415        let adapter = InMemoryAdapter::new(Arc::new(bundle));
416        let run = RunId(Uuid::new_v4());
417
418        adapter
419            .put(
420                run,
421                "session",
422                "note",
423                MemoryRecord::new(Value::String("secret".into())),
424            )
425            .await
426            .unwrap();
427
428        let record = adapter.get(run, "session", "note").await.unwrap().unwrap();
429        assert_eq!(record.value, Value::String("[memory-redacted]".into()));
430        match record.trust {
431            Trust::Derived { ref origin } => {
432                assert_eq!(origin.boundary, TrustBoundary::EgressMemory);
433                assert!(matches!(origin.source, Some(TrustSource::Memory { .. })));
434            }
435            other => panic!("expected derived trust after redaction, got {:?}", other),
436        }
437    }
438}