Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

part 2 of typing x509 extensions #5815

Merged
merged 1 commit into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cryptography/hazmat/backends/openssl/decode_asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def _asn1_string_to_ascii(backend, asn1_string):
return _asn1_string_to_bytes(backend, asn1_string).decode("ascii")


def _asn1_string_to_utf8(backend, asn1_string):
def _asn1_string_to_utf8(backend, asn1_string) -> str:
buf = backend._ffi.new("unsigned char **")
res = backend._lib.ASN1_STRING_to_UTF8(buf, asn1_string)
if res == -1:
Expand Down
97 changes: 67 additions & 30 deletions src/cryptography/x509/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OBJECT_IDENTIFIER,
SEQUENCE,
)
from cryptography.hazmat._types import _PUBLIC_KEY_TYPES
from cryptography.hazmat.primitives import constant_time, serialization
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
Expand All @@ -33,7 +34,7 @@
)


def _key_identifier_from_public_key(public_key):
def _key_identifier_from_public_key(public_key: _PUBLIC_KEY_TYPES) -> bytes:
if isinstance(public_key, RSAPublicKey):
data = public_key.public_bytes(
serialization.Encoding.DER,
Expand All @@ -54,7 +55,7 @@ def _key_identifier_from_public_key(public_key):
reader = DERReader(serialized)
with reader.read_single_element(SEQUENCE) as public_key_info:
algorithm = public_key_info.read_element(SEQUENCE)
public_key = public_key_info.read_element(BIT_STRING)
public_key_data = public_key_info.read_element(BIT_STRING)

# Double-check the algorithm structure.
with algorithm:
Expand All @@ -65,10 +66,10 @@ def _key_identifier_from_public_key(public_key):

# BIT STRING contents begin with the number of padding bytes added. It
# must be zero for SubjectPublicKeyInfo structures.
if public_key.read_byte() != 0:
if public_key_data.read_byte() != 0:
raise ValueError("Invalid public key encoding")

data = public_key.data
data = public_key_data.data

return hashlib.sha1(data).digest()

Expand Down Expand Up @@ -110,14 +111,14 @@ class Extensions(object):
def __init__(self, extensions: typing.List["Extension"]):
self._extensions = extensions

def get_extension_for_oid(self, oid):
def get_extension_for_oid(self, oid: ObjectIdentifier) -> "Extension":
for ext in self:
if ext.oid == oid:
return ext

raise ExtensionNotFound("No {} extension was found".format(oid), oid)

def get_extension_for_class(self, extclass):
def get_extension_for_class(self, extclass) -> "Extension":
if extclass is UnrecognizedExtension:
raise TypeError(
"UnrecognizedExtension can't be used with "
Expand All @@ -142,7 +143,7 @@ def __repr__(self):
class CRLNumber(ExtensionType):
oid = ExtensionOID.CRL_NUMBER

def __init__(self, crl_number):
def __init__(self, crl_number: int):
if not isinstance(crl_number, int):
raise TypeError("crl_number must be an integer")

Expand Down Expand Up @@ -171,9 +172,9 @@ class AuthorityKeyIdentifier(ExtensionType):

def __init__(
self,
key_identifier,
authority_cert_issuer,
authority_cert_serial_number,
key_identifier: typing.Optional[bytes],
authority_cert_issuer: typing.Optional[typing.Iterable[GeneralName]],
authority_cert_serial_number: typing.Optional[int],
):
if (authority_cert_issuer is None) != (
authority_cert_serial_number is None
Expand Down Expand Up @@ -203,7 +204,9 @@ def __init__(
self._authority_cert_serial_number = authority_cert_serial_number

@classmethod
def from_issuer_public_key(cls, public_key):
def from_issuer_public_key(
cls, public_key: _PUBLIC_KEY_TYPES
) -> "AuthorityKeyIdentifier":
digest = _key_identifier_from_public_key(public_key)
return cls(
key_identifier=digest,
Expand All @@ -212,7 +215,9 @@ def from_issuer_public_key(cls, public_key):
)

@classmethod
def from_issuer_subject_key_identifier(cls, ski):
def from_issuer_subject_key_identifier(
cls, ski: "SubjectKeyIdentifier"
) -> "AuthorityKeyIdentifier":
return cls(
key_identifier=ski.digest,
authority_cert_issuer=None,
Expand Down Expand Up @@ -260,11 +265,13 @@ def __hash__(self):
class SubjectKeyIdentifier(ExtensionType):
oid = ExtensionOID.SUBJECT_KEY_IDENTIFIER

def __init__(self, digest):
def __init__(self, digest: bytes):
self._digest = digest

@classmethod
def from_public_key(cls, public_key):
def from_public_key(
cls, public_key: _PUBLIC_KEY_TYPES
) -> "SubjectKeyIdentifier":
return cls(_key_identifier_from_public_key(public_key))

digest = utils.read_only_property("_digest")
Expand All @@ -288,7 +295,7 @@ def __hash__(self):
class AuthorityInformationAccess(ExtensionType):
oid = ExtensionOID.AUTHORITY_INFORMATION_ACCESS

def __init__(self, descriptions):
def __init__(self, descriptions: typing.Iterable["AccessDescription"]):
descriptions = list(descriptions)
if not all(isinstance(x, AccessDescription) for x in descriptions):
raise TypeError(
Expand Down Expand Up @@ -319,7 +326,7 @@ def __hash__(self):
class SubjectInformationAccess(ExtensionType):
oid = ExtensionOID.SUBJECT_INFORMATION_ACCESS

def __init__(self, descriptions):
def __init__(self, descriptions: typing.Iterable["AccessDescription"]):
descriptions = list(descriptions)
if not all(isinstance(x, AccessDescription) for x in descriptions):
raise TypeError(
Expand Down Expand Up @@ -348,7 +355,9 @@ def __hash__(self):


class AccessDescription(object):
def __init__(self, access_method, access_location):
def __init__(
self, access_method: ObjectIdentifier, access_location: GeneralName
):
if not isinstance(access_method, ObjectIdentifier):
raise TypeError("access_method must be an ObjectIdentifier")

Expand Down Expand Up @@ -386,7 +395,7 @@ def __hash__(self):
class BasicConstraints(ExtensionType):
oid = ExtensionOID.BASIC_CONSTRAINTS

def __init__(self, ca, path_length):
def __init__(self, ca: bool, path_length: typing.Optional[int]):
if not isinstance(ca, bool):
raise TypeError("ca must be a boolean value")

Expand Down Expand Up @@ -427,7 +436,7 @@ def __hash__(self):
class DeltaCRLIndicator(ExtensionType):
oid = ExtensionOID.DELTA_CRL_INDICATOR

def __init__(self, crl_number):
def __init__(self, crl_number: int):
if not isinstance(crl_number, int):
raise TypeError("crl_number must be an integer")

Expand All @@ -454,7 +463,9 @@ def __repr__(self):
class CRLDistributionPoints(ExtensionType):
oid = ExtensionOID.CRL_DISTRIBUTION_POINTS

def __init__(self, distribution_points):
def __init__(
self, distribution_points: typing.Iterable["DistributionPoint"]
):
distribution_points = list(distribution_points)
if not all(
isinstance(x, DistributionPoint) for x in distribution_points
Expand Down Expand Up @@ -489,7 +500,9 @@ def __hash__(self):
class FreshestCRL(ExtensionType):
oid = ExtensionOID.FRESHEST_CRL

def __init__(self, distribution_points):
def __init__(
self, distribution_points: typing.Iterable["DistributionPoint"]
):
distribution_points = list(distribution_points)
if not all(
isinstance(x, DistributionPoint) for x in distribution_points
Expand Down Expand Up @@ -522,7 +535,13 @@ def __hash__(self):


class DistributionPoint(object):
def __init__(self, full_name, relative_name, reasons, crl_issuer):
def __init__(
self,
full_name: typing.Optional[typing.Iterable[GeneralName]],
relative_name: typing.Optional[RelativeDistinguishedName],
reasons: typing.Optional[typing.FrozenSet["ReasonFlags"]],
crl_issuer: typing.Optional[typing.Iterable[GeneralName]],
):
if full_name and relative_name:
raise ValueError(
"You cannot provide both full_name and relative_name, at "
Expand Down Expand Up @@ -631,7 +650,11 @@ class ReasonFlags(Enum):
class PolicyConstraints(ExtensionType):
oid = ExtensionOID.POLICY_CONSTRAINTS

def __init__(self, require_explicit_policy, inhibit_policy_mapping):
def __init__(
self,
require_explicit_policy: typing.Optional[int],
inhibit_policy_mapping: typing.Optional[int],
):
if require_explicit_policy is not None and not isinstance(
require_explicit_policy, int
):
Expand Down Expand Up @@ -691,7 +714,7 @@ def __hash__(self):
class CertificatePolicies(ExtensionType):
oid = ExtensionOID.CERTIFICATE_POLICIES

def __init__(self, policies):
def __init__(self, policies: typing.Iterable["PolicyInformation"]):
policies = list(policies)
if not all(isinstance(x, PolicyInformation) for x in policies):
raise TypeError(
Expand Down Expand Up @@ -720,7 +743,13 @@ def __hash__(self):


class PolicyInformation(object):
def __init__(self, policy_identifier, policy_qualifiers):
def __init__(
self,
policy_identifier: ObjectIdentifier,
policy_qualifiers: typing.Optional[
typing.Iterable[typing.Union[str, "UserNotice"]]
],
):
if not isinstance(policy_identifier, ObjectIdentifier):
raise TypeError("policy_identifier must be an ObjectIdentifier")

Expand Down Expand Up @@ -769,7 +798,11 @@ def __hash__(self):


class UserNotice(object):
def __init__(self, notice_reference, explicit_text):
def __init__(
self,
notice_reference: typing.Optional["NoticeReference"],
explicit_text: typing.Optional[str],
):
if notice_reference and not isinstance(
notice_reference, NoticeReference
):
Expand Down Expand Up @@ -806,7 +839,11 @@ def __hash__(self):


class NoticeReference(object):
def __init__(self, organization, notice_numbers):
def __init__(
self,
organization: typing.Optional[str],
notice_numbers: typing.Iterable[int],
):
self._organization = organization
notice_numbers = list(notice_numbers)
if not all(isinstance(x, int) for x in notice_numbers):
Expand Down Expand Up @@ -842,7 +879,7 @@ def __hash__(self):
class ExtendedKeyUsage(ExtensionType):
oid = ExtensionOID.EXTENDED_KEY_USAGE

def __init__(self, usages):
def __init__(self, usages: typing.Iterable[ObjectIdentifier]):
usages = list(usages)
if not all(isinstance(x, ObjectIdentifier) for x in usages):
raise TypeError(
Expand Down Expand Up @@ -910,7 +947,7 @@ def __repr__(self):
class TLSFeature(ExtensionType):
oid = ExtensionOID.TLS_FEATURE

def __init__(self, features):
def __init__(self, features: typing.Iterable["TLSFeatureType"]):
features = list(features)
if (
not all(isinstance(x, TLSFeatureType) for x in features)
Expand Down Expand Up @@ -958,7 +995,7 @@ class TLSFeatureType(Enum):
class InhibitAnyPolicy(ExtensionType):
oid = ExtensionOID.INHIBIT_ANY_POLICY

def __init__(self, skip_certs):
def __init__(self, skip_certs: int):
if not isinstance(skip_certs, int):
raise TypeError("skip_certs must be an integer")

Expand Down
Loading