From d49947efb0495a3d5f25314fd8aeb3c1ca356480 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Wed, 4 Dec 2024 21:24:04 -0800 Subject: [PATCH] handle case where a "valid" pkey does not contain a valid EC key (#12101) * handle case where a "valid" pkey does not contain a valid EC key * add test * skip the test in some scenarios --- src/rust/src/backend/ec.rs | 7 +++++-- tests/hazmat/primitives/test_ec.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/rust/src/backend/ec.rs b/src/rust/src/backend/ec.rs index 37bfc9123..98ee81155 100644 --- a/src/rust/src/backend/ec.rs +++ b/src/rust/src/backend/ec.rs @@ -135,8 +135,11 @@ pub(crate) fn private_key_from_pkey( py: pyo3::Python<'_>, pkey: &openssl::pkey::PKeyRef, ) -> CryptographyResult { - let curve = py_curve_from_curve(py, pkey.ec_key().unwrap().group())?; - check_key_infinity(&pkey.ec_key().unwrap())?; + let ec_key = pkey + .ec_key() + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid EC key"))?; + let curve = py_curve_from_curve(py, ec_key.group())?; + check_key_infinity(&ec_key)?; Ok(ECPrivateKey { pkey: pkey.to_owned(), curve: curve.into(), diff --git a/tests/hazmat/primitives/test_ec.py b/tests/hazmat/primitives/test_ec.py index 2a30c6661..9cf7b8290 100644 --- a/tests/hazmat/primitives/test_ec.py +++ b/tests/hazmat/primitives/test_ec.py @@ -466,6 +466,25 @@ class TestECDSAVectors: backend=backend, ) + @pytest.mark.supported( + only_if=( + lambda backend: rust_openssl.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER + or rust_openssl.CRYPTOGRAPHY_IS_BORINGSSL + ), + skip_message="LibreSSL and OpenSSL 1.1.1 handle this differently", + ) + def test_load_invalid_private_scalar_pem(self, backend): + _skip_curve_unsupported(backend, ec.SECP256R1()) + + data = load_vectors_from_file( + os.path.join( + "asymmetric", "PKCS8", "ec-invalid-private-scalar.pem" + ), + lambda pemfile: pemfile.read().encode(), + ) + with pytest.raises(ValueError): + serialization.load_pem_private_key(data, None) + def test_signatures(self, backend, subtests): vectors = itertools.chain( load_vectors_from_file(