Convert src/backend/ciphers.rs to new pyo3 APIs (#10703)

This commit is contained in:
Alex Gaynor 2024-04-04 09:10:49 -04:00 committed by GitHub
parent 6813602069
commit 632389f2fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 45 additions and 36 deletions

View file

@ -7,7 +7,7 @@ use crate::buf::{CffiBuf, CffiMutBuf};
use crate::error::{CryptographyError, CryptographyResult};
use crate::exceptions;
use crate::types;
use pyo3::prelude::PyAnyMethods;
use pyo3::prelude::{PyAnyMethods, PyModuleMethods};
use pyo3::IntoPy;
struct CipherContext {
@ -121,10 +121,10 @@ impl CipherContext {
&mut self,
py: pyo3::Python<'p>,
buf: &[u8],
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let mut out_buf = vec![0; buf.len() + self.ctx.block_size()];
let n = self.update_into(py, buf, &mut out_buf)?;
Ok(pyo3::types::PyBytes::new(py, &out_buf[..n]))
Ok(pyo3::types::PyBytes::new_bound(py, &out_buf[..n]))
}
fn update_into(
@ -146,7 +146,11 @@ impl CipherContext {
for chunk in buf.chunks(1 << 29) {
// SAFETY: We ensure that outbuf is sufficiently large above.
unsafe {
let n = if self.py_mode.as_ref(py).is_instance(types::XTS.get(py)?)? {
let n = if self
.py_mode
.bind(py)
.is_instance(&types::XTS.get_bound(py)?)?
{
self.ctx.cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..])).map_err(|_| {
pyo3::exceptions::PyValueError::new_err(
"In XTS mode you must supply at least a full block in the first update call. For AES this is 16 bytes."
@ -171,14 +175,14 @@ impl CipherContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let mut out_buf = vec![0; self.ctx.block_size()];
let n = self.ctx.cipher_final(&mut out_buf).or_else(|e| {
if e.errors().is_empty()
&& self
.py_mode
.as_ref(py)
.is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)?
.bind(py)
.is_instance(&types::MODE_WITH_AUTHENTICATION_TAG.get_bound(py)?)?
{
return Err(CryptographyError::from(exceptions::InvalidTag::new_err(())));
}
@ -188,7 +192,7 @@ impl CipherContext {
),
))
})?;
Ok(pyo3::types::PyBytes::new(py, &out_buf[..n]))
Ok(pyo3::types::PyBytes::new_bound(py, &out_buf[..n]))
}
}
@ -233,7 +237,7 @@ impl PyCipherContext {
&mut self,
py: pyo3::Python<'p>,
buf: CffiBuf<'_>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
get_mut_ctx(self.ctx.as_mut())?.update(py, buf.as_bytes())
}
@ -249,7 +253,7 @@ impl PyCipherContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let result = get_mut_ctx(self.ctx.as_mut())?.finalize(py)?;
self.ctx = None;
Ok(result)
@ -262,7 +266,7 @@ impl PyAEADEncryptionContext {
&mut self,
py: pyo3::Python<'p>,
buf: CffiBuf<'_>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let data = buf.as_bytes();
self.updated = true;
@ -314,16 +318,16 @@ impl PyAEADEncryptionContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let ctx = get_mut_ctx(self.ctx.as_mut())?;
let result = ctx.finalize(py)?;
// XXX: do not hard code 16
let tag = pyo3::types::PyBytes::new_with(py, 16, |t| {
let tag = pyo3::types::PyBytes::new_bound_with(py, 16, |t| {
ctx.ctx.tag(t).map_err(CryptographyError::from)?;
Ok(())
})?;
self.tag = Some(tag.into_py(py));
self.tag = Some(tag.unbind());
self.ctx = None;
Ok(result)
@ -349,7 +353,7 @@ impl PyAEADDecryptionContext {
&mut self,
py: pyo3::Python<'p>,
buf: CffiBuf<'_>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let data = buf.as_bytes();
self.updated = true;
@ -401,12 +405,12 @@ impl PyAEADDecryptionContext {
fn finalize<'p>(
&mut self,
py: pyo3::Python<'p>,
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let ctx = get_mut_ctx(self.ctx.as_mut())?;
if ctx
.py_mode
.as_ref(py)
.bind(py)
.getattr(pyo3::intern!(py, "tag"))?
.is_none()
{
@ -426,12 +430,12 @@ impl PyAEADDecryptionContext {
&mut self,
py: pyo3::Python<'p>,
tag: &[u8],
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let ctx = get_mut_ctx(self.ctx.as_mut())?;
if !ctx
.py_mode
.as_ref(py)
.bind(py)
.getattr(pyo3::intern!(py, "tag"))?
.is_none()
{
@ -444,7 +448,7 @@ impl PyAEADDecryptionContext {
let min_tag_length = ctx
.py_mode
.as_ref(py)
.bind(py)
.getattr(pyo3::intern!(py, "_min_tag_length"))?
.extract()?;
// XXX: Do not hard code 16
@ -506,8 +510,11 @@ fn create_decryption_ctx(
let mut ctx = CipherContext::new(py, algorithm, mode.clone(), openssl::symm::Mode::Decrypt)?;
if mode.is_instance(&types::MODE_WITH_AUTHENTICATION_TAG.get_bound(py)?)? {
if let Some(tag) = mode.getattr(pyo3::intern!(py, "tag"))?.extract()? {
ctx.ctx.set_tag(tag)?;
if let Some(tag) = mode
.getattr(pyo3::intern!(py, "tag"))?
.extract::<Option<pyo3::pybacked::PyBackedBytes>>()?
{
ctx.ctx.set_tag(&tag)?;
}
Ok(PyAEADDecryptionContext {
@ -536,31 +543,33 @@ fn cipher_supported(
}
#[pyo3::prelude::pyfunction]
fn _advance(ctx: &pyo3::PyAny, n: u64) {
if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADEncryptionContext>>() {
fn _advance(ctx: pyo3::Bound<'_, pyo3::PyAny>, n: u64) {
if let Ok(c) = ctx.downcast::<PyAEADEncryptionContext>() {
c.borrow_mut().bytes_remaining -= n;
} else if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADDecryptionContext>>() {
} else if let Ok(c) = ctx.downcast::<PyAEADDecryptionContext>() {
c.borrow_mut().bytes_remaining -= n;
}
}
#[pyo3::prelude::pyfunction]
fn _advance_aad(ctx: &pyo3::PyAny, n: u64) {
if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADEncryptionContext>>() {
fn _advance_aad(ctx: pyo3::Bound<'_, pyo3::PyAny>, n: u64) {
if let Ok(c) = ctx.downcast::<PyAEADEncryptionContext>() {
c.borrow_mut().aad_bytes_remaining -= n;
} else if let Ok(c) = ctx.downcast::<pyo3::PyCell<PyAEADDecryptionContext>>() {
} else if let Ok(c) = ctx.downcast::<PyAEADDecryptionContext>() {
c.borrow_mut().aad_bytes_remaining -= n;
}
}
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "ciphers")?;
m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(cipher_supported, m)?)?;
pub(crate) fn create_module(
py: pyo3::Python<'_>,
) -> pyo3::PyResult<pyo3::Bound<'_, pyo3::prelude::PyModule>> {
let m = pyo3::prelude::PyModule::new_bound(py, "ciphers")?;
m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(cipher_supported, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(_advance, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(_advance_aad, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(_advance, &m)?)?;
m.add_function(pyo3::wrap_pyfunction!(_advance_aad, &m)?)?;
m.add_class::<PyCipherContext>()?;
m.add_class::<PyAEADEncryptionContext>()?;

View file

@ -25,7 +25,7 @@ pub(crate) mod x448;
pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<()> {
module.add_submodule(aead::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(ciphers::create_module(module.py())?)?;
module.add_submodule(ciphers::create_module(module.py())?.into_gil_ref())?;
module.add_submodule(cmac::create_module(module.py())?)?;
module.add_submodule(dh::create_module(module.py())?)?;
module.add_submodule(dsa::create_module(module.py())?)?;