1use std::collections::BTreeMap;
2
3use anyhow::{Context, Result};
4use futures::stream::TryStreamExt;
5use serde_json::{Map, Value};
6use sha2::{Digest, Sha256};
7use sqlx::{types::Json, Executor, PgPool, Row};
8use uuid::Uuid;
9
10pub trait PgExec<'a>: Executor<'a, Database = sqlx::Postgres> {}
11impl<'a, T> PgExec<'a> for T where T: Executor<'a, Database = sqlx::Postgres> {}
12
13fn canonicalize(value: &Value) -> Value {
14 match value {
15 Value::Object(map) => {
16 let mut ordered = BTreeMap::new();
17 for (key, val) in map {
18 ordered.insert(key.clone(), canonicalize(val));
19 }
20 let mut canonical = Map::with_capacity(ordered.len());
21 for (key, value) in ordered {
22 canonical.insert(key, value);
23 }
24 Value::Object(canonical)
25 }
26 Value::Array(items) => Value::Array(items.iter().map(canonicalize).collect()),
27 _ => value.clone(),
28 }
29}
30
31fn canonical_bytes(value: &Value) -> Result<Vec<u8>> {
32 let canonical = canonicalize(value);
33 serde_json::to_vec(&canonical).context("failed to serialise audit resource")
34}
35
36async fn fetch_prev_hash<'a, E>(exec: E) -> Result<Option<Vec<u8>>>
37where
38 E: PgExec<'a>,
39{
40 let prev = sqlx::query_scalar(
41 r#"
42 select hash
43 from audit_log
44 order by id desc
45 limit 1
46 for update
47 "#,
48 )
49 .fetch_optional(exec)
50 .await?;
51
52 Ok(prev)
53}
54
55#[allow(clippy::too_many_arguments)]
56async fn insert_audit_row<'a, E>(
57 exec: E,
58 actor: &str,
59 tenant_id: Option<Uuid>,
60 action: &str,
61 resource: &Value,
62 prev_hash: Option<Vec<u8>>,
63 hash: Vec<u8>,
64) -> Result<i64>
65where
66 E: PgExec<'a>,
67{
68 let id = sqlx::query_scalar(
69 r#"
70 insert into audit_log (actor, tenant_id, action, resource, prev_hash, hash)
71 values ($1, $2, $3, $4, $5, $6)
72 returning id
73 "#,
74 )
75 .bind(actor)
76 .bind(tenant_id)
77 .bind(action)
78 .bind(Json(resource.clone()))
79 .bind(prev_hash)
80 .bind(hash)
81 .fetch_one(exec)
82 .await?;
83
84 Ok(id)
85}
86
87#[derive(Clone)]
88pub struct Audit {
89 pool: PgPool,
90}
91
92impl Audit {
93 pub fn new(pool: PgPool) -> Self {
94 Self { pool }
95 }
96
97 pub fn pool(&self) -> &PgPool {
98 &self.pool
99 }
100
101 async fn append_inner<'a>(
102 &self,
103 mut tx: sqlx::Transaction<'a, sqlx::Postgres>,
104 actor: &str,
105 tenant_id: Option<Uuid>,
106 action: &str,
107 resource: &Value,
108 ) -> Result<(sqlx::Transaction<'a, sqlx::Postgres>, i64)> {
109 let prev = fetch_prev_hash(tx.as_mut()).await?;
110
111 let mut hasher = Sha256::new();
112 if let Some(prev_hash) = &prev {
113 hasher.update(prev_hash);
114 }
115 let bytes = canonical_bytes(resource)?;
116 hasher.update(&bytes);
117 let hash = hasher.finalize().to_vec();
118
119 let id =
120 insert_audit_row(tx.as_mut(), actor, tenant_id, action, resource, prev, hash).await?;
121
122 Ok((tx, id))
123 }
124
125 pub async fn append(
126 &self,
127 actor: &str,
128 tenant_id: Option<Uuid>,
129 action: &str,
130 resource: &Value,
131 ) -> Result<i64> {
132 let tx = self.pool.begin().await?;
133 let (tx, id) = self
134 .append_inner(tx, actor, tenant_id, action, resource)
135 .await?;
136 tx.commit().await?;
137 Ok(id)
138 }
139
140 pub async fn verify_chain(&self) -> Result<bool> {
141 let mut rows = sqlx::query(
142 r#"
143 select id, resource, prev_hash, hash
144 from audit_log
145 order by id asc
146 "#,
147 )
148 .fetch(&self.pool);
149
150 let mut expected_prev: Option<Vec<u8>> = None;
151
152 while let Some(row) = rows.try_next().await? {
153 let id: i64 = row.get("id");
154 let stored_prev: Option<Vec<u8>> = row.get("prev_hash");
155 let stored_hash: Vec<u8> = row.get("hash");
156 let resource: Json<Value> = row.get("resource");
157
158 if stored_prev.as_ref() != expected_prev.as_ref() {
159 tracing::warn!(id, "audit chain previous hash mismatch");
160 return Ok(false);
161 }
162
163 let mut hasher = Sha256::new();
164 if let Some(prev_hash) = &expected_prev {
165 hasher.update(prev_hash);
166 }
167 let bytes = canonical_bytes(&resource.0)?;
168 hasher.update(&bytes);
169 let expected_hash = hasher.finalize().to_vec();
170
171 if expected_hash != stored_hash {
172 tracing::warn!(id, "audit chain hash mismatch");
173 return Ok(false);
174 }
175
176 expected_prev = Some(stored_hash);
177 }
178
179 Ok(true)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use serde_json::json;
187 use sqlx::postgres::PgPoolOptions;
188 use testcontainers::clients::Cli;
189 use testcontainers::images::postgres::Postgres;
190 use uuid::Uuid;
191
192 fn docker_available() -> bool {
193 std::fs::metadata("/var/run/docker.sock").is_ok() || std::env::var("DOCKER_HOST").is_ok()
194 }
195
196 #[tokio::test]
197 async fn appends_and_verifies_chain() -> Result<()> {
198 if !docker_available() {
199 eprintln!("Skipping audit tests because Docker is unavailable");
200 return Ok(());
201 }
202
203 let docker = Cli::default();
204 let container = docker.run(Postgres::default());
205 let port = container.get_host_port_ipv4(5432);
206 let database_url = format!("postgres://postgres:postgres@127.0.0.1:{port}/postgres");
207
208 let pool = PgPoolOptions::new()
209 .max_connections(2)
210 .acquire_timeout(std::time::Duration::from_secs(5))
211 .connect(&database_url)
212 .await?;
213
214 sqlx::query(
215 r#"
216 create table if not exists audit_log (
217 id bigserial primary key,
218 ts timestamptz not null default now(),
219 actor text not null,
220 tenant_id uuid,
221 action text not null,
222 resource jsonb not null,
223 prev_hash bytea,
224 hash bytea not null
225 )
226 "#,
227 )
228 .execute(&pool)
229 .await?;
230 sqlx::query(
231 r#"
232 create index if not exists idx_audit_log_tenant_ts
233 on audit_log (tenant_id, ts)
234 "#,
235 )
236 .execute(&pool)
237 .await?;
238
239 let audit = Audit::new(pool.clone());
240 let tenant = Uuid::new_v4();
241
242 audit
243 .append(
244 "tester",
245 Some(tenant),
246 "run.submit",
247 &json!({ "run_id": "1" }),
248 )
249 .await?;
250 audit
251 .append(
252 "tester",
253 Some(tenant),
254 "step.started",
255 &json!({ "run_id": "1", "step_id": "a" }),
256 )
257 .await?;
258 audit
259 .append(
260 "tester",
261 Some(tenant),
262 "step.succeeded",
263 &json!({ "run_id": "1", "step_id": "a" }),
264 )
265 .await?;
266
267 assert!(audit.verify_chain().await?);
268
269 sqlx::query("update audit_log set hash = $1::bytea where id = $2")
270 .bind(vec![0u8; 32])
271 .bind(2_i64)
272 .execute(&pool)
273 .await?;
274
275 assert!(!audit.verify_chain().await?);
276
277 drop(container);
278 Ok(())
279 }
280}