From e34fa0a72e313c12c3fe0c15cffb6a1ad9d18cda Mon Sep 17 00:00:00 2001 From: Mathias Ertl Date: Wed, 1 Jan 2025 14:07:25 +0100 Subject: [PATCH] optimize CertificateRevocationListView for database queries --- ca/django_ca/acme/views.py | 5 +- ca/django_ca/managers.py | 18 ++++--- ca/django_ca/models.py | 29 ++++++++---- ca/django_ca/querysets.py | 47 ++++--------------- .../test_certificate_revocation_list_view.py | 42 +++++++++++++---- ca/django_ca/views.py | 22 ++++----- 6 files changed, 86 insertions(+), 77 deletions(-) diff --git a/ca/django_ca/acme/views.py b/ca/django_ca/acme/views.py index dbe6fa275..b51ac9092 100644 --- a/ca/django_ca/acme/views.py +++ b/ca/django_ca/acme/views.py @@ -302,7 +302,7 @@ def set_link_relations(self, response: "HttpResponseBase", **kwargs: str) -> Non # TYPEHINT NOTE: Django does not officially support dispatch() being an async method, but in async views, # it returns a coroutine, just like an async view would. Thus declaring it async should not make much # difference. - async def dispatch( # type: ignore[override] + async def dispatch( # type: ignore[override] # pylint: disable=invalid-overridden-method self, request: HttpRequest, serial: str, slug: Optional[str] = None ) -> "HttpResponseBase": if not model_settings.CA_ENABLE_ACME: @@ -1091,6 +1091,7 @@ def set_link_relations(self, response: "HttpResponseBase", **kwargs: str) -> Non @transaction.atomic def save_challenge(self, challenge: AcmeChallenge) -> None: + """Save challenge and launch Celery task.""" challenge.save() # Actually perform challenge validation asynchronously @@ -1189,7 +1190,7 @@ async def get_certificate(self, serial: str) -> Certificate: # If the account holds authorizations for all the identifiers in the certificate, it can also # be revoked, so get a list of all currently valid authorizations that the account holds authz_qs = AcmeAuthorization.objects.dns().valid().account(account=self.account).names() - authz = set([auth async for auth in authz_qs]) + authz = set(auth async for auth in authz_qs) # Get names from the certificate, first from the CommonName... # NOTE: returns empty list if subject does not have a CommonName. diff --git a/ca/django_ca/managers.py b/ca/django_ca/managers.py index 9849752cd..5a394a707 100644 --- a/ca/django_ca/managers.py +++ b/ca/django_ca/managers.py @@ -142,8 +142,6 @@ class CertificateAuthorityManager( def acme(self) -> "CertificateAuthorityQuerySet": ... - def default(self) -> "CertificateAuthority": ... - def disabled(self) -> "CertificateAuthorityQuerySet": ... def enabled(self) -> "CertificateAuthorityQuerySet": ... @@ -228,6 +226,11 @@ def _handle_crl_distribution_point( ), ) + # PYLINT NOTE: documented in queryset + def default(self) -> "CertificateAuthority": # pylint: disable=missing-function-docstring + # Needs to be here because the async_to_sync version in the queryset does not get mirrored here. + return self.all().default() + @deprecate_argument("expires", RemovedInDjangoCA230Warning, replacement="not_after") def init( # noqa: PLR0912,PLR0913,PLR0915 self, @@ -729,7 +732,7 @@ def reasons( def scope( self, - ca: "CertificateAuthority", + serial: str, only_contains_ca_certs: bool = False, only_contains_user_certs: bool = False, only_contains_attribute_certs: bool = False, @@ -876,7 +879,7 @@ def create_certificate_revocation_list( # Create subquery for the current CRL number with the given scope. number_subquery = ( self.scope( - ca=ca, + serial=ca.serial, only_contains_ca_certs=only_contains_ca_certs, only_contains_user_certs=only_contains_user_certs, only_contains_attribute_certs=only_contains_attribute_certs, @@ -902,7 +905,8 @@ def create_certificate_revocation_list( # https://docs.djangoproject.com/en/5.1/ref/models/expressions/#f-assignments-persist-after-model-save if django.VERSION >= (5, 0): # pragma: django>=5.1 branch # Assure that ``ca`` is loaded already - obj.refresh_from_db(from_queryset=self.model.objects.select_related("ca")) + fields = ("ca", "number") # only fetch required fields to optimize query + obj.refresh_from_db(from_queryset=self.model.objects.select_related("ca"), fields=fields) else: # pragma: django<5.1 branch # The `from_queryset` argument was added in Django 5.0. obj = self.model.objects.select_related("ca").get(pk=obj.pk) @@ -917,7 +921,7 @@ def create_certificate_revocation_list( # Store CRL in the database obj.data = crl.public_bytes(Encoding.DER) - obj.save() + obj.save(update_fields=("data",)) # only update single field to optimize query return obj @@ -981,7 +985,7 @@ class AcmeCertificateManager(AcmeCertificateManagerBase): if typing.TYPE_CHECKING: # See CertificateManagerMixin for description on this branch # - # pylint: disable=missing-function-docstring,unused-argument; just defining stubs here + # pylint: disable=missing-function-docstring; just defining stubs here def account(self) -> "AcmeCertificateQuerySet": ... def url(self) -> "AcmeCertificateQuerySet": ... diff --git a/ca/django_ca/models.py b/ca/django_ca/models.py index d6de8b335..f84eb5ac0 100644 --- a/ca/django_ca/models.py +++ b/ca/django_ca/models.py @@ -1252,7 +1252,7 @@ def pem(self) -> bytes: """The CRL encoded in PEM format.""" return self.loaded.public_bytes(Encoding.PEM) - def _cache_data(self) -> Iterator[tuple[str, bytes, int]]: + def _cache_data(self, serial: Optional[str] = None) -> Iterator[tuple[str, bytes, int]]: if self.data is None: raise ValueError("CRL is not yet generated for this object.") @@ -1262,10 +1262,13 @@ def _cache_data(self) -> Iterator[tuple[str, bytes, int]]: else: # pragma: no cover # we never generate CRLs without a next_update flag. expires_seconds = 86400 + if serial is None: + serial = self.ca.serial + for encoding in [Encoding.PEM, Encoding.DER]: cache_key = get_crl_cache_key( - self.ca.serial, - encoding, + serial=serial, + encoding=encoding, only_contains_ca_certs=self.only_contains_ca_certs, only_contains_user_certs=self.only_contains_user_certs, only_contains_attribute_certs=self.only_contains_attribute_certs, @@ -1284,14 +1287,22 @@ def _cache_data(self) -> Iterator[tuple[str, bytes, int]]: yield cache_key, encoded_crl, expires_seconds - def cache(self) -> None: - """Cache this instance.""" - for cache_key, encoded_crl, expires_seconds in self._cache_data(): + def cache(self, serial: Optional[str] = None) -> None: + """Cache this instance. + + If `serial` is not given, `self.ca` will be accessed (possibly triggering a database query) to + generate the cache keys. + """ + for cache_key, encoded_crl, expires_seconds in self._cache_data(serial): cache.set(cache_key, encoded_crl, expires_seconds) - async def acache(self) -> None: - """Cache this instance.""" - for cache_key, encoded_crl, expires_seconds in self._cache_data(): + async def acache(self, serial: Optional[str] = None) -> None: + """Cache this instance (async version). + + If `serial` is not given, `self.ca` will be accessed (possibly triggering a database query) to + generate the cache keys. + """ + for cache_key, encoded_crl, expires_seconds in self._cache_data(serial): await cache.aset(cache_key, encoded_crl, expires_seconds) diff --git a/ca/django_ca/querysets.py b/ca/django_ca/querysets.py index dfca1483b..0f6282d0e 100644 --- a/ca/django_ca/querysets.py +++ b/ca/django_ca/querysets.py @@ -19,6 +19,8 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from asgiref.sync import async_to_sync + from cryptography import x509 from django.core.exceptions import ImproperlyConfigured @@ -177,42 +179,6 @@ def acme(self) -> "CertificateAuthorityQuerySet": """Return usable CAs that have support for the ACME protocol enabled.""" return self.filter(acme_enabled=True) - def default(self) -> "CertificateAuthority": - """Return the default CA to use when no CA is selected. - - This function honors the :ref:`CA_DEFAULT_CA `. If no usable CA can be - returned, raises :py:exc:`~django:django.core.exceptions.ImproperlyConfigured`. - - Raises - ------ - :py:exc:`~django:django.core.exceptions.ImproperlyConfigured` - When the CA named by :ref:`CA_DEFAULT_CA ` is either not found, disabled - or not currently valid. Or, if the setting is not set, no CA is currently usable. - """ - if (serial := model_settings.CA_DEFAULT_CA) is not None: - try: - # NOTE: Don't prefilter queryset so that we can provide more specialized error messages below. - ca = self.get(serial=serial) - except self.model.DoesNotExist as ex: - raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial}: CA not found.") from ex - - if ca.enabled is False: - raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is disabled.") - - now = timezone.now() - if ca.not_after < now: - raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is expired.") - if ca.not_before > now: # OK, how could this ever happen? ;-) - raise ImproperlyConfigured(f"CA_DEFAULT_CA: {serial} is not yet valid.") - return ca - - # NOTE: We add the serial to sorting make *sure* we have deterministic behavior. In many cases, users - # will just create several CAs that all actually expire on the same day. - first_ca = self.usable().order_by("-not_after", "serial").first() # usable == enabled and valid - if first_ca is None: - raise ImproperlyConfigured("No CA is currently usable.") - return first_ca - async def adefault(self) -> "CertificateAuthority": """Return the default CA to use when no CA is selected. @@ -250,6 +216,11 @@ async def adefault(self) -> "CertificateAuthority": raise ImproperlyConfigured("No CA is currently usable.") return first_ca + @async_to_sync + async def default(self) -> "CertificateAuthority": + """Return the default CA to use when no CA is selected (synchronous version).""" + return await self.adefault() + def disabled(self) -> "CertificateAuthorityQuerySet": """Return CAs that are disabled.""" return self.filter(enabled=False) @@ -330,7 +301,7 @@ def reasons( def scope( self, - ca: "CertificateAuthority", + serial: str, only_contains_ca_certs: bool = False, only_contains_user_certs: bool = False, only_contains_attribute_certs: bool = False, @@ -338,7 +309,7 @@ def scope( ) -> "CertificateRevocationListQuerySet": """Return CRLs with the given scope.""" return self.filter( - ca=ca, + ca__serial=serial, only_contains_ca_certs=only_contains_ca_certs, only_contains_user_certs=only_contains_user_certs, only_contains_attribute_certs=only_contains_attribute_certs, diff --git a/ca/django_ca/tests/views/test_certificate_revocation_list_view.py b/ca/django_ca/tests/views/test_certificate_revocation_list_view.py index 497e3dfd7..82c1d4dbc 100644 --- a/ca/django_ca/tests/views/test_certificate_revocation_list_view.py +++ b/ca/django_ca/tests/views/test_certificate_revocation_list_view.py @@ -27,6 +27,7 @@ from django.urls import include, path, re_path, reverse import pytest +from pytest_django import DjangoAssertNumQueries from pytest_django.fixtures import SettingsWrapper from django_ca import constants @@ -108,17 +109,26 @@ def deprecated_scope() -> Iterator[None]: yield -def test_full_crl(client: Client, default_url: str, root_crl: CertificateRevocationList) -> None: +def test_full_crl( + django_assert_num_queries: DjangoAssertNumQueries, + client: Client, + default_url: str, + root_crl: CertificateRevocationList, +) -> None: """Fetch a full CRL (= CA and user certs, all reasons).""" - response = client.get(default_url) + with django_assert_num_queries(0): + response = client.get(default_url) assert response.status_code == HTTPStatus.OK assert response["Content-Type"] == "application/pkix-crl" assert response.content == root_crl.data -def test_ca_crl(client: Client, root_ca_crl: CertificateRevocationList) -> None: +def test_ca_crl( + django_assert_num_queries: DjangoAssertNumQueries, client: Client, root_ca_crl: CertificateRevocationList +) -> None: """Fetch a CA CRL.""" - response = client.get(reverse("ca", kwargs={"serial": root_ca_crl.ca.serial})) + with django_assert_num_queries(0): + response = client.get(reverse("ca", kwargs={"serial": root_ca_crl.ca.serial})) assert response.status_code == HTTPStatus.OK assert response["Content-Type"] == "application/pkix-crl" assert response.content == root_ca_crl.data @@ -132,18 +142,34 @@ def test_user_crl(client: Client, root_user_crl: CertificateRevocationList) -> N assert response.content == root_user_crl.data -def test_with_cache_miss(client: Client, default_url: str, root_crl: CertificateRevocationList) -> None: +def test_with_cache_miss( + django_assert_num_queries: DjangoAssertNumQueries, + client: Client, + default_url: str, + root_crl: CertificateRevocationList, +) -> None: """Fetch a full CRL with a cache miss.""" cache.clear() # clear the cache to generate a cache miss - response = client.get(default_url) + + with django_assert_num_queries(1) as captured: # Only one query for fetching the CRL required + response = client.get(default_url) + assert 'FROM "django_ca_certificaterevocationlist" INNER JOIN' in captured.captured_queries[0]["sql"] + assert response.status_code == HTTPStatus.OK assert response["Content-Type"] == "application/pkix-crl" assert response.content == root_crl.data -def test_regenerate_full_crl(client: Client, usable_root: CertificateAuthority, default_url: str) -> None: +def test_regenerate_full_crl( + django_assert_num_queries: DjangoAssertNumQueries, + client: Client, + usable_root: CertificateAuthority, + default_url: str, +) -> None: """Fetch a full CRL where the CRL has to be regenerated.""" - response = client.get(default_url) + with django_assert_num_queries(9): # loads of queries required to regenerate a CRL + response = client.get(default_url) + assert response.status_code == HTTPStatus.OK assert response["Content-Type"] == "application/pkix-crl" assert_crl(response.content, expected=[], encoding=Encoding.DER, signer=usable_root) diff --git a/ca/django_ca/views.py b/ca/django_ca/views.py index c2602fa39..03c2ee3fc 100644 --- a/ca/django_ca/views.py +++ b/ca/django_ca/views.py @@ -127,9 +127,7 @@ def get_key_backend_options(self, ca: CertificateAuthority) -> BaseModel: """ return ca.key_backend.get_use_private_key_options(ca, {}) - async def fetch_crl( - self, ca: CertificateAuthority, encoding: CertificateRevocationListEncodings - ) -> bytes: + async def fetch_crl(self, serial: str, encoding: CertificateRevocationListEncodings) -> bytes: """Actually fetch the CRL (nested function so that we can easily catch any exception).""" if self.scope is not _NOT_SET: warnings.warn( @@ -164,7 +162,7 @@ async def fetch_crl( ) cache_key = get_crl_cache_key( - ca.serial, + serial, encoding=encoding, only_contains_ca_certs=only_contains_ca_certs, only_contains_user_certs=only_contains_user_certs, @@ -172,25 +170,25 @@ async def fetch_crl( only_some_reasons=self.only_some_reasons, ) - encoded_crl: Optional[bytes] = cache.get(cache_key) + encoded_crl: Optional[bytes] = await cache.aget(cache_key) # CRL is not cached, try to retrieve it from the database. if encoded_crl is None: crl_qs = ( CertificateRevocationList.objects.scope( - ca=ca, + serial=serial, only_contains_ca_certs=only_contains_ca_certs, only_contains_user_certs=only_contains_user_certs, only_contains_attribute_certs=only_contains_attribute_certs, only_some_reasons=self.only_some_reasons, - ) - .filter(data__isnull=False) # only objects that have CRL data associated with it - .select_related("ca") + ).filter(data__isnull=False) # Only objects that have CRL data associated with it ) crl_obj: Optional[CertificateRevocationList] = await crl_qs.anewest() # CRL was not found in the database either, so we try to regenerate it. if crl_obj is None: + ca = await CertificateAuthority.objects.aget(serial=serial) + key_backend_options = self.get_key_backend_options(ca) expires = datetime.now(tz=tz.utc) + timedelta(seconds=self.expires) crl_obj = await CertificateRevocationList.objects.acreate_certificate_revocation_list( @@ -204,7 +202,7 @@ async def fetch_crl( ) # Cache the CRL. - await crl_obj.acache() + await crl_obj.acache(serial) # Get object in the right encoding. if encoding == Encoding.PEM: @@ -226,10 +224,8 @@ async def get(self, request: HttpRequest, serial: str) -> HttpResponse: else: encoding = self.type - ca = await CertificateAuthority.objects.aget(serial=serial) - try: - crl = await self.fetch_crl(ca, encoding) + crl = await self.fetch_crl(serial, encoding) except Exception: # pylint: disable=broad-exception-caught log.exception("Error generating a CRL") return HttpResponseServerError("Error while retrieving the CRL.", content_type="text/plain")