Skip to content

Commit

Permalink
Merge pull request #95 from maykinmedia/feature/94-support-dots-in-cl…
Browse files Browse the repository at this point in the history
…aim-names

Support dots in claim names
  • Loading branch information
sergei-maertens authored May 2, 2024
2 parents da943ce + d25ea89 commit 907ad31
Show file tree
Hide file tree
Showing 10 changed files with 623 additions and 136 deletions.
18 changes: 13 additions & 5 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ The name of the claim that is used for the ``User.username`` property
can be configured via the admin (**Username claim**). By default, the username is derived from the ``sub`` claim that
is returned by the OIDC provider.

If the desired claim is nested in one or more objects, its path can be specified with dots, e.g.:
If the desired claim is nested in one or more objects, you can specify the segments
of the path:

.. code-block:: json
Expand All @@ -175,17 +176,24 @@ If the desired claim is nested in one or more objects, its path can be specified
}
}
Can be retrieved by setting the username claim to ``some.nested.claim``
Can be retrieved by setting the username claim (array field) to:

.. note::
The username claim does not support claims that have dots in their name, it cannot be configured to retrieve the following claim for instance:
- some
- nested
- claim

If the claim has dots in it, you can specify those in a segment:

.. code-block:: json
{
"some.dotted.claim": "foo"
}
can be retrieved with:

- some.dotted.claim

User profile
------------

Expand Down Expand Up @@ -254,4 +262,4 @@ and ``OIDCAuthenticationBackend.config_class`` to be this new class.

.. _mozilla-django-oidc settings documentation: https://mozilla-django-oidc.readthedocs.io/en/stable/settings.html

.. _OIDC spec: https://openid.net/specs/openid-connect-discovery-1_0.html#WellKnownRegistry
.. _OIDC spec: https://openid.net/specs/openid-connect-discovery-1_0.html#WellKnownRegistry
147 changes: 79 additions & 68 deletions mozilla_django_oidc_db/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core.exceptions import ObjectDoesNotExist

import requests
from glom import glom
from glom import Path, glom
from mozilla_django_oidc.auth import (
OIDCAuthenticationBackend as _OIDCAuthenticationBackend,
)
Expand All @@ -22,6 +22,12 @@
T = TypeVar("T", bound=OpenIDConnectConfig)


class MissingIdentifierClaim(Exception):
def __init__(self, claim_bits: list[str], *args, **kwargs):
self.claim_bits = claim_bits
super().__init__(*args, **kwargs)


