Move more of the Rust AEAD logic into common functions (#9499)

This commit is contained in:
Alex Gaynor 2023-08-26 11:15:43 -04:00 committed by GitHub
parent faf318360e
commit 1031dfecff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 28 deletions

View file

@ -19,13 +19,36 @@ fn check_length(data: &[u8]) -> CryptographyResult<()> {
Ok(())
}
enum Aad<'a> {
List(&'a pyo3::types::PyList),
}
fn process_aad(
ctx: &mut openssl::cipher_ctx::CipherCtx,
aad: Option<Aad<'_>>,
) -> CryptographyResult<()> {
if let Some(Aad::List(ads)) = aad {
for ad in ads.iter() {
let ad = ad.extract::<CffiBuf<'_>>()?;
check_length(ad.as_bytes())?;
ctx.cipher_update(ad.as_bytes(), None)?;
}
}
Ok(())
}
fn encrypt_value<'p>(
py: pyo3::Python<'p>,
mut ctx: openssl::cipher_ctx::CipherCtx,
plaintext: &[u8],
aad: Option<Aad<'_>>,
tag_len: usize,
tag_first: bool,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
check_length(plaintext)?;
process_aad(&mut ctx, aad)?;
Ok(pyo3::types::PyBytes::new_with(
py,
plaintext.len() + tag_len,
@ -58,7 +81,10 @@ fn decrypt_value<'p>(
py: pyo3::Python<'p>,
mut ctx: openssl::cipher_ctx::CipherCtx,
ciphertext: &[u8],
aad: Option<Aad<'_>>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
process_aad(&mut ctx, aad)?;
Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| {
// AES SIV can error here if the data is invalid on decrypt
let n = ctx
@ -150,26 +176,17 @@ impl AesSiv {
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
let key_buf = self.key.extract::<CffiBuf<'_>>(py)?;
let data_bytes = data.as_bytes();
let aad = associated_data.map(Aad::List);
if data_bytes.is_empty() {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("data must not be zero length"),
));
};
check_length(data_bytes)?;
let mut ctx = openssl::cipher_ctx::CipherCtx::new()?;
ctx.encrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?;
if let Some(ads) = associated_data {
for ad in ads.iter() {
let ad = ad.extract::<CffiBuf<'_>>()?;
check_length(ad.as_bytes())?;
ctx.cipher_update(ad.as_bytes(), None)?;
}
}
encrypt_value(py, ctx, data_bytes, 16, true)
encrypt_value(py, ctx, data_bytes, aad, 16, true)
}
fn decrypt<'p>(
@ -180,12 +197,7 @@ impl AesSiv {
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
let key_buf = self.key.extract::<CffiBuf<'_>>(py)?;
let data_bytes = data.as_bytes();
if data_bytes.is_empty() {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("data must not be zero length"),
));
}
let aad = associated_data.map(Aad::List);
let mut ctx = openssl::cipher_ctx::CipherCtx::new()?;
ctx.decrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?;
@ -199,16 +211,7 @@ impl AesSiv {
let (tag, ciphertext) = data_bytes.split_at(16);
ctx.set_tag(tag)?;
if let Some(ads) = associated_data {
for ad in ads.iter() {
let ad = ad.extract::<CffiBuf<'_>>()?;
check_length(ad.as_bytes())?;
ctx.cipher_update(ad.as_bytes(), None)?;
}
}
decrypt_value(py, ctx, ciphertext)
decrypt_value(py, ctx, ciphertext, aad)
}
}

View file

@ -681,7 +681,7 @@ class TestAESSIV:
with pytest.raises(ValueError):
aessiv.encrypt(b"", None)
with pytest.raises(ValueError):
with pytest.raises(InvalidTag):
aessiv.decrypt(b"", None)
def test_vectors(self, backend, subtests):