From 3888bab58a2a7ccc34ece073ee7862027bbbcbd6 Mon Sep 17 00:00:00 2001 From: Eugene Yakubovich Date: Thu, 11 Dec 2025 10:27:14 -0800 Subject: [PATCH] feat: add Jwk::from_decoding_key Allow for constructing a Jwk from a decoding key. This allows it to be created from a DER encoded file, for example. This patch renames JwkUtils to KeyUtils. --- src/crypto/aws_lc/mod.rs | 22 ++++-- src/crypto/mod.rs | 51 ++++++++++--- src/crypto/rust_crypto/mod.rs | 26 +++++-- src/jwk.rs | 140 ++++++++++++++++++++++++++++------ 4 files changed, 192 insertions(+), 47 deletions(-) diff --git a/src/crypto/aws_lc/mod.rs b/src/crypto/aws_lc/mod.rs index 8787887f..6691992a 100644 --- a/src/crypto/aws_lc/mod.rs +++ b/src/crypto/aws_lc/mod.rs @@ -8,7 +8,7 @@ use aws_lc_rs::{ use crate::{ Algorithm, DecodingKey, EncodingKey, - crypto::{CryptoProvider, JwkUtils, JwtSigner, JwtVerifier}, + crypto::{CryptoProvider, JwtSigner, JwtVerifier, KeyUtils}, errors::{self, Error, ErrorKind}, jwk::{EllipticCurve, ThumbprintHash}, }; @@ -18,7 +18,7 @@ mod eddsa; mod hmac; mod rsa; -fn extract_rsa_public_key_components(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { +fn rsa_components_from_private_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { let key_pair = aws_sig::RsaKeyPair::from_der(key_content) .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; let public = key_pair.public_key(); @@ -26,7 +26,15 @@ fn extract_rsa_public_key_components(key_content: &[u8]) -> errors::Result<(Vec< Ok((components.n, components.e)) } -fn extract_ec_public_key_coordinates( +fn rsa_components_from_public_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { + let public = aws_lc_rs::rsa::PublicKey::from_der(key_content) + .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; + + let components = aws_sig::RsaPublicKeyComponents::>::from(&public); + Ok((components.n, components.e)) +} + +fn ec_components_from_private_key( key_content: &[u8], alg: Algorithm, ) -> errors::Result<(EllipticCurve, Vec, Vec)> { @@ -102,9 +110,11 @@ fn new_verifier( pub static DEFAULT_PROVIDER: CryptoProvider = CryptoProvider { signer_factory: new_signer, verifier_factory: new_verifier, - jwk_utils: JwkUtils { - extract_rsa_public_key_components, - extract_ec_public_key_coordinates, + key_utils: KeyUtils { + rsa_pub_components_from_private_key: rsa_components_from_private_key, + rsa_pub_components_from_public_key: rsa_components_from_public_key, + ec_pub_components_from_private_key: ec_components_from_private_key, + ec_pub_components_from_public_key: crate::crypto::ec_components_from_public_key, compute_digest, }, }; diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 488f217d..2c88a41a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -10,7 +10,7 @@ //! [`CryptoProvider`]: crate::crypto::CryptoProvider use crate::algorithms::Algorithm; -use crate::errors::Result; +use crate::errors::{self, ErrorKind, Result}; use crate::jwk::{EllipticCurve, ThumbprintHash}; use crate::{DecodingKey, EncodingKey}; @@ -86,7 +86,7 @@ pub struct CryptoProvider { /// A function that produces a [`JwtVerifier`] for a given [`Algorithm`] pub verifier_factory: fn(&Algorithm, &DecodingKey) -> Result>, /// Struct with utility functions for JWK processing. - pub jwk_utils: JwkUtils, + pub key_utils: KeyUtils, } impl CryptoProvider { @@ -123,7 +123,7 @@ See the documentation of the CryptoProvider type for more information. static INSTANCE: CryptoProvider = CryptoProvider { signer_factory: |_, _| panic!("{}", NOT_INSTALLED_ERROR), verifier_factory: |_, _| panic!("{}", NOT_INSTALLED_ERROR), - jwk_utils: JwkUtils::new_unimplemented(), + key_utils: KeyUtils::new_unimplemented(), }; &INSTANCE @@ -132,22 +132,29 @@ See the documentation of the CryptoProvider type for more information. } /// Holds utility functions required for JWK processing. -/// Use the [`JwkUtils::new_unimplemented`] function to initialize all values to dummies. +/// Use the [`KeyUtils::new_unimplemented`] function to initialize all values to dummies. #[derive(Clone, Debug)] -pub struct JwkUtils { +pub struct KeyUtils { /// Given a DER encoded private key, extract the RSA public key components (n, e) #[allow(clippy::type_complexity)] - pub extract_rsa_public_key_components: fn(&[u8]) -> Result<(Vec, Vec)>, + pub rsa_pub_components_from_private_key: fn(&[u8]) -> Result<(Vec, Vec)>, + /// Given a DER encoded public key, extract the RSA public key components (n, e) + #[allow(clippy::type_complexity)] + pub rsa_pub_components_from_public_key: fn(&[u8]) -> Result<(Vec, Vec)>, /// Given a DER encoded private key and an algorithm, extract the associated curve /// and the EC public key components (x, y) #[allow(clippy::type_complexity)] - pub extract_ec_public_key_coordinates: + pub ec_pub_components_from_private_key: fn(&[u8], Algorithm) -> Result<(EllipticCurve, Vec, Vec)>, + /// Given bitstring from DER encoded private key, extract the associated curve + /// and the EC public key components (x, y) + #[allow(clippy::type_complexity)] + pub ec_pub_components_from_public_key: fn(&[u8]) -> Result<(EllipticCurve, Vec, Vec)>, /// Given some data and a name of a hash function, compute hash_function(data) pub compute_digest: fn(&[u8], ThumbprintHash) -> Result>, } -impl JwkUtils { +impl KeyUtils { /// Initialises all values to dummies. /// Will lead to a panic when JWKs are required, so only use it if you don't want to support JWKs. pub const fn new_unimplemented() -> Self { @@ -157,10 +164,16 @@ Call CryptoProvider::install_default() before this point to select a provider ma See the documentation of the CryptoProvider type for more information. "###; Self { - extract_rsa_public_key_components: |_| { + rsa_pub_components_from_private_key: |_| { panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) }, - extract_ec_public_key_coordinates: |_, _| { + rsa_pub_components_from_public_key: |_| { + panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) + }, + ec_pub_components_from_private_key: |_, _| { + panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) + }, + ec_pub_components_from_public_key: |_| { panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR) }, compute_digest: |_, _| panic!("{}", NOT_INSTALLED_OR_UNIMPLEMENTED_ERROR), @@ -168,6 +181,24 @@ See the documentation of the CryptoProvider type for more information. } } +#[allow(unused)] +fn ec_components_from_public_key( + pub_bytes: &[u8], +) -> errors::Result<(EllipticCurve, Vec, Vec)> { + let (curve, pub_elem_bytes) = match pub_bytes.len() { + 65 => (EllipticCurve::P256, 32), + 97 => (EllipticCurve::P384, 48), + _ => return Err(ErrorKind::InvalidEcdsaKey.into()), + }; + + if pub_bytes[0] != 4 { + return Err(ErrorKind::InvalidEcdsaKey.into()); + } + + let (x, y) = pub_bytes[1..].split_at(pub_elem_bytes); + Ok((curve, x.to_vec(), y.to_vec())) +} + mod static_default { use std::sync::OnceLock; diff --git a/src/crypto/rust_crypto/mod.rs b/src/crypto/rust_crypto/mod.rs index cdb31c60..6dc372c9 100644 --- a/src/crypto/rust_crypto/mod.rs +++ b/src/crypto/rust_crypto/mod.rs @@ -1,11 +1,15 @@ -use ::rsa::{RsaPrivateKey, pkcs1::DecodeRsaPrivateKey, traits::PublicKeyParts}; +use ::rsa::{ + RsaPrivateKey, RsaPublicKey, + pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey}, + traits::PublicKeyParts, +}; use p256::{ecdsa::SigningKey as P256SigningKey, pkcs8::DecodePrivateKey}; use p384::ecdsa::SigningKey as P384SigningKey; use sha2::{Digest, Sha256, Sha384, Sha512}; use crate::{ Algorithm, DecodingKey, EncodingKey, - crypto::{CryptoProvider, JwkUtils, JwtSigner, JwtVerifier}, + crypto::{CryptoProvider, JwtSigner, JwtVerifier, KeyUtils}, errors::{self, Error, ErrorKind}, jwk::{EllipticCurve, ThumbprintHash}, }; @@ -15,14 +19,20 @@ mod eddsa; mod hmac; mod rsa; -fn extract_rsa_public_key_components(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { +fn rsa_components_from_private_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { let private_key = RsaPrivateKey::from_pkcs1_der(key_content) .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; let public_key = private_key.to_public_key(); Ok((public_key.n().to_bytes_be(), public_key.e().to_bytes_be())) } -fn extract_ec_public_key_coordinates( +fn rsa_components_from_public_key(key_content: &[u8]) -> errors::Result<(Vec, Vec)> { + let public_key = RsaPublicKey::from_pkcs1_der(key_content) + .map_err(|e| ErrorKind::InvalidRsaKey(e.to_string()))?; + Ok((public_key.n().to_bytes_be(), public_key.e().to_bytes_be())) +} + +fn ec_components_from_private_key( key_content: &[u8], alg: Algorithm, ) -> errors::Result<(EllipticCurve, Vec, Vec)> { @@ -108,9 +118,11 @@ fn new_verifier( pub static DEFAULT_PROVIDER: CryptoProvider = CryptoProvider { signer_factory: new_signer, verifier_factory: new_verifier, - jwk_utils: JwkUtils { - extract_rsa_public_key_components, - extract_ec_public_key_coordinates, + key_utils: KeyUtils { + rsa_pub_components_from_private_key: rsa_components_from_private_key, + rsa_pub_components_from_public_key: rsa_components_from_public_key, + ec_pub_components_from_private_key: ec_components_from_private_key, + ec_pub_components_from_public_key: crate::crypto::ec_components_from_public_key, compute_digest, }, }; diff --git a/src/jwk.rs b/src/jwk.rs index 615d76bc..3b0b15f6 100644 --- a/src/jwk.rs +++ b/src/jwk.rs @@ -11,7 +11,8 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer, de}; use crate::crypto::CryptoProvider; use crate::serialization::b64_encode; use crate::{ - Algorithm, EncodingKey, + Algorithm, DecodingKey, EncodingKey, + decoding::DecodingKeyKind, errors::{self, Error, ErrorKind}, }; @@ -222,6 +223,25 @@ impl FromStr for KeyAlgorithm { } } +impl From for KeyAlgorithm { + fn from(algorithm: Algorithm) -> Self { + match algorithm { + Algorithm::HS256 => KeyAlgorithm::HS256, + Algorithm::HS384 => KeyAlgorithm::HS384, + Algorithm::HS512 => KeyAlgorithm::HS512, + Algorithm::ES256 => KeyAlgorithm::ES256, + Algorithm::ES384 => KeyAlgorithm::ES384, + Algorithm::RS256 => KeyAlgorithm::RS256, + Algorithm::RS384 => KeyAlgorithm::RS384, + Algorithm::RS512 => KeyAlgorithm::RS512, + Algorithm::PS256 => KeyAlgorithm::PS256, + Algorithm::PS384 => KeyAlgorithm::PS384, + Algorithm::PS512 => KeyAlgorithm::PS512, + Algorithm::EdDSA => KeyAlgorithm::EdDSA, + } + } +} + impl fmt::Display for KeyAlgorithm { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:?}", self) @@ -437,23 +457,7 @@ impl Jwk { } pub fn from_encoding_key(key: &EncodingKey, alg: Algorithm) -> crate::errors::Result { Ok(Self { - common: CommonParameters { - key_algorithm: Some(match alg { - Algorithm::HS256 => KeyAlgorithm::HS256, - Algorithm::HS384 => KeyAlgorithm::HS384, - Algorithm::HS512 => KeyAlgorithm::HS512, - Algorithm::ES256 => KeyAlgorithm::ES256, - Algorithm::ES384 => KeyAlgorithm::ES384, - Algorithm::RS256 => KeyAlgorithm::RS256, - Algorithm::RS384 => KeyAlgorithm::RS384, - Algorithm::RS512 => KeyAlgorithm::RS512, - Algorithm::PS256 => KeyAlgorithm::PS256, - Algorithm::PS384 => KeyAlgorithm::PS384, - Algorithm::PS512 => KeyAlgorithm::PS512, - Algorithm::EdDSA => KeyAlgorithm::EdDSA, - }), - ..Default::default() - }, + common: CommonParameters { key_algorithm: Some(alg.into()), ..Default::default() }, algorithm: match key.family() { crate::algorithms::AlgorithmFamily::Hmac => { AlgorithmParameters::OctetKey(OctetKeyParameters { @@ -463,8 +467,8 @@ impl Jwk { } crate::algorithms::AlgorithmFamily::Rsa => { let (n, e) = (CryptoProvider::get_default() - .jwk_utils - .extract_rsa_public_key_components)( + .key_utils + .rsa_pub_components_from_private_key)( key.inner() )?; AlgorithmParameters::RSA(RSAKeyParameters { @@ -475,8 +479,8 @@ impl Jwk { } crate::algorithms::AlgorithmFamily::Ec => { let (curve, x, y) = (CryptoProvider::get_default() - .jwk_utils - .extract_ec_public_key_coordinates)( + .key_utils + .ec_pub_components_from_private_key)( key.inner(), alg )?; AlgorithmParameters::EllipticCurve(EllipticCurveKeyParameters { @@ -487,7 +491,66 @@ impl Jwk { }) } crate::algorithms::AlgorithmFamily::Ed => { - unimplemented!(); + unimplemented!("Edwards curve is not supported"); + } + }, + }) + } + + pub fn from_decoding_key( + key: &DecodingKey, + alg: Option, + ) -> crate::errors::Result { + Ok(Self { + common: CommonParameters { key_algorithm: alg.map(|a| a.into()), ..Default::default() }, + algorithm: match key.family() { + crate::algorithms::AlgorithmFamily::Hmac => { + let secret = match &key.kind() { + DecodingKeyKind::SecretOrDer(secret) => secret, + _ => return Err(ErrorKind::InvalidKeyFormat.into()), + }; + + AlgorithmParameters::OctetKey(OctetKeyParameters { + key_type: OctetKeyType::Octet, + value: b64_encode(secret), + }) + } + crate::algorithms::AlgorithmFamily::Rsa => { + let (n, e) = match &key.kind() { + DecodingKeyKind::RsaModulusExponent { n, e } => { + (b64_encode(n), b64_encode(e)) + } + DecodingKeyKind::SecretOrDer(der) => { + let (n, e) = (CryptoProvider::get_default() + .key_utils + .rsa_pub_components_from_public_key)( + der + )?; + (b64_encode(n), b64_encode(e)) + } + }; + + AlgorithmParameters::RSA(RSAKeyParameters { key_type: RSAKeyType::RSA, n, e }) + } + crate::algorithms::AlgorithmFamily::Ec => { + let (curve, x, y) = match &key.kind() { + DecodingKeyKind::SecretOrDer(pub_bytes) => (CryptoProvider::get_default() + .key_utils + .ec_pub_components_from_public_key)( + pub_bytes + )?, + _ => return Err(ErrorKind::InvalidKeyFormat.into()), + }; + + AlgorithmParameters::EllipticCurve(EllipticCurveKeyParameters { + key_type: EllipticCurveKeyType::EC, + curve, + x: b64_encode(x), + y: b64_encode(y), + }) + } + crate::algorithms::AlgorithmFamily::Ed => { + unimplemented!("Edwards curve is not supported"); } }, }) @@ -540,7 +603,7 @@ impl Jwk { }, }; - Ok(b64_encode((CryptoProvider::get_default().jwk_utils.compute_digest)( + Ok(b64_encode((CryptoProvider::get_default().key_utils.compute_digest)( pre.as_bytes(), hash_function, )?)) @@ -573,6 +636,7 @@ mod tests { ThumbprintHash, }; use crate::serialization::b64_encode; + use crate::{DecodingKey, EncodingKey}; #[test] #[wasm_bindgen_test] @@ -629,4 +693,32 @@ mod tests { assert_eq!(tp.as_str(), "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs"); } + + #[test] + #[cfg(feature = "use_pem")] + fn check_jwk_from_decoding_key_rsa() { + let enc_key = + EncodingKey::from_rsa_pem(include_bytes!("../tests/rsa/private_rsa_key_pkcs8.pem")) + .unwrap(); + let dec_key = + DecodingKey::from_rsa_pem(include_bytes!("../tests/rsa/public_rsa_key_pkcs8.pem")) + .unwrap(); + let expected_jwk = Jwk::from_encoding_key(&enc_key, Algorithm::RS256).unwrap(); + let jwk = Jwk::from_decoding_key(&dec_key, Some(Algorithm::RS256)).unwrap(); + assert_eq!(jwk, expected_jwk); + } + + #[test] + #[cfg(feature = "use_pem")] + fn check_jwk_from_decoding_key_ec() { + let enc_key = + EncodingKey::from_ec_pem(include_bytes!("../tests/ecdsa/private_ecdsa_key.pem")) + .unwrap(); + let dec_key = + DecodingKey::from_ec_pem(include_bytes!("../tests/ecdsa/public_ecdsa_key.pem")) + .unwrap(); + let expected_jwk = Jwk::from_encoding_key(&enc_key, Algorithm::ES256).unwrap(); + let jwk = Jwk::from_decoding_key(&dec_key, Some(Algorithm::ES256)).unwrap(); + assert_eq!(jwk, expected_jwk); + } }