diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index 2a6641afa..94a9e949a 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -19,13 +19,36 @@ fn check_length(data: &[u8]) -> CryptographyResult<()> { Ok(()) } +enum Aad<'a> { + List(&'a pyo3::types::PyList), +} + +fn process_aad( + ctx: &mut openssl::cipher_ctx::CipherCtx, + aad: Option>, +) -> CryptographyResult<()> { + if let Some(Aad::List(ads)) = aad { + for ad in ads.iter() { + let ad = ad.extract::>()?; + check_length(ad.as_bytes())?; + ctx.cipher_update(ad.as_bytes(), None)?; + } + } + + Ok(()) +} + fn encrypt_value<'p>( py: pyo3::Python<'p>, mut ctx: openssl::cipher_ctx::CipherCtx, plaintext: &[u8], + aad: Option>, tag_len: usize, tag_first: bool, ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + check_length(plaintext)?; + process_aad(&mut ctx, aad)?; + Ok(pyo3::types::PyBytes::new_with( py, plaintext.len() + tag_len, @@ -58,7 +81,10 @@ fn decrypt_value<'p>( py: pyo3::Python<'p>, mut ctx: openssl::cipher_ctx::CipherCtx, ciphertext: &[u8], + aad: Option>, ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + process_aad(&mut ctx, aad)?; + Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| { // AES SIV can error here if the data is invalid on decrypt let n = ctx @@ -150,26 +176,17 @@ impl AesSiv { ) -> CryptographyResult<&'p pyo3::types::PyBytes> { let key_buf = self.key.extract::>(py)?; let data_bytes = data.as_bytes(); + let aad = associated_data.map(Aad::List); if data_bytes.is_empty() { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("data must not be zero length"), )); }; - check_length(data_bytes)?; - let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; ctx.encrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?; - if let Some(ads) = associated_data { - for ad in ads.iter() { - let ad = ad.extract::>()?; - check_length(ad.as_bytes())?; - ctx.cipher_update(ad.as_bytes(), None)?; - } - } - - encrypt_value(py, ctx, data_bytes, 16, true) + encrypt_value(py, ctx, data_bytes, aad, 16, true) } fn decrypt<'p>( @@ -180,12 +197,7 @@ impl AesSiv { ) -> CryptographyResult<&'p pyo3::types::PyBytes> { let key_buf = self.key.extract::>(py)?; let data_bytes = data.as_bytes(); - - if data_bytes.is_empty() { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("data must not be zero length"), - )); - } + let aad = associated_data.map(Aad::List); let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; ctx.decrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?; @@ -199,16 +211,7 @@ impl AesSiv { let (tag, ciphertext) = data_bytes.split_at(16); ctx.set_tag(tag)?; - if let Some(ads) = associated_data { - for ad in ads.iter() { - let ad = ad.extract::>()?; - check_length(ad.as_bytes())?; - - ctx.cipher_update(ad.as_bytes(), None)?; - } - } - - decrypt_value(py, ctx, ciphertext) + decrypt_value(py, ctx, ciphertext, aad) } } diff --git a/tests/hazmat/primitives/test_aead.py b/tests/hazmat/primitives/test_aead.py index 7db9607af..ce90f6892 100644 --- a/tests/hazmat/primitives/test_aead.py +++ b/tests/hazmat/primitives/test_aead.py @@ -681,7 +681,7 @@ class TestAESSIV: with pytest.raises(ValueError): aessiv.encrypt(b"", None) - with pytest.raises(ValueError): + with pytest.raises(InvalidTag): aessiv.decrypt(b"", None) def test_vectors(self, backend, subtests):