Added typing for a bunch of random stuff (#5743)

This commit is contained in:
Alex Gaynor 2021-02-04 18:43:41 -05:00 committed by GitHub
parent fb3c73a0b1
commit 0b41cb2b61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 65 additions and 34 deletions

View file

@ -14,7 +14,7 @@ from cryptography.hazmat.primitives import serialization
_MIN_MODULUS_SIZE = 512
def generate_parameters(generator, key_size, backend=None):
def generate_parameters(generator, key_size, backend=None) -> "DHParameters":
backend = _get_backend(backend)
return backend.generate_dh_parameters(generator, key_size)
@ -83,7 +83,7 @@ class DHPublicNumbers(object):
def __ne__(self, other):
return not self == other
def public_key(self, backend=None):
def public_key(self, backend=None) -> "DHPublicKey":
backend = _get_backend(backend)
return backend.load_dh_public_numbers(self)
@ -136,7 +136,7 @@ class DHParameters(metaclass=abc.ABCMeta):
self,
encoding: "serialization.Encoding",
format: "serialization.ParameterFormat",
):
) -> bytes:
"""
Returns the parameters serialized as bytes.
"""
@ -222,7 +222,7 @@ class DHPrivateKey(metaclass=abc.ABCMeta):
encoding: "serialization.Encoding",
format: "serialization.PrivateFormat",
encryption_algorithm: "serialization.KeySerializationEncryption",
):
) -> bytes:
"""
Returns the key serialized as bytes.
"""

View file

@ -509,7 +509,7 @@ _OID_TO_CURVE = {
}
def get_curve_for_oid(oid):
def get_curve_for_oid(oid: ObjectIdentifier) -> typing.Type[EllipticCurve]:
try:
return _OID_TO_CURVE[oid]
except KeyError:

View file

