diff --git a/src/cryptography/hazmat/backends/openssl/decode_asn1.py b/src/cryptography/hazmat/backends/openssl/decode_asn1.py index 96ba4cdbc42c..167acc078743 100644 --- a/src/cryptography/hazmat/backends/openssl/decode_asn1.py +++ b/src/cryptography/hazmat/backends/openssl/decode_asn1.py @@ -768,7 +768,7 @@ def _asn1_string_to_ascii(backend, asn1_string): return _asn1_string_to_bytes(backend, asn1_string).decode("ascii") -def _asn1_string_to_utf8(backend, asn1_string): +def _asn1_string_to_utf8(backend, asn1_string) -> str: buf = backend._ffi.new("unsigned char **") res = backend._lib.ASN1_STRING_to_UTF8(buf, asn1_string) if res == -1: diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py index 2f8612277d8f..6cae016a1c60 100644 --- a/src/cryptography/x509/extensions.py +++ b/src/cryptography/x509/extensions.py @@ -17,6 +17,7 @@ OBJECT_IDENTIFIER, SEQUENCE, ) +from cryptography.hazmat._types import _PUBLIC_KEY_TYPES from cryptography.hazmat.primitives import constant_time, serialization from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey @@ -33,7 +34,7 @@ ) -def _key_identifier_from_public_key(public_key): +def _key_identifier_from_public_key(public_key: _PUBLIC_KEY_TYPES) -> bytes: if isinstance(public_key, RSAPublicKey): data = public_key.public_bytes( serialization.Encoding.DER, @@ -54,7 +55,7 @@ def _key_identifier_from_public_key(public_key): reader = DERReader(serialized) with reader.read_single_element(SEQUENCE) as public_key_info: algorithm = public_key_info.read_element(SEQUENCE) - public_key = public_key_info.read_element(BIT_STRING) + public_key_data = public_key_info.read_element(BIT_STRING) # Double-check the algorithm structure. with algorithm: @@ -65,10 +66,10 @@ def _key_identifier_from_public_key(public_key): # BIT STRING contents begin with the number of padding bytes added. It # must be zero for SubjectPublicKeyInfo structures. - if public_key.read_byte() != 0: + if public_key_data.read_byte() != 0: raise ValueError("Invalid public key encoding") - data = public_key.data + data = public_key_data.data return hashlib.sha1(data).digest() @@ -110,14 +111,14 @@ class Extensions(object): def __init__(self, extensions: typing.List["Extension"]): self._extensions = extensions - def get_extension_for_oid(self, oid): + def get_extension_for_oid(self, oid: ObjectIdentifier) -> "Extension": for ext in self: if ext.oid == oid: return ext raise ExtensionNotFound("No {} extension was found".format(oid), oid) - def get_extension_for_class(self, extclass): + def get_extension_for_class(self, extclass) -> "Extension": if extclass is UnrecognizedExtension: raise TypeError( "UnrecognizedExtension can't be used with " @@ -142,7 +143,7 @@ def __repr__(self): class CRLNumber(ExtensionType): oid = ExtensionOID.CRL_NUMBER - def __init__(self, crl_number): + def __init__(self, crl_number: int): if not isinstance(crl_number, int): raise TypeError("crl_number must be an integer") @@ -171,9 +172,9 @@ class AuthorityKeyIdentifier(ExtensionType): def __init__( self, - key_identifier, - authority_cert_issuer, - authority_cert_serial_number, + key_identifier: typing.Optional[bytes], + authority_cert_issuer: typing.Optional[typing.Iterable[GeneralName]], + authority_cert_serial_number: typing.Optional[int], ): if (authority_cert_issuer is None) != ( authority_cert_serial_number is None @@ -203,7 +204,9 @@ def __init__( self._authority_cert_serial_number = authority_cert_serial_number @classmethod - def from_issuer_public_key(cls, public_key): + def from_issuer_public_key( + cls, public_key: _PUBLIC_KEY_TYPES + ) -> "AuthorityKeyIdentifier": digest = _key_identifier_from_public_key(public_key) return cls( key_identifier=digest, @@ -212,7 +215,9 @@ def from_issuer_public_key(cls, public_key): ) @classmethod - def from_issuer_subject_key_identifier(cls, ski): + def from_issuer_subject_key_identifier( + cls, ski: "SubjectKeyIdentifier" + ) -> "AuthorityKeyIdentifier": return cls( key_identifier=ski.digest, authority_cert_issuer=None, @@ -260,11 +265,13 @@ def __hash__(self): class SubjectKeyIdentifier(ExtensionType): oid = ExtensionOID.SUBJECT_KEY_IDENTIFIER - def __init__(self, digest): + def __init__(self, digest: bytes): self._digest = digest @classmethod - def from_public_key(cls, public_key): + def from_public_key( + cls, public_key: _PUBLIC_KEY_TYPES + ) -> "SubjectKeyIdentifier": return cls(_key_identifier_from_public_key(public_key)) digest = utils.read_only_property("_digest") @@ -288,7 +295,7 @@ def __hash__(self): class AuthorityInformationAccess(ExtensionType): oid = ExtensionOID.AUTHORITY_INFORMATION_ACCESS - def __init__(self, descriptions): + def __init__(self, descriptions: typing.Iterable["AccessDescription"]): descriptions = list(descriptions) if not all(isinstance(x, AccessDescription) for x in descriptions): raise TypeError( @@ -319,7 +326,7 @@ def __hash__(self): class SubjectInformationAccess(ExtensionType): oid = ExtensionOID.SUBJECT_INFORMATION_ACCESS - def __init__(self, descriptions): + def __init__(self, descriptions: typing.Iterable["AccessDescription"]): descriptions = list(descriptions) if not all(isinstance(x, AccessDescription) for x in descriptions): raise TypeError( @@ -348,7 +355,9 @@ def __hash__(self): class AccessDescription(object): - def __init__(self, access_method, access_location): + def __init__( + self, access_method: ObjectIdentifier, access_location: GeneralName + ): if not isinstance(access_method, ObjectIdentifier): raise TypeError("access_method must be an ObjectIdentifier") @@ -386,7 +395,7 @@ def __hash__(self): class BasicConstraints(ExtensionType): oid = ExtensionOID.BASIC_CONSTRAINTS - def __init__(self, ca, path_length): + def __init__(self, ca: bool, path_length: typing.Optional[int]): if not isinstance(ca, bool): raise TypeError("ca must be a boolean value") @@ -427,7 +436,7 @@ def __hash__(self): class DeltaCRLIndicator(ExtensionType): oid = ExtensionOID.DELTA_CRL_INDICATOR - def __init__(self, crl_number): + def __init__(self, crl_number: int): if not isinstance(crl_number, int): raise TypeError("crl_number must be an integer") @@ -454,7 +463,9 @@ def __repr__(self): class CRLDistributionPoints(ExtensionType): oid = ExtensionOID.CRL_DISTRIBUTION_POINTS - def __init__(self, distribution_points): + def __init__( + self, distribution_points: typing.Iterable["DistributionPoint"] + ): distribution_points = list(distribution_points) if not all( isinstance(x, DistributionPoint) for x in distribution_points @@ -489,7 +500,9 @@ def __hash__(self): class FreshestCRL(ExtensionType): oid = ExtensionOID.FRESHEST_CRL - def __init__(self, distribution_points): + def __init__( + self, distribution_points: typing.Iterable["DistributionPoint"] + ): distribution_points = list(distribution_points) if not all( isinstance(x, DistributionPoint) for x in distribution_points @@ -522,7 +535,13 @@ def __hash__(self): class DistributionPoint(object): - def __init__(self, full_name, relative_name, reasons, crl_issuer): + def __init__( + self, + full_name: typing.Optional[typing.Iterable[GeneralName]], + relative_name: typing.Optional[RelativeDistinguishedName], + reasons: typing.Optional[typing.FrozenSet["ReasonFlags"]], + crl_issuer: typing.Optional[typing.Iterable[GeneralName]], + ): if full_name and relative_name: raise ValueError( "You cannot provide both full_name and relative_name, at " @@ -631,7 +650,11 @@ class ReasonFlags(Enum): class PolicyConstraints(ExtensionType): oid = ExtensionOID.POLICY_CONSTRAINTS - def __init__(self, require_explicit_policy, inhibit_policy_mapping): + def __init__( + self, + require_explicit_policy: typing.Optional[int], + inhibit_policy_mapping: typing.Optional[int], + ): if require_explicit_policy is not None and not isinstance( require_explicit_policy, int ): @@ -691,7 +714,7 @@ def __hash__(self): class CertificatePolicies(ExtensionType): oid = ExtensionOID.CERTIFICATE_POLICIES - def __init__(self, policies): + def __init__(self, policies: typing.Iterable["PolicyInformation"]): policies = list(policies) if not all(isinstance(x, PolicyInformation) for x in policies): raise TypeError( @@ -720,7 +743,13 @@ def __hash__(self): class PolicyInformation(object): - def __init__(self, policy_identifier, policy_qualifiers): + def __init__( + self, + policy_identifier: ObjectIdentifier, + policy_qualifiers: typing.Optional[ + typing.Iterable[typing.Union[str, "UserNotice"]] + ], + ): if not isinstance(policy_identifier, ObjectIdentifier): raise TypeError("policy_identifier must be an ObjectIdentifier") @@ -769,7 +798,11 @@ def __hash__(self): class UserNotice(object): - def __init__(self, notice_reference, explicit_text): + def __init__( + self, + notice_reference: typing.Optional["NoticeReference"], + explicit_text: typing.Optional[str], + ): if notice_reference and not isinstance( notice_reference, NoticeReference ): @@ -806,7 +839,11 @@ def __hash__(self): class NoticeReference(object): - def __init__(self, organization, notice_numbers): + def __init__( + self, + organization: typing.Optional[str], + notice_numbers: typing.Iterable[int], + ): self._organization = organization notice_numbers = list(notice_numbers) if not all(isinstance(x, int) for x in notice_numbers): @@ -842,7 +879,7 @@ def __hash__(self): class ExtendedKeyUsage(ExtensionType): oid = ExtensionOID.EXTENDED_KEY_USAGE - def __init__(self, usages): + def __init__(self, usages: typing.Iterable[ObjectIdentifier]): usages = list(usages) if not all(isinstance(x, ObjectIdentifier) for x in usages): raise TypeError( @@ -910,7 +947,7 @@ def __repr__(self): class TLSFeature(ExtensionType): oid = ExtensionOID.TLS_FEATURE - def __init__(self, features): + def __init__(self, features: typing.Iterable["TLSFeatureType"]): features = list(features) if ( not all(isinstance(x, TLSFeatureType) for x in features) @@ -958,7 +995,7 @@ class TLSFeatureType(Enum): class InhibitAnyPolicy(ExtensionType): oid = ExtensionOID.INHIBIT_ANY_POLICY - def __init__(self, skip_certs): + def __init__(self, skip_certs: int): if not isinstance(skip_certs, int): raise TypeError("skip_certs must be an integer") diff --git a/tests/x509/test_x509_ext.py b/tests/x509/test_x509_ext.py index 938357f2d514..b8f226d5f848 100644 --- a/tests/x509/test_x509_ext.py +++ b/tests/x509/test_x509_ext.py @@ -7,6 +7,7 @@ import datetime import ipaddress import os +import typing import pretend @@ -138,7 +139,7 @@ def test_hash(self): class TestTLSFeature(object): def test_not_enum_type(self): with pytest.raises(TypeError): - x509.TLSFeature([3]) + x509.TLSFeature([3]) # type:ignore[list-item] def test_empty_list(self): with pytest.raises(TypeError): @@ -346,7 +347,7 @@ def test_repr(self): class TestDeltaCRLIndicator(object): def test_not_int(self): with pytest.raises(TypeError): - x509.DeltaCRLIndicator("notanint") + x509.DeltaCRLIndicator("notanint") # type:ignore[arg-type] def test_eq(self): delta1 = x509.DeltaCRLIndicator(1) @@ -404,11 +405,13 @@ def test_hash(self): class TestNoticeReference(object): def test_notice_numbers_not_all_int(self): with pytest.raises(TypeError): - x509.NoticeReference("org", [1, 2, "three"]) + x509.NoticeReference( + "org", [1, 2, "three"] # type:ignore[list-item] + ) def test_notice_numbers_none(self): with pytest.raises(TypeError): - x509.NoticeReference("org", None) + x509.NoticeReference("org", None) # type:ignore[arg-type] def test_iter_input(self): numbers = [1, 3, 4] @@ -447,7 +450,7 @@ def test_hash(self): class TestUserNotice(object): def test_notice_reference_invalid(self): with pytest.raises(TypeError): - x509.UserNotice("invalid", None) + x509.UserNotice("invalid", None) # type:ignore[arg-type] def test_notice_reference_none(self): un = x509.UserNotice(None, "text") @@ -491,7 +494,7 @@ def test_hash(self): class TestPolicyInformation(object): def test_invalid_policy_identifier(self): with pytest.raises(TypeError): - x509.PolicyInformation("notanoid", None) + x509.PolicyInformation("notanoid", None) # type:ignore[arg-type] def test_none_policy_qualifiers(self): pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), None) @@ -506,7 +509,10 @@ def test_policy_qualifiers(self): def test_invalid_policy_identifiers(self): with pytest.raises(TypeError): - x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), [1, 2]) + x509.PolicyInformation( + x509.ObjectIdentifier("1.2.3"), + [1, 2], # type:ignore[list-item] + ) def test_iter_input(self): qual = ["foo", "bar"] @@ -514,7 +520,10 @@ def test_iter_input(self): assert list(pi.policy_qualifiers) == qual def test_repr(self): - pq = ["string", x509.UserNotice(None, "hi")] + pq: typing.List[typing.Union[str, x509.UserNotice]] = [ + "string", + x509.UserNotice(None, "hi"), + ] pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), pq) assert repr(pi) == ( "