Skip to content

Commit

Permalink
Merge pull request #110 from maykinmedia/feature/claim-field-default
Browse files Browse the repository at this point in the history
Add and use convenience ClaimFieldDefault
  • Loading branch information
sergei-maertens authored Jun 12, 2024
2 parents a568ee1 + 6f82b3b commit 982ee98
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 14 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
Changelog
=========

0.18.0 (2024-06-1?)
===================

Small feature release

* Added ``mozilla_django_oidc_db.fields.ClaimFieldDefault`` to specify default values
for ``ClaimField`` in a less verbose way.

0.17.0 (2024-05-28)
===================

Expand Down
3 changes: 3 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Models
.. automodule:: mozilla_django_oidc_db.models
:members:

.. automodule:: mozilla_django_oidc_db.fields
:members:

Utils
=====

Expand Down
28 changes: 28 additions & 0 deletions mozilla_django_oidc_db/fields.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
from django.db import models
from django.utils.deconstruct import deconstructible
from django.utils.translation import gettext_lazy as _

from django_jsonform.models.fields import ArrayField


@deconstructible
class ClaimFieldDefault:
"""
Callable default for ClaimField.
Django's ArrayField requires a callable to be passed for the ``default`` kwarg, to
avoid sharing a mutable value shared by all instances. This custom class provides
a straight-forward interface so that defaults can be provided inline rather than
requiring a function to be defined at the module level, since lambda's cannot be
serialized for migrations.
Usage:
>>> field = ClaimField(default=ClaimFieldDefault("foo", "bar"))
>>> field.get_default() # ["foo", "bar"]
"""

def __init__(self, *bits: str):
self.bits = list(bits)

def __eq__(self, other) -> bool:
return self.bits == other.bits

def __call__(self) -> list[str]:
return self.bits


class ClaimField(ArrayField):
"""
A field to store a path to claims holding the desired value(s).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Migration(migrations.Migration):
max_length=50, verbose_name="claim path segment"
),
blank=True,
default=mozilla_django_oidc_db.models.get_default_groups_claim,
default=mozilla_django_oidc_db.fields.ClaimFieldDefault("roles"),
help_text="The name of the OIDC claim that holds the values to map to local user groups.",
size=None,
verbose_name="groups claim",
Expand All @@ -82,7 +82,7 @@ class Migration(migrations.Migration):
base_field=models.CharField(
max_length=50, verbose_name="claim path segment"
),
default=mozilla_django_oidc_db.models.get_default_username_claim,
default=mozilla_django_oidc_db.fields.ClaimFieldDefault("sub"),
help_text="The name of the OIDC claim that is used as the username",
size=None,
verbose_name="username claim",
Expand Down
16 changes: 4 additions & 12 deletions mozilla_django_oidc_db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from solo import settings as solo_settings
from solo.models import SingletonModel

from .fields import ClaimField
from .fields import ClaimField, ClaimFieldDefault
from .typing import ClaimPath, DjangoView


Expand All @@ -41,14 +41,6 @@ def get_claim_mapping() -> dict[str, list[str]]:
}


def get_default_username_claim() -> list[str]:
return ["sub"]


def get_default_groups_claim() -> list[str]:
return ["roles"]


class OpenIDConnectConfigBase(SingletonModel):
"""
Defines the required fields for a config to establish an OIDC connection
Expand Down Expand Up @@ -240,7 +232,7 @@ def oidcdb_username_claim(self) -> ClaimPath:
"""
The claim to read to extract the value for the username field.
"""
return get_default_username_claim()
return ["sub"]

@property
def oidcdb_userinfo_claims_source(self) -> UserInformationClaimsSources:
Expand All @@ -264,7 +256,7 @@ class OpenIDConnectConfig(OpenIDConnectConfigBase):

username_claim = ClaimField(
verbose_name=_("username claim"),
default=get_default_username_claim,
default=ClaimFieldDefault("sub"),
help_text=_("The name of the OIDC claim that is used as the username"),
)

Expand All @@ -275,7 +267,7 @@ class OpenIDConnectConfig(OpenIDConnectConfigBase):
)
groups_claim = ClaimField(
verbose_name=_("groups claim"),
default=get_default_groups_claim,
default=ClaimFieldDefault("roles"),
help_text=_(
"The name of the OIDC claim that holds the values to map to local user groups."
),
Expand Down
6 changes: 6 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from mozilla_django_oidc_db.fields import ClaimFieldDefault
from mozilla_django_oidc_db.models import OpenIDConnectConfig


Expand Down Expand Up @@ -58,3 +59,8 @@ def test_validate_username_field_not_in_claim_mapping():
assert "claim_mapping" in err_dict
error = _("The username field may not be in the claim mapping")
assert error in err_dict["claim_mapping"]


def test_claim_field_default_equality():
assert ClaimFieldDefault("foo", "bar") == ClaimFieldDefault("foo", "bar")
assert ClaimFieldDefault("foo", "bar") != ClaimFieldDefault("bar", "foo")

0 comments on commit 982ee98

Please sign in to comment.