From b93e165615217f0359992b333fa33fcf6f5cecf4 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 5 Apr 2024 23:41:57 -0400 Subject: [PATCH] Convert some types usage to bound (#10750) --- src/rust/src/backend/aead.rs | 39 +++++++++++++++++++++++++----------- src/rust/src/backend/dh.rs | 6 +++--- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index 55ac8b842..16ea74f20 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -532,8 +532,8 @@ impl ChaCha20Poly1305 { } #[staticmethod] - fn generate_key(py: pyo3::Python<'_>) -> CryptographyResult<&pyo3::PyAny> { - Ok(types::OS_URANDOM.get(py)?.call1((32,))?) + fn generate_key(py: pyo3::Python<'_>) -> CryptographyResult> { + Ok(types::OS_URANDOM.get_bound(py)?.call1((32,))?) } fn encrypt<'p>( @@ -638,14 +638,17 @@ impl AesGcm { } #[staticmethod] - fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + fn generate_key( + py: pyo3::Python<'_>, + bit_length: usize, + ) -> CryptographyResult> { if bit_length != 128 && bit_length != 192 && bit_length != 256 { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), )); } - Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + Ok(types::OS_URANDOM.get_bound(py)?.call1((bit_length / 8,))?) } fn encrypt<'p>( @@ -746,14 +749,17 @@ impl AesCcm { } #[staticmethod] - fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + fn generate_key( + py: pyo3::Python<'_>, + bit_length: usize, + ) -> CryptographyResult> { if bit_length != 128 && bit_length != 192 && bit_length != 256 { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), )); } - Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + Ok(types::OS_URANDOM.get_bound(py)?.call1((bit_length / 8,))?) } fn encrypt<'p>( @@ -876,14 +882,17 @@ impl AesSiv { } #[staticmethod] - fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + fn generate_key( + py: pyo3::Python<'_>, + bit_length: usize, + ) -> CryptographyResult> { if bit_length != 256 && bit_length != 384 && bit_length != 512 { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("bit_length must be 256, 384, or 512"), )); } - Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + Ok(types::OS_URANDOM.get_bound(py)?.call1((bit_length / 8,))?) } #[pyo3(signature = (data, associated_data))] @@ -970,14 +979,17 @@ impl AesOcb3 { } #[staticmethod] - fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + fn generate_key( + py: pyo3::Python<'_>, + bit_length: usize, + ) -> CryptographyResult> { if bit_length != 128 && bit_length != 192 && bit_length != 256 { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), )); } - Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + Ok(types::OS_URANDOM.get_bound(py)?.call1((bit_length / 8,))?) } #[pyo3(signature = (nonce, data, associated_data))] @@ -1076,14 +1088,17 @@ impl AesGcmSiv { } #[staticmethod] - fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + fn generate_key( + py: pyo3::Python<'_>, + bit_length: usize, + ) -> CryptographyResult> { if bit_length != 128 && bit_length != 192 && bit_length != 256 { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), )); } - Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + Ok(types::OS_URANDOM.get_bound(py)?.call1((bit_length / 8,))?) } #[pyo3(signature = (nonce, data, associated_data))] diff --git a/src/rust/src/backend/dh.rs b/src/rust/src/backend/dh.rs index 9d597b9ec..70a57d50b 100644 --- a/src/rust/src/backend/dh.rs +++ b/src/rust/src/backend/dh.rs @@ -229,7 +229,7 @@ impl DHPrivateKey { format: &pyo3::Bound<'p, pyo3::PyAny>, encryption_algorithm: &pyo3::Bound<'p, pyo3::PyAny>, ) -> CryptographyResult> { - if !format.is(types::PRIVATE_FORMAT_PKCS8.get(py)?) { + if !format.is(&types::PRIVATE_FORMAT_PKCS8.get_bound(py)?) { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err( "DH private keys support only PKCS8 serialization", @@ -263,7 +263,7 @@ impl DHPublicKey { encoding: &pyo3::Bound<'p, pyo3::PyAny>, format: &pyo3::Bound<'p, pyo3::PyAny>, ) -> CryptographyResult> { - if !format.is(types::PUBLIC_FORMAT_SUBJECT_PUBLIC_KEY_INFO.get(py)?) { + if !format.is(&types::PUBLIC_FORMAT_SUBJECT_PUBLIC_KEY_INFO.get_bound(py)?) { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err( "DH public keys support only SubjectPublicKeyInfo serialization", @@ -345,7 +345,7 @@ impl DHParameters { encoding: pyo3::Bound<'p, pyo3::PyAny>, format: pyo3::Bound<'p, pyo3::PyAny>, ) -> CryptographyResult> { - if !format.is(types::PARAMETER_FORMAT_PKCS3.get(py)?) { + if !format.is(&types::PARAMETER_FORMAT_PKCS3.get_bound(py)?) { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("Only PKCS3 serialization is supported"), ));