diff --git a/src/rust/cryptography-x509/src/certificate.rs b/src/rust/cryptography-x509/src/certificate.rs index 59960242b..2a5616e93 100644 --- a/src/rust/cryptography-x509/src/certificate.rs +++ b/src/rust/cryptography-x509/src/certificate.rs @@ -36,7 +36,7 @@ pub struct TbsCertificate<'a> { } impl<'a> TbsCertificate<'a> { - pub fn extensions(&'a self) -> Result>, asn1::ObjectIdentifier> { + pub fn extensions(&'a self) -> Result, asn1::ObjectIdentifier> { Extensions::from_raw_extensions(self.raw_extensions.as_ref()) } } diff --git a/src/rust/cryptography-x509/src/extensions.rs b/src/rust/cryptography-x509/src/extensions.rs index b1138fec2..51c283af3 100644 --- a/src/rust/cryptography-x509/src/extensions.rs +++ b/src/rust/cryptography-x509/src/extensions.rs @@ -18,7 +18,7 @@ pub type RawExtensions<'a> = common::Asn1ReadableOrWritable< /// /// In particular, an `Extensions` cannot be constructed from a `RawExtensions` /// that contains duplicated extensions (by OID). -pub struct Extensions<'a>(RawExtensions<'a>); +pub struct Extensions<'a>(Option>); impl<'a> Extensions<'a> { /// Create an `Extensions` from the given `RawExtensions`. @@ -27,7 +27,7 @@ impl<'a> Extensions<'a> { /// OID, if there are any duplicates. pub fn from_raw_extensions( raw: Option<&RawExtensions<'a>>, - ) -> Result, asn1::ObjectIdentifier> { + ) -> Result { match raw { Some(raw_exts) => { let mut seen_oids = HashSet::new(); @@ -38,22 +38,22 @@ impl<'a> Extensions<'a> { } } - Ok(Some(Self(raw_exts.clone()))) + Ok(Self(Some(raw_exts.clone()))) } - None => Ok(None), + None => Ok(Self(None)), } } /// Retrieves the extension identified by the given OID, /// or None if the extension is not present (or no extensions are present). pub fn get_extension(&self, oid: &asn1::ObjectIdentifier) -> Option { - let mut extensions = self.0.unwrap_read().clone(); - - extensions.find(|ext| &ext.extn_id == oid) + self.0 + .as_ref() + .and_then(|exts| exts.unwrap_read().clone().find(|ext| &ext.extn_id == oid)) } /// Returns a reference to the underlying extensions. - pub fn as_raw(&self) -> &RawExtensions<'_> { + pub fn as_raw(&self) -> &Option> { &self.0 } } @@ -245,9 +245,7 @@ mod tests { let der = asn1::write_single(&extensions).unwrap(); let extensions: Extensions = - Extensions::from_raw_extensions(Some(&asn1::parse_single(&der).unwrap())) - .unwrap() - .unwrap(); + Extensions::from_raw_extensions(Some(&asn1::parse_single(&der).unwrap())).unwrap(); assert!(&extensions.get_extension(&BASIC_CONSTRAINTS_OID).is_some()); assert!(&extensions diff --git a/src/rust/src/x509/certificate.rs b/src/rust/src/x509/certificate.rs index 3784b1c9a..dbe761fb9 100644 --- a/src/rust/src/x509/certificate.rs +++ b/src/rust/src/x509/certificate.rs @@ -194,8 +194,17 @@ impl Certificate { let mut tbs_precert = val.tbs_cert.clone(); // Remove the SCT list extension match val.tbs_cert.extensions() { - Ok(Some(extensions)) => { - let readable_extensions = extensions.as_raw().unwrap_read().clone(); + Ok(extensions) => { + let readable_extensions = match extensions.as_raw() { + Some(raw_exts) => raw_exts.unwrap_read().clone(), + None => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Could not find any extensions in TBS certificate", + ), + )) + } + }; let ext_count = readable_extensions.len(); let filtered_extensions: Vec> = readable_extensions .filter(|x| x.extn_id != oid::PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS_OID) @@ -210,15 +219,11 @@ impl Certificate { let filtered_extensions: RawExtensions<'_> = Asn1ReadableOrWritable::new_write( asn1::SequenceOfWriter::new(filtered_extensions), ); + tbs_precert.raw_extensions = Some(filtered_extensions); let result = asn1::write_single(&tbs_precert)?; Ok(pyo3::types::PyBytes::new(py, &result)) } - Ok(None) => Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err( - "Could not find any extensions in TBS certificate", - ), - )), Err(oid) => { let oid_obj = oid_to_py_oid(py, &oid)?; Err(exceptions::DuplicateExtension::new_err(( diff --git a/src/rust/src/x509/common.rs b/src/rust/src/x509/common.rs index 94ae58d38..3c42f0c5d 100644 --- a/src/rust/src/x509/common.rs +++ b/src/rust/src/x509/common.rs @@ -410,8 +410,8 @@ pub(crate) fn parse_and_cache_extensions< let x509_module = py.import(pyo3::intern!(py, "cryptography.x509"))?; let exts = pyo3::types::PyList::empty(py); - if let Some(extensions) = extensions { - for raw_ext in extensions.as_raw().unwrap_read().clone() { + if let Some(extensions) = extensions.as_raw() { + for raw_ext in extensions.unwrap_read().clone() { let oid_obj = oid_to_py_oid(py, &raw_ext.extn_id)?; let extn_value = match parse_ext(&raw_ext.extn_id, raw_ext.extn_value)? {