fleetforge_signer_azure_kv/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use azure_identity::DefaultAzureCredential;
3use base64::engine::general_purpose::STANDARD as BASE64;
4use base64::Engine as _;
5use fleetforge_trust::{
6    digest_for_algorithm, normalize_ecdsa_signature, Jwk, SignatureEnvelope, Signer,
7    SigningAlgorithm,
8};
9use once_cell::sync::OnceCell;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::sync::{Arc, Mutex};
13use tokio::runtime::Runtime;
14
15const SCOPE: &str = "https://vault.azure.net/.default";
16
17#[derive(Clone, Debug)]
18pub struct AzureKvSdkSignerConfig {
19    pub key_id: String,
20    pub algorithm: SigningAlgorithm,
21    pub public_key: Option<Jwk>,
22}
23
24impl AzureKvSdkSignerConfig {
25    pub fn kid(&self) -> String {
26        format!("azure-kv:{}", self.key_id)
27    }
28}
29
30pub struct AzureKvSdkSigner {
31    client: Client,
32    credential: Arc<DefaultAzureCredential>,
33    runtime: Arc<Mutex<Runtime>>,
34    key_id: String,
35    kid: String,
36    algorithm: SigningAlgorithm,
37    azure_algorithm: String,
38    public_key: Option<Jwk>,
39    fetched_public_key: OnceCell<Jwk>,
40}
41
42impl AzureKvSdkSigner {
43    pub fn new(config: AzureKvSdkSignerConfig) -> Result<Self> {
44        let runtime = Arc::new(Mutex::new(
45            Runtime::new().context("failed to create tokio runtime")?,
46        ));
47        let credential = Arc::new(
48            DefaultAzureCredential::create()
49                .context("failed to create Azure DefaultAzureCredential")?,
50        );
51        let client = Client::builder()
52            .build()
53            .context("failed to build reqwest client")?;
54        Ok(Self {
55            client,
56            credential,
57            runtime,
58            key_id: config.key_id.clone(),
59            kid: config.kid(),
60            algorithm: config.algorithm,
61            azure_algorithm: config.algorithm.as_str().to_string(),
62            public_key: config.public_key,
63            fetched_public_key: OnceCell::new(),
64        })
65    }
66
67    fn block_on<F, T>(&self, fut: F) -> T
68    where
69        F: std::future::Future<Output = T>,
70    {
71        self.runtime.lock().expect("runtime poisoned").block_on(fut)
72    }
73
74    fn auth_header(&self) -> Result<String> {
75        let credential = self.credential.clone();
76        let token = self
77            .block_on(async move { credential.get_token(SCOPE).await })
78            .context("failed to fetch Azure access token")?;
79        Ok(format!("Bearer {}", token.token.secret()))
80    }
81
82    fn fetch_public_key(&self) -> Result<Jwk> {
83        #[derive(Deserialize)]
84        struct KeyResponse {
85            key: Jwk,
86        }
87        let url = format!("{}?api-version=7.4", self.key_id);
88        let token = self.auth_header()?;
89        let resp = self.block_on(async {
90            let resp = self
91                .client
92                .get(&url)
93                .header("Authorization", &token)
94                .send()
95                .await
96                .context("azure key fetch request failed")?;
97            if !resp.status().is_success() {
98                let body = resp.text().await.unwrap_or_default();
99                return Err(anyhow!("azure key fetch failed: {}", body));
100            }
101            resp.json::<KeyResponse>()
102                .await
103                .context("failed to parse azure key response")
104        })?;
105        Ok(resp.key)
106    }
107
108    fn azure_sign(&self, payload: &[u8]) -> Result<Vec<u8>> {
109        #[derive(Serialize)]
110        struct SignRequest<'a> {
111            alg: &'a str,
112            value: String,
113        }
114        #[derive(Deserialize)]
115        struct SignResponse {
116            value: String,
117        }
118        let digest = digest_for_algorithm(&self.algorithm, payload)?;
119        let url = format!("{}/sign?api-version=7.4", self.key_id);
120        let token = self.auth_header()?;
121        let body = SignRequest {
122            alg: &self.azure_algorithm,
123            value: BASE64.encode(digest),
124        };
125        let resp = self.block_on(async {
126            let resp = self
127                .client
128                .post(&url)
129                .header("Authorization", &token)
130                .json(&body)
131                .send()
132                .await
133                .context("azure sign request failed")?;
134            if !resp.status().is_success() {
135                let body = resp.text().await.unwrap_or_default();
136                return Err(anyhow!("azure sign failed: {}", body));
137            }
138            resp.json::<SignResponse>()
139                .await
140                .context("failed to parse azure sign response")
141        })?;
142        BASE64
143            .decode(resp.value.as_bytes())
144            .context("invalid azure signature output")
145    }
146
147    fn azure_verify(&self, payload: &[u8], signature: &[u8]) -> Result<bool> {
148        #[derive(Serialize)]
149        struct VerifyRequest<'a> {
150            alg: &'a str,
151            digest: String,
152            value: String,
153        }
154        #[derive(Deserialize)]
155        struct VerifyResponse {
156            value: bool,
157        }
158        let digest = digest_for_algorithm(&self.algorithm, payload)?;
159        let url = format!("{}/verify?api-version=7.4", self.key_id);
160        let token = self.auth_header()?;
161        let body = VerifyRequest {
162            alg: &self.azure_algorithm,
163            digest: BASE64.encode(digest),
164            value: BASE64.encode(signature),
165        };
166        let verified = self.block_on(async {
167            let resp = self
168                .client
169                .post(&url)
170                .header("Authorization", &token)
171                .json(&body)
172                .send()
173                .await
174                .context("azure verify request failed")?;
175            if !resp.status().is_success() {
176                let body = resp.text().await.unwrap_or_default();
177                return Err(anyhow!("azure verify failed: {}", body));
178            }
179            let parsed = resp
180                .json::<VerifyResponse>()
181                .await
182                .context("failed to parse azure verify response")?;
183            Ok(parsed.value)
184        })?;
185        Ok(verified)
186    }
187}
188
189impl Signer for AzureKvSdkSigner {
190    fn algorithm(&self) -> SigningAlgorithm {
191        self.algorithm.clone()
192    }
193
194    fn key_id(&self) -> &str {
195        &self.kid
196    }
197
198    fn public_key_jwk(&self) -> Result<Jwk> {
199        if let Some(jwk) = &self.public_key {
200            return Ok(jwk.clone());
201        }
202        if let Some(jwk) = self.fetched_public_key.get() {
203            return Ok(jwk.clone());
204        }
205        let jwk = self.fetch_public_key()?;
206        let _ = self.fetched_public_key.set(jwk.clone());
207        Ok(jwk)
208    }
209
210    fn sign(&self, payload: &[u8]) -> Result<SignatureEnvelope> {
211        let mut signature = self.azure_sign(payload)?;
212        signature = normalize_ecdsa_signature(&signature, &self.algorithm)?;
213        Ok(SignatureEnvelope {
214            algorithm: self.algorithm(),
215            signature,
216            key_id: self.kid.clone(),
217            public_key: Some(self.public_key_jwk()?),
218        })
219    }
220
221    fn verify(&self, payload: &[u8], signature: &[u8]) -> Result<bool> {
222        self.azure_verify(payload, signature)
223    }
224}