diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 864a50b..4e93404 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,14 @@ Changelog ========= +0.18.0 (2024-06-1?) +=================== + +Small feature release + +* Added ``mozilla_django_oidc_db.fields.ClaimFieldDefault`` to specify default values + for ``ClaimField`` in a less verbose way. + 0.17.0 (2024-05-28) =================== diff --git a/docs/reference.rst b/docs/reference.rst index 695f1b5..83bfca8 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -23,6 +23,9 @@ Models .. automodule:: mozilla_django_oidc_db.models :members: +.. automodule:: mozilla_django_oidc_db.fields + :members: + Utils ===== diff --git a/mozilla_django_oidc_db/fields.py b/mozilla_django_oidc_db/fields.py index d4e0e86..b3c5ddb 100644 --- a/mozilla_django_oidc_db/fields.py +++ b/mozilla_django_oidc_db/fields.py @@ -1,9 +1,37 @@ from django.db import models +from django.utils.deconstruct import deconstructible from django.utils.translation import gettext_lazy as _ from django_jsonform.models.fields import ArrayField +@deconstructible +class ClaimFieldDefault: + """ + Callable default for ClaimField. + + Django's ArrayField requires a callable to be passed for the ``default`` kwarg, to + avoid sharing a mutable value shared by all instances. This custom class provides + a straight-forward interface so that defaults can be provided inline rather than + requiring a function to be defined at the module level, since lambda's cannot be + serialized for migrations. + + Usage: + + >>> field = ClaimField(default=ClaimFieldDefault("foo", "bar")) + >>> field.get_default() # ["foo", "bar"] + """ + + def __init__(self, *bits: str): + self.bits = list(bits) + + def __eq__(self, other) -> bool: + return self.bits == other.bits + + def __call__(self) -> list[str]: + return self.bits + + class ClaimField(ArrayField): """ A field to store a path to claims holding the desired value(s). diff --git a/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py b/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py index d1e5101..ebabf03 100644 --- a/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py +++ b/mozilla_django_oidc_db/migrations/0002_migrate_to_claim_field.py @@ -69,7 +69,7 @@ class Migration(migrations.Migration): max_length=50, verbose_name="claim path segment" ), blank=True, - default=mozilla_django_oidc_db.models.get_default_groups_claim, + default=mozilla_django_oidc_db.fields.ClaimFieldDefault("roles"), help_text="The name of the OIDC claim that holds the values to map to local user groups.", size=None, verbose_name="groups claim", @@ -82,7 +82,7 @@ class Migration(migrations.Migration): base_field=models.CharField( max_length=50, verbose_name="claim path segment" ), - default=mozilla_django_oidc_db.models.get_default_username_claim, + default=mozilla_django_oidc_db.fields.ClaimFieldDefault("sub"), help_text="The name of the OIDC claim that is used as the username", size=None, verbose_name="username claim", diff --git a/mozilla_django_oidc_db/models.py b/mozilla_django_oidc_db/models.py index d43beb0..12cb334 100644 --- a/mozilla_django_oidc_db/models.py +++ b/mozilla_django_oidc_db/models.py @@ -15,7 +15,7 @@ from solo import settings as solo_settings from solo.models import SingletonModel -from .fields import ClaimField +from .fields import ClaimField, ClaimFieldDefault from .typing import ClaimPath, DjangoView @@ -41,14 +41,6 @@ def get_claim_mapping() -> dict[str, list[str]]: } -def get_default_username_claim() -> list[str]: - return ["sub"] - - -def get_default_groups_claim() -> list[str]: - return ["roles"] - - class OpenIDConnectConfigBase(SingletonModel): """ Defines the required fields for a config to establish an OIDC connection @@ -240,7 +232,7 @@ def oidcdb_username_claim(self) -> ClaimPath: """ The claim to read to extract the value for the username field. """ - return get_default_username_claim() + return ["sub"] @property def oidcdb_userinfo_claims_source(self) -> UserInformationClaimsSources: @@ -264,7 +256,7 @@ class OpenIDConnectConfig(OpenIDConnectConfigBase): username_claim = ClaimField( verbose_name=_("username claim"), - default=get_default_username_claim, + default=ClaimFieldDefault("sub"), help_text=_("The name of the OIDC claim that is used as the username"), ) @@ -275,7 +267,7 @@ class OpenIDConnectConfig(OpenIDConnectConfigBase): ) groups_claim = ClaimField( verbose_name=_("groups claim"), - default=get_default_groups_claim, + default=ClaimFieldDefault("roles"), help_text=_( "The name of the OIDC claim that holds the values to map to local user groups." ), diff --git a/tests/test_models.py b/tests/test_models.py index 30803a2..e64692f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,6 +3,7 @@ import pytest +from mozilla_django_oidc_db.fields import ClaimFieldDefault from mozilla_django_oidc_db.models import OpenIDConnectConfig @@ -58,3 +59,8 @@ def test_validate_username_field_not_in_claim_mapping(): assert "claim_mapping" in err_dict error = _("The username field may not be in the claim mapping") assert error in err_dict["claim_mapping"] + + +def test_claim_field_default_equality(): + assert ClaimFieldDefault("foo", "bar") == ClaimFieldDefault("foo", "bar") + assert ClaimFieldDefault("foo", "bar") != ClaimFieldDefault("bar", "foo")