class OIDCAuthenticationBackend(
GetAttributeMixin, SoloConfigMixin[T], _OIDCAuthenticationBackend
):
Expand All @@ -31,7 +37,7 @@ class OIDCAuthenticationBackend(
"""

config_identifier_field = "username_claim"
sensitive_claim_names = []
sensitive_claim_names: list[list[str]] = []

def __init__(self, *args, **kwargs):
# django-stubs returns AbstractBaseUser, but we depend on properties of
Expand All @@ -46,27 +52,26 @@ def __init__(self, *args, **kwargs):
# to avoid a large number of `OpenIDConnectConfig.get_solo` calls when
# `OIDCAuthenticationBackend.__init__` is called for permission checks

def retrieve_identifier_claim(self, claims: dict) -> str:
# NOTE: this does not support the extraction of claims that contain dots "." in
# their name (e.g. {"foo.bar": "baz"})
identifier_claim_name = getattr(self.config, self.config_identifier_field)
unique_id = glom(claims, identifier_claim_name, default="")
def retrieve_identifier_claim(
self, claims: dict, raise_on_empty: bool = False
) -> str:
claim_bits = getattr(self.config, self.config_identifier_field)
unique_id = glom(claims, Path(*claim_bits), default="")
if raise_on_empty and not unique_id:
raise MissingIdentifierClaim(claim_bits=claim_bits)
return unique_id

def get_sensitive_claims_names(self) -> list:
def get_sensitive_claims_names(self) -> list[list[str]]:
"""
Defines the claims that should be obfuscated before logging claims.
Nested claims can be specified by using a dotted path (e.g. "foo.bar.baz")
NOTE: this does not support claim names that have dots in them, so the following
claim cannot be marked as a sensitive claim
{
"foo.bar": "baz"
}
Nested claims are represented with a path of bits (e.g. ["foo", "bar", "baz"]).
Claims with dots in them are supported, e.g. ["foo.bar"].
"""
identifier_claim_name = getattr(self.config, self.config_identifier_field)
return [identifier_claim_name] + self.sensitive_claim_names
identifier_claim_bits: list[str] = getattr(
self.config, self.config_identifier_field
)
return [identifier_claim_bits] + self.sensitive_claim_names

def get_userinfo(self, access_token, id_token, payload):
"""
Expand Down Expand Up @@ -132,8 +137,8 @@ def get_user_instance_values(self, claims) -> dict[str, Any]:
Map the names and values of the claims to the fields of the User model
"""
return {
model_field: glom(claims, claims_field, default="")
for model_field, claims_field in self.config.claim_mapping.items()
model_field: glom(claims, Path(*claim_bits), default="")
for model_field, claim_bits in self.config.claim_mapping.items()
}

def create_user(self, claims):
Expand Down Expand Up @@ -169,11 +174,14 @@ def verify_claims(self, claims) -> bool:

logger.debug("OIDC claims received: %s", obfuscated_claims)

identifier_claim_name = getattr(self.config, self.config_identifier_field)
if not glom(claims, identifier_claim_name, default=""):
# check if we have an identifier
try:
self.retrieve_identifier_claim(claims, raise_on_empty=True)
except MissingIdentifierClaim as exc:
logger.error(
"%s not in OIDC claims, cannot proceed with authentication",
identifier_claim_name,
"'%s' not in OIDC claims, cannot proceed with authentication",
" > ".join(exc.claim_bits),
exc_info=exc,
)
return False
return True
Expand All @@ -199,76 +207,79 @@ def update_user(self, user, claims):

return user

def _retrieve_groups_claim(self, claims: dict[str, Any]) -> list[str]:
groups_claim_bits = self.config.groups_claim
return glom(claims, Path(*groups_claim_bits), default=[])

def update_user_superuser_status(self, user, claims) -> None:
"""
Assigns superuser status to the user if the user is a member of at least one
specific group. Superuser status is explicitly removed if the user is not or
no longer member of at least one of these groups.
"""
groups_claim = self.config.groups_claim
# can't do an isinstance check here
superuser_group_names = cast(list[str], self.config.superuser_group_names)

if not superuser_group_names:
return

claim_groups = glom(claims, groups_claim, default=[])
claim_groups = self._retrieve_groups_claim(claims)
if set(superuser_group_names) & set(claim_groups):
user.is_superuser = True
else:
user.is_superuser = False
user.save()

def update_user_groups(self, user, claims):
def update_user_groups(self, user, claims) -> None:
"""
Updates user group memberships based on the group_claim setting.
Copied and modified from: https://github.com/snok/django-auth-adfs/blob/master/django_auth_adfs/backend.py
"""
groups_claim = self.config.groups_claim

if groups_claim:
# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
claim_groups = glom(claims, groups_claim, default=[])
if claim_groups:
if not isinstance(claim_groups, list):
claim_groups = [
group_claim_bits: list[str] = self.config.groups_claim
if not group_claim_bits:
return

claim_groups = self._retrieve_groups_claim(claims)

# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
if claim_groups:
if not isinstance(claim_groups, list):
claim_groups = [
claim_groups,
]
else:
logger.debug(
"The configured groups claim '%s' was not found in the access token",
" > ".join(group_claim_bits),
)
claim_groups = []
if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(
Group.objects.filter(name__in=claim_groups).iterator()
)
existing_group_names = frozenset(group.name for group in existing_groups)
new_groups = []
if self.config.sync_groups:
# Only sync groups that match the supplied glob pattern
new_groups = [
Group.objects.get_or_create(name=name)[0]
for name in fnmatch.filter(
claim_groups,
]
self.config.sync_groups_glob_pattern,
)
if name not in existing_group_names
]
else:
logger.debug(
"The configured groups claim '%s' was not found in the access token",
groups_claim,
)
claim_groups = []
if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(
Group.objects.filter(name__in=claim_groups).iterator()
)
existing_group_names = frozenset(
group.name for group in existing_groups
)
new_groups = []
if self.config.sync_groups:
# Only sync groups that match the supplied glob pattern
new_groups = [
Group.objects.get_or_create(name=name)[0]
for name in fnmatch.filter(
claim_groups,
self.config.sync_groups_glob_pattern,
)
if name not in existing_group_names
]
else:
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)

def update_user_default_groups(self, user):
"""
Expand Down
16 changes: 16 additions & 0 deletions mozilla_django_oidc_db/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import models
from django.utils.translation import gettext_lazy as _

from django_jsonform.models.fields import ArrayField


class ClaimField(ArrayField):
"""
A field to store a path to claims holding the desired value(s).
Each item is a segment in the path from the root to leaf for nested claims.
"""

def __init__(self, *args, **kwargs):
kwargs["base_field"] = models.CharField(_("claim path segment"), max_length=50)
super().__init__(*args, **kwargs)
Loading

0 comments on commit 907ad31

Please sign in to comment.