Make Extensions contain an optional RawExtensions (#8900)

This matter models how x.509 represents these things, and will make it easier to make Extensions an iterator in the future
This commit is contained in:
Alex Gaynor 2023-05-10 15:20:23 -04:00 committed by GitHub
parent 1ff6208ec7
commit a8aaf19c3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 21 deletions

View file

@ -36,7 +36,7 @@ pub struct TbsCertificate<'a> {
}
impl<'a> TbsCertificate<'a> {
pub fn extensions(&'a self) -> Result<Option<Extensions<'a>>, asn1::ObjectIdentifier> {
pub fn extensions(&'a self) -> Result<Extensions<'a>, asn1::ObjectIdentifier> {
Extensions::from_raw_extensions(self.raw_extensions.as_ref())
}
}

View file

@ -18,7 +18,7 @@ pub type RawExtensions<'a> = common::Asn1ReadableOrWritable<
///
/// In particular, an `Extensions` cannot be constructed from a `RawExtensions`
/// that contains duplicated extensions (by OID).
pub struct Extensions<'a>(RawExtensions<'a>);
pub struct Extensions<'a>(Option<RawExtensions<'a>>);
impl<'a> Extensions<'a> {
/// Create an `Extensions` from the given `RawExtensions`.
@ -27,7 +27,7 @@ impl<'a> Extensions<'a> {
/// OID, if there are any duplicates.
pub fn from_raw_extensions(
raw: Option<&RawExtensions<'a>>,
) -> Result<Option<Self>, asn1::ObjectIdentifier> {
) -> Result<Self, asn1::ObjectIdentifier> {
match raw {
Some(raw_exts) => {
let mut seen_oids = HashSet::new();
@ -38,22 +38,22 @@ impl<'a> Extensions<'a> {
}
}
Ok(Some(Self(raw_exts.clone())))
Ok(Self(Some(raw_exts.clone())))
}
None => Ok(None),
None => Ok(Self(None)),
}
}
/// Retrieves the extension identified by the given OID,
/// or None if the extension is not present (or no extensions are present).
pub fn get_extension(&self, oid: &asn1::ObjectIdentifier) -> Option<Extension> {
let mut extensions = self.0.unwrap_read().clone();
extensions.find(|ext| &ext.extn_id == oid)
self.0
.as_ref()
.and_then(|exts| exts.unwrap_read().clone().find(|ext| &ext.extn_id == oid))
}
/// Returns a reference to the underlying extensions.
pub fn as_raw(&self) -> &RawExtensions<'_> {
pub fn as_raw(&self) -> &Option<RawExtensions<'_>> {
&self.0
}
}
@ -245,9 +245,7 @@ mod tests {
let der = asn1::write_single(&extensions).unwrap();
let extensions: Extensions =
Extensions::from_raw_extensions(Some(&asn1::parse_single(&der).unwrap()))
.unwrap()
.unwrap();
Extensions::from_raw_extensions(Some(&asn1::parse_single(&der).unwrap())).unwrap();
assert!(&extensions.get_extension(&BASIC_CONSTRAINTS_OID).is_some());
assert!(&extensions

View file

@ -194,8 +194,17 @@ impl Certificate {
let mut tbs_precert = val.tbs_cert.clone();
// Remove the SCT list extension
match val.tbs_cert.extensions() {
Ok(Some(extensions)) => {
let readable_extensions = extensions.as_raw().unwrap_read().clone();
Ok(extensions) => {
let readable_extensions = match extensions.as_raw() {
Some(raw_exts) => raw_exts.unwrap_read().clone(),
None => {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"Could not find any extensions in TBS certificate",
),
))
}
};
let ext_count = readable_extensions.len();
let filtered_extensions: Vec<Extension<'_>> = readable_extensions
.filter(|x| x.extn_id != oid::PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS_OID)
@ -210,15 +219,11 @@ impl Certificate {
let filtered_extensions: RawExtensions<'_> = Asn1ReadableOrWritable::new_write(
asn1::SequenceOfWriter::new(filtered_extensions),
);
tbs_precert.raw_extensions = Some(filtered_extensions);
let result = asn1::write_single(&tbs_precert)?;
Ok(pyo3::types::PyBytes::new(py, &result))
}
Ok(None) => Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"Could not find any extensions in TBS certificate",
),
)),
Err(oid) => {
let oid_obj = oid_to_py_oid(py, &oid)?;
Err(exceptions::DuplicateExtension::new_err((

View file

@ -410,8 +410,8 @@ pub(crate) fn parse_and_cache_extensions<
let x509_module = py.import(pyo3::intern!(py, "cryptography.x509"))?;
let exts = pyo3::types::PyList::empty(py);
if let Some(extensions) = extensions {
for raw_ext in extensions.as_raw().unwrap_read().clone() {
if let Some(extensions) = extensions.as_raw() {
for raw_ext in extensions.unwrap_read().clone() {
let oid_obj = oid_to_py_oid(py, &raw_ext.extn_id)?;
let extn_value = match parse_ext(&raw_ext.extn_id, raw_ext.extn_value)? {