@ -37,7 +37,7 @@ class Ed25519PublicKey(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def verify(self, signature: bytes, data: bytes):
def verify(self, signature: bytes, data: bytes) -> None:
"""
Verify the signature.
"""

View file

@ -4,6 +4,7 @@
import abc
import typing
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
@ -11,7 +12,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa
class AsymmetricPadding(metaclass=abc.ABCMeta):
@abc.abstractproperty
def name(self):
def name(self) -> str:
"""
A string naming this padding (e.g. "PSS", "PKCS1").
"""
@ -43,7 +44,12 @@ class PSS(AsymmetricPadding):
class OAEP(AsymmetricPadding):
name = "EME-OAEP"
def __init__(self, mgf, algorithm, label):
def __init__(
self,
mgf: "MGF1",
algorithm: hashes.HashAlgorithm,
label: typing.Optional[bytes],
):
if not isinstance(algorithm, hashes.HashAlgorithm):
raise TypeError("Expected instance of hashes.HashAlgorithm.")
@ -55,14 +61,17 @@ class OAEP(AsymmetricPadding):
class MGF1(object):
MAX_LENGTH = object()
def __init__(self, algorithm):
def __init__(self, algorithm: hashes.HashAlgorithm):
if not isinstance(algorithm, hashes.HashAlgorithm):
raise TypeError("Expected instance of hashes.HashAlgorithm.")
self._algorithm = algorithm
def calculate_max_pss_salt_length(key, hash_algorithm):
def calculate_max_pss_salt_length(
key: typing.Union["rsa.RSAPrivateKey", "rsa.RSAPublicKey"],
hash_algorithm: hashes.HashAlgorithm,
) -> int:
if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
raise TypeError("key must be an RSA public or private key")
# bit length - 1 per RFC 3447

View file

@ -160,7 +160,7 @@ def generate_private_key(
return backend.generate_rsa_private_key(public_exponent, key_size)
def _verify_rsa_parameters(public_exponent: int, key_size: int):
def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
if public_exponent not in (3, 65537):
raise ValueError(
"public_exponent must be either 3 (for legacy compatibility) or "
@ -180,7 +180,7 @@ def _check_private_key_components(
iqmp: int,
public_exponent: int,
modulus: int,
):
) -> None:
if modulus < 3:
raise ValueError("modulus must be >= 3.")
@ -218,7 +218,7 @@ def _check_private_key_components(
raise ValueError("p*q must equal modulus.")
def _check_public_key_components(e: int, n: int):
def _check_public_key_components(e: int, n: int) -> None:
if n < 3:
raise ValueError("n must be >= 3.")
@ -229,7 +229,7 @@ def _check_public_key_components(e: int, n: int):
raise ValueError("e must be odd.")
def _modinv(e: int, m: int):
def _modinv(e: int, m: int) -> int:
"""
Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
"""
@ -242,14 +242,14 @@ def _modinv(e: int, m: int):
return x1 % m
def rsa_crt_iqmp(p: int, q: int):
def rsa_crt_iqmp(p: int, q: int) -> int:
"""
Compute the CRT (q ** -1) % p value from RSA primes p and q.
"""
return _modinv(q, p)
def rsa_crt_dmp1(private_exponent: int, p: int):
def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
"""
Compute the CRT private_exponent % (p - 1) value from the RSA
private_exponent (d) and p.
@ -257,7 +257,7 @@ def rsa_crt_dmp1(private_exponent: int, p: int):
return private_exponent % (p - 1)
def rsa_crt_dmq1(private_exponent: int, q: int):
def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
"""
Compute the CRT private_exponent % (q - 1) value from the RSA
private_exponent (d) and q.
@ -271,7 +271,9 @@ def rsa_crt_dmq1(private_exponent: int, q: int):
_MAX_RECOVERY_ATTEMPTS = 1000
def rsa_recover_prime_factors(n: int, e: int, d: int):
def rsa_recover_prime_factors(
n: int, e: int, d: int
) -> typing.Tuple[int, int]:
"""
Compute factors p and q from the private exponent d. We assume that n has
no more than two factors. This function is adapted from code in PyCrypto.

View file

@ -139,7 +139,7 @@ class ConcatKDFHMAC(KeyDerivationFunction):
def _hmac(self) -> hmac.HMAC:
return hmac.HMAC(self._salt, self._algorithm, self._backend)
def derive(self, key_material: bytes):
def derive(self, key_material: bytes) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True

View file

@ -10,13 +10,6 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa
def load_key_and_certificates(
data: bytes, password: typing.Optional[bytes], backend=None
):
backend = _get_backend(backend)
return backend.load_key_and_certificates_from_pkcs12(data, password)
_ALLOWED_PKCS12_TYPES = typing.Union[
rsa.RSAPrivateKey,
dsa.DSAPrivateKey,
@ -24,13 +17,24 @@ _ALLOWED_PKCS12_TYPES = typing.Union[
]
def load_key_and_certificates(
data: bytes, password: typing.Optional[bytes], backend=None
) -> typing.Tuple[
typing.Optional[_ALLOWED_PKCS12_TYPES],
typing.Optional[x509.Certificate],
typing.List[x509.Certificate],
]:
backend = _get_backend(backend)
return backend.load_key_and_certificates_from_pkcs12(data, password)
def serialize_key_and_certificates(
name: typing.Optional[bytes],
key: typing.Optional[_ALLOWED_PKCS12_TYPES],
cert: typing.Optional[x509.Certificate],
cas: typing.Optional[typing.Iterable[x509.Certificate]],
encryption_algorithm: serialization.KeySerializationEncryption,
):
) -> bytes:
if key is not None and not isinstance(
key,
(

View file

@ -407,7 +407,9 @@ class TestOpenSSLRSA(object):
assert (
backend.rsa_padding_supported(
padding.OAEP(
mgf=DummyMGF(), algorithm=hashes.SHA1(), label=None
mgf=DummyMGF(), # type: ignore[arg-type]
algorithm=hashes.SHA1(),
label=None,
),
)
is False

View file

@ -186,6 +186,7 @@ class TestPKCS12Creation(object):
p12, password, backend
)
assert parsed_cert == cert
assert parsed_key is not None
assert parsed_key.private_numbers() == key.private_numbers()
assert parsed_more_certs == []
@ -204,6 +205,7 @@ class TestPKCS12Creation(object):
p12, None, backend
)
assert parsed_cert == cert
assert parsed_key is not None
assert parsed_key.private_numbers() == key.private_numbers()
assert parsed_more_certs == [cert2, cert3]
@ -247,6 +249,7 @@ class TestPKCS12Creation(object):
p12, None, backend
)
assert parsed_cert is None
assert parsed_key is not None
assert parsed_key.private_numbers() == key.private_numbers()
assert parsed_more_certs == []

View file

@ -609,7 +609,8 @@ class TestRSASignature(object):
private_key.sign(
b"msg",
padding.PSS(
mgf=DummyMGF(), salt_length=padding.PSS.MAX_LENGTH
mgf=DummyMGF(), # type: ignore[arg-type]
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA1(),
)
@ -1455,7 +1456,9 @@ class TestRSAPKCS1Verification(object):
class TestPSS(object):
def test_calculate_max_pss_salt_length(self):
with pytest.raises(TypeError):
padding.calculate_max_pss_salt_length(object(), hashes.SHA256())
padding.calculate_max_pss_salt_length(
object(), hashes.SHA256() # type:ignore[arg-type]
)
def test_invalid_salt_length_not_integer(self):
with pytest.raises(TypeError):
@ -1486,7 +1489,7 @@ class TestPSS(object):
class TestMGF1(object):
def test_invalid_hash_algorithm(self):
with pytest.raises(TypeError):
padding.MGF1(b"not_a_hash")
padding.MGF1(b"not_a_hash") # type:ignore[arg-type]
def test_valid_mgf1_parameters(self):
algorithm = hashes.SHA1()
@ -1498,7 +1501,9 @@ class TestOAEP(object):
def test_invalid_algorithm(self):
mgf = padding.MGF1(hashes.SHA1())
with pytest.raises(TypeError):
padding.OAEP(mgf=mgf, algorithm=b"", label=None)
padding.OAEP(
mgf=mgf, algorithm=b"", label=None # type:ignore[arg-type]
)
@pytest.mark.requires_backend_interface(interface=RSABackend)
@ -1738,7 +1743,9 @@ class TestRSADecryption(object):
private_key.decrypt(
b"0" * 64,
padding.OAEP(
mgf=DummyMGF(), algorithm=hashes.SHA1(), label=None
mgf=DummyMGF(), # type: ignore[arg-type]
algorithm=hashes.SHA1(),
label=None,
),
)
@ -1924,7 +1931,9 @@ class TestRSAEncryption(object):
public_key.encrypt(
b"ciphertext",
padding.OAEP(
mgf=DummyMGF(), algorithm=hashes.SHA1(), label=None
mgf=DummyMGF(), # type: ignore[arg-type]
algorithm=hashes.SHA1(),
label=None,
),
)

View file

@ -194,6 +194,8 @@ def test_rsa_oaep_encryption(backend, wycheproof):
assert isinstance(key, rsa.RSAPrivateKey)
digest = _DIGESTS[wycheproof.testgroup["sha"]]
mgf_digest = _DIGESTS[wycheproof.testgroup["mgfSha"]]
assert digest is not None
assert mgf_digest is not None
padding_algo = padding.OAEP(
mgf=padding.MGF1(algorithm=mgf_digest),