fleetforge_signer_aws_kms/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use aws_config::meta::region::RegionProviderChain;
3use aws_sdk_kms::{config::Region, primitives::Blob, types::SigningAlgorithmSpec, Client};
4use fleetforge_trust::{
5    jwk_from_aws_public_key, normalize_ecdsa_signature, Jwk, SignatureEnvelope, Signer,
6    SigningAlgorithm,
7};
8use once_cell::sync::OnceCell;
9use std::sync::{Arc, Mutex};
10use tokio::runtime::Runtime;
11
12pub struct AwsKmsSdkSignerConfig {
13    pub key_id: String,
14    pub algorithm: SigningAlgorithm,
15    pub region: Option<String>,
16    pub profile: Option<String>,
17    pub public_key: Option<Jwk>,
18}
19
20impl AwsKmsSdkSignerConfig {
21    pub fn kid(&self) -> String {
22        format!("aws-kms:{}", self.key_id)
23    }
24}
25
26pub struct AwsKmsSdkSigner {
27    client: Client,
28    runtime: Arc<Mutex<Runtime>>,
29    key_id: String,
30    kid: String,
31    algorithm: SigningAlgorithm,
32    aws_algorithm: SigningAlgorithmSpec,
33    public_key: Option<Jwk>,
34    fetched_public_key: OnceCell<Jwk>,
35}
36
37impl AwsKmsSdkSigner {
38    pub fn new(config: AwsKmsSdkSignerConfig) -> Result<Self> {
39        let runtime = Arc::new(Mutex::new(
40            Runtime::new().context("failed to create tokio runtime")?,
41        ));
42        let aws_config = {
43            let mut loader = aws_config::from_env();
44            if let Some(profile) = config.profile.clone() {
45                loader = loader.profile_name(profile);
46            }
47            if let Some(region) = config.region.clone() {
48                loader = loader.region(Region::new(region));
49            } else {
50                let provider = RegionProviderChain::default_provider();
51                loader = loader.region(provider);
52            }
53            runtime
54                .lock()
55                .expect("runtime poisoned")
56                .block_on(loader.load())
57        };
58        let client = Client::new(&aws_config);
59        let aws_algorithm = map_aws_algorithm(&config.algorithm)?;
60        Ok(Self {
61            client,
62            runtime,
63            key_id: config.key_id.clone(),
64            kid: config.kid(),
65            algorithm: config.algorithm,
66            aws_algorithm,
67            public_key: config.public_key,
68            fetched_public_key: OnceCell::new(),
69        })
70    }
71
72    fn block_on<F, T>(&self, fut: F) -> T
73    where
74        F: std::future::Future<Output = T>,
75    {
76        self.runtime.lock().expect("runtime poisoned").block_on(fut)
77    }
78
79    fn fetch_public_key(&self) -> Result<Jwk> {
80        let resp = self
81            .block_on(self.client.get_public_key().key_id(&self.key_id).send())
82            .context("aws kms get_public_key failed")?;
83        let der = resp
84            .public_key()
85            .ok_or_else(|| anyhow!("AWS KMS did not return a public key"))?
86            .as_ref()
87            .to_vec();
88        let key_spec = resp
89            .key_spec()
90            .map(|spec| spec.as_str().to_string())
91            .unwrap_or_else(|| "".into());
92        jwk_from_aws_public_key(&der, &key_spec, &self.kid)
93    }
94}
95
96impl Signer for AwsKmsSdkSigner {
97    fn algorithm(&self) -> SigningAlgorithm {
98        self.algorithm.clone()
99    }
100
101    fn key_id(&self) -> &str {
102        &self.kid
103    }
104
105    fn public_key_jwk(&self) -> Result<Jwk> {
106        if let Some(jwk) = &self.public_key {
107            return Ok(jwk.clone());
108        }
109        if let Some(jwk) = self.fetched_public_key.get() {
110            return Ok(jwk.clone());
111        }
112        let jwk = self.fetch_public_key()?;
113        let _ = self.fetched_public_key.set(jwk.clone());
114        Ok(jwk)
115    }
116
117    fn sign(&self, payload: &[u8]) -> Result<SignatureEnvelope> {
118        let resp = self
119            .block_on(
120                self.client
121                    .sign()
122                    .key_id(&self.key_id)
123                    .message(Blob::new(payload.to_vec()))
124                    .message_type(aws_sdk_kms::types::MessageType::Raw)
125                    .signing_algorithm(self.aws_algorithm.clone())
126                    .send(),
127            )
128            .context("aws kms sign failed")?;
129        let mut signature = resp
130            .signature()
131            .map(|blob| blob.as_ref().to_vec())
132            .ok_or_else(|| anyhow!("aws kms sign did not return a signature"))?;
133        signature = normalize_ecdsa_signature(&signature, &self.algorithm)?;
134        Ok(SignatureEnvelope {
135            algorithm: self.algorithm(),
136            signature,
137            key_id: format!("aws-kms:{}", resp.key_id().unwrap_or(&self.key_id)),
138            public_key: Some(self.public_key_jwk()?),
139        })
140    }
141
142    fn verify(&self, payload: &[u8], signature: &[u8]) -> Result<bool> {
143        let resp = self
144            .block_on(
145                self.client
146                    .verify()
147                    .key_id(&self.key_id)
148                    .message(Blob::new(payload.to_vec()))
149                    .message_type(aws_sdk_kms::types::MessageType::Raw)
150                    .signature(Blob::new(signature.to_vec()))
151                    .signing_algorithm(self.aws_algorithm.clone())
152                    .send(),
153            )
154            .context("aws kms verify failed")?;
155        Ok(resp.signature_valid())
156    }
157}
158
159fn map_aws_algorithm(alg: &SigningAlgorithm) -> Result<SigningAlgorithmSpec> {
160    match alg.as_str() {
161        "ES256" => Ok(SigningAlgorithmSpec::EcdsaSha256),
162        "ES384" => Ok(SigningAlgorithmSpec::EcdsaSha384),
163        "PS256" => Ok(SigningAlgorithmSpec::RsassaPssSha256),
164        "PS384" => Ok(SigningAlgorithmSpec::RsassaPssSha384),
165        "RS256" => Ok(SigningAlgorithmSpec::RsassaPkcs1V15Sha256),
166        other => Err(anyhow!(
167            "unsupported AWS signing algorithm '{}' (supported: ES256, ES384, PS256, PS384, RS256)",
168            other
169        )),
170    }
171}