From 87733e0cf0c0d161c6d07803c4b462f474a994f3 Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Sat, 28 Dec 2024 20:28:30 +0100 Subject: [PATCH] fix typents and missing docstrings --- ca/ca/settings.py | 2 +- ca/django_ca/api/auth.py | 3 ++- ca/django_ca/api/endpoints.py | 5 ++--- ca/django_ca/managers.py | 4 +++- ca/django_ca/models.py | 5 +++-- ca/django_ca/querysets.py | 9 ++++++++- ca/django_ca/tests/test_utils.py | 6 +++--- ca/django_ca/views.py | 29 ++++++++++++++--------------- 8 files changed, 36 insertions(+), 27 deletions(-) diff --git a/ca/ca/settings.py b/ca/ca/settings.py index c69525046..4066912eb 100644 --- a/ca/ca/settings.py +++ b/ca/ca/settings.py @@ -300,7 +300,7 @@ }, "django.request": { "handlers": ["mail_admins"], - "level": "ERROR", + "level": "DEBUG", "propagate": True, }, }, diff --git a/ca/django_ca/api/auth.py b/ca/django_ca/api/auth.py index e6ff5a809..33b399219 100644 --- a/ca/django_ca/api/auth.py +++ b/ca/django_ca/api/auth.py @@ -44,7 +44,8 @@ def __init__(self, permission: str) -> None: # TODO: async implement call against warnings? - async def authenticate( + # PYLINT NOTE: documented in django-ninja docs that this can be async + async def authenticate( # pylint: disable=invalid-overridden-method self, request: HttpRequest, username: str, password: str ) -> Union[Literal[False], AbstractUser]: user = await User.objects.aget(username=username) diff --git a/ca/django_ca/api/endpoints.py b/ca/django_ca/api/endpoints.py index cfd5d7ea4..a4e34bf49 100644 --- a/ca/django_ca/api/endpoints.py +++ b/ca/django_ca/api/endpoints.py @@ -41,7 +41,6 @@ from django_ca.api.utils import get_certificate_authority from django_ca.models import Certificate, CertificateAuthority, CertificateOrder from django_ca.pydantic.messages import SignCertificateMessage -from django_ca.querysets import CertificateAuthorityQuerySet, CertificateQuerySet from django_ca.tasks import api_sign_certificate as sign_certificate_task, run_task api = NinjaAPI(title="django-ca API", version=__version__, urls_namespace="django_ca:api") @@ -63,7 +62,7 @@ def forbidden(request: WSGIRequest, exc: Exception) -> HttpResponse: # pylint: async def list_certificate_authorities( request: WSGIRequest, filters: CertificateAuthorityFilterSchema = Query(...), # type: ignore[type-arg] # noqa: B008 -) -> CertificateAuthorityQuerySet: +) -> list[CertificateAuthority]: """Retrieve a list of currently usable certificate authorities.""" qs = CertificateAuthority.objects.enabled().exclude(api_enabled=False) if filters.expired is False: @@ -187,7 +186,7 @@ async def list_certificates( request: WSGIRequest, serial: str, filters: CertificateFilterSchema = Query(...), # type: ignore[type-arg] # noqa: B008 -) -> CertificateQuerySet: +) -> list[Certificate]: """Retrieve certificates signed by the certificate authority named by `serial`.""" ca = await get_certificate_authority(serial, expired=True) # You can list certificates of expired CAs qs = Certificate.objects.filter(ca=ca) diff --git a/ca/django_ca/managers.py b/ca/django_ca/managers.py index 017e9837b..2be4e8db3 100644 --- a/ca/django_ca/managers.py +++ b/ca/django_ca/managers.py @@ -121,6 +121,8 @@ def for_certificate_revocation_list( def get_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ... + async def aget_by_serial_or_cn(self, identifier: str) -> X509CertMixinTypeVar: ... + def valid(self) -> QuerySetTypeVar: ... @@ -882,7 +884,7 @@ def create_certificate_revocation_list( ) # Create database object (as late as possible so any exception above would not hit the database) - obj: CertificateAuthority = self.create( + obj: CertificateRevocationList = self.create( ca=ca, number=Coalesce(models.Subquery(number_subquery, default=1), 0), only_contains_ca_certs=only_contains_ca_certs, diff --git a/ca/django_ca/models.py b/ca/django_ca/models.py index 246473c73..1b3d1e749 100644 --- a/ca/django_ca/models.py +++ b/ca/django_ca/models.py @@ -454,7 +454,7 @@ def get_revocation(self) -> x509.RevokedCertificate: @contextmanager def _revoke( self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None - ) -> None: + ) -> Iterator[None]: pre_revoke_cert.send(sender=self.__class__, cert=self, reason=reason) self.revoked = True @@ -485,6 +485,7 @@ def revoke( async def arevoke( self, reason: ReasonFlags = ReasonFlags.unspecified, compromised: Optional[datetime] = None ) -> None: + """Asynchronous version of ``revoke()``.""" with self._revoke(reason, compromised): await self.asave() @@ -1257,7 +1258,7 @@ def _cache_data(self) -> Iterator[tuple[str, bytes, int]]: now = datetime.now(tz=tz.utc) if self.loaded.next_update_utc is not None: - expires_seconds = (self.loaded.next_update_utc - now).total_seconds() + expires_seconds = int((self.loaded.next_update_utc - now).total_seconds()) else: # pragma: no cover # we never generate CRLs without a next_update flag. expires_seconds = 86400 diff --git a/ca/django_ca/querysets.py b/ca/django_ca/querysets.py index 90d7516b2..e7a4ab8b7 100644 --- a/ca/django_ca/querysets.py +++ b/ca/django_ca/querysets.py @@ -85,17 +85,23 @@ def filter(self: X509CertMixinQuerySetProtocol) -> X509CertMixinQuerySetProtocol model: X509CertMixinTypeVar + async def aget(self, *args: Any, **kwargs: Any) -> X509CertMixinTypeVar: ... + def filter(self, *args: Any, **kwargs: Any) -> "Self": ... def get(self, *args: Any, **kwargs: Any) -> X509CertMixinTypeVar: ... + def _serial_or_cn_query(self, identifier: str) -> tuple[Q, Q]: ... + def revoked(self) -> "Self": ... class DjangoCAMixin(Generic[X509CertMixinTypeVar], metaclass=abc.ABCMeta): """Mixin with common methods for CertificateAuthority and Certificate models.""" - def _serial_or_cn_query(self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str): + def _serial_or_cn_query( + self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str + ) -> tuple[Q, Q]: identifier = identifier.strip() exact_query = startswith_query = Q(cn=identifier) @@ -130,6 +136,7 @@ def get_by_serial_or_cn( async def aget_by_serial_or_cn( self: X509CertMixinQuerySetProtocol[X509CertMixinTypeVar], identifier: str ) -> X509CertMixinTypeVar: + """Asynchronous version of ``get_by_serial_or_cn()``""" exact_query, startswith_query = self._serial_or_cn_query(identifier) try: diff --git a/ca/django_ca/tests/test_utils.py b/ca/django_ca/tests/test_utils.py index c4c104409..4a34f09ef 100644 --- a/ca/django_ca/tests/test_utils.py +++ b/ca/django_ca/tests/test_utils.py @@ -558,11 +558,11 @@ def test_wrong_values(self) -> None: for key_type in ("Ed448", "Ed25519"): with pytest.raises(ValueError, match=rf"^Key size is not supported for {key_type} keys\.$"): - validate_private_key_parameters(key_type, key_size, None) + validate_private_key_parameters(key_type, key_size, None) # type: ignore[call-overload] with pytest.raises( ValueError, match=rf"^Elliptic curves are not supported for {key_type} keys\.$" ): - validate_private_key_parameters(key_type, None, elliptic_curve) + validate_private_key_parameters(key_type, None, elliptic_curve) # type: ignore[call-overload] class ValidatePublicKeyParametersTest(TestCase): @@ -572,7 +572,7 @@ def test_valid_parameters(self) -> None: """Test valid parameters.""" for key_type in ("RSA", "DSA", "EC"): for algorithm in (hashes.SHA256(), hashes.SHA512()): - validate_public_key_parameters(key_type, algorithm) + validate_public_key_parameters(key_type, algorithm) # type: ignore[arg-type] for key_type in ("Ed448", "Ed25519"): validate_public_key_parameters(key_type, None) # type: ignore[arg-type] diff --git a/ca/django_ca/views.py b/ca/django_ca/views.py index d1c0b2e34..c2602fa39 100644 --- a/ca/django_ca/views.py +++ b/ca/django_ca/views.py @@ -176,7 +176,7 @@ async def fetch_crl( # CRL is not cached, try to retrieve it from the database. if encoded_crl is None: - crl_qs: Optional[CertificateRevocationList] = ( + crl_qs = ( CertificateRevocationList.objects.scope( ca=ca, only_contains_ca_certs=only_contains_ca_certs, @@ -187,7 +187,7 @@ async def fetch_crl( .filter(data__isnull=False) # only objects that have CRL data associated with it .select_related("ca") ) - crl_obj = await crl_qs.anewest() + crl_obj: Optional[CertificateRevocationList] = await crl_qs.anewest() # CRL was not found in the database either, so we try to regenerate it. if crl_obj is None: @@ -214,7 +214,7 @@ async def fetch_crl( return encoded_crl - async def get(self, request: HttpRequest, serial: str) -> HttpResponse: # pylint: disable=unused-argument + async def get(self, request: HttpRequest, serial: str) -> HttpResponse: # pylint: disable=missing-function-docstring; standard Django view function if get_encoding := request.GET.get("encoding"): if get_encoding not in CERTIFICATE_REVOCATION_LIST_ENCODING_TYPES: @@ -275,6 +275,8 @@ class OCSPView(View): ca_ocsp = False """If set to ``True``, validate child CAs instead.""" + loaded_ca: CertificateAuthority + async def get(self, request: HttpRequest, data: str) -> HttpResponse: # pylint: disable=missing-function-docstring; standard Django view function try: @@ -405,7 +407,7 @@ async def process_ocsp_request(self, data: bytes) -> HttpResponse: # Get CA and certificate try: - ca = await self.get_ca() + ca = self.loaded_ca = await self.get_ca() except CertificateAuthority.DoesNotExist: log.error("%s: Certificate Authority could not be found.", self.ca) return self.fail() @@ -474,34 +476,31 @@ class GenericOCSPView(OCSPView): argument must be the serial for this CA. """ - auto_ca: CertificateAuthority - # NOINSPECTION NOTE: It's okay to be more specific here # noinspection PyMethodOverriding - async def dispatch(self, request: HttpRequest, serial: str, **kwargs: Any) -> "HttpResponseBase": + def dispatch(self, request: HttpRequest, serial: str, **kwargs: Any) -> "HttpResponseBase": if request.method == "GET" and "data" not in kwargs: - return await self.http_method_not_allowed(request, serial, **kwargs) + return self.http_method_not_allowed(request, serial, **kwargs) if request.method == "POST" and "data" in kwargs: - return await self.http_method_not_allowed(request, serial, **kwargs) + return self.http_method_not_allowed(request, serial, **kwargs) # COVERAGE NOTE: Checking just for safety here. if not isinstance(serial, str): # pragma: no cover raise ImproperlyConfigured("View expects a str for a serial") - self.auto_ca = await CertificateAuthority.objects.aget(serial=serial) - return await super().dispatch(request, **kwargs) + return super().dispatch(request, **kwargs) async def get_ca(self) -> CertificateAuthority: - return self.auto_ca + return await CertificateAuthority.objects.aget(serial=self.kwargs["serial"]) def get_expires(self, now: datetime) -> datetime: - return now + timedelta(seconds=self.auto_ca.ocsp_response_validity) + return now + timedelta(seconds=self.loaded_ca.ocsp_response_validity) async def get_ocsp_response(self, builder: OCSPResponseBuilder) -> Union[HttpResponse, OCSPResponse]: """Sign the OCSP request using cryptography keys.""" # Load public key try: - responder_pem = self.auto_ca.ocsp_key_backend_options["certificate"]["pem"] + responder_pem = self.loaded_ca.ocsp_key_backend_options["certificate"]["pem"] except KeyError: # The OCSP responder certificate has never been created. `manage.py init_ca` usually creates them, # so this can only happen if the system is misconfigured (e.g. Celery task is never acted upon), @@ -525,7 +524,7 @@ async def get_ocsp_response(self, builder: OCSPResponseBuilder) -> Union[HttpRes # TYPEHINT NOTE: Certificates are always generated with a supported algorithm, so we do not check. algorithm = cast(Optional[AllowedHashTypes], responder_certificate.signature_hash_algorithm) - return self.auto_ca.ocsp_key_backend.sign_ocsp_response(self.auto_ca, builder, algorithm) + return self.loaded_ca.ocsp_key_backend.sign_ocsp_response(self.loaded_ca, builder, algorithm) class GenericCAIssuersView(View):