Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dots in claim names #95

Merged
merged 7 commits into from
May 2, 2024
18 changes: 13 additions & 5 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ The name of the claim that is used for the ``User.username`` property
can be configured via the admin (**Username claim**). By default, the username is derived from the ``sub`` claim that
is returned by the OIDC provider.

If the desired claim is nested in one or more objects, its path can be specified with dots, e.g.:
If the desired claim is nested in one or more objects, you can specify the segments
of the path:

.. code-block:: json

Expand All @@ -175,17 +176,24 @@ If the desired claim is nested in one or more objects, its path can be specified
}
}

Can be retrieved by setting the username claim to ``some.nested.claim``
Can be retrieved by setting the username claim (array field) to:

.. note::
The username claim does not support claims that have dots in their name, it cannot be configured to retrieve the following claim for instance:
- some
- nested
- claim

If the claim has dots in it, you can specify those in a segment:

.. code-block:: json

{
"some.dotted.claim": "foo"
}

can be retrieved with:

- some.dotted.claim

User profile
------------

Expand Down Expand Up @@ -254,4 +262,4 @@ and ``OIDCAuthenticationBackend.config_class`` to be this new class.

.. _mozilla-django-oidc settings documentation: https://mozilla-django-oidc.readthedocs.io/en/stable/settings.html

.. _OIDC spec: https://openid.net/specs/openid-connect-discovery-1_0.html#WellKnownRegistry
.. _OIDC spec: https://openid.net/specs/openid-connect-discovery-1_0.html#WellKnownRegistry
147 changes: 79 additions & 68 deletions mozilla_django_oidc_db/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core.exceptions import ObjectDoesNotExist

import requests
from glom import glom
from glom import Path, glom
from mozilla_django_oidc.auth import (
OIDCAuthenticationBackend as _OIDCAuthenticationBackend,
)
Expand All @@ -22,6 +22,12 @@
T = TypeVar("T", bound=OpenIDConnectConfig)


class MissingIdentifierClaim(Exception):
def __init__(self, claim_bits: list[str], *args, **kwargs):
self.claim_bits = claim_bits
super().__init__(*args, **kwargs)


class OIDCAuthenticationBackend(
GetAttributeMixin, SoloConfigMixin[T], _OIDCAuthenticationBackend
):
Expand All @@ -31,7 +37,7 @@ class OIDCAuthenticationBackend(
"""

config_identifier_field = "username_claim"
sensitive_claim_names = []
sensitive_claim_names: list[list[str]] = []

def __init__(self, *args, **kwargs):
# django-stubs returns AbstractBaseUser, but we depend on properties of
Expand All @@ -46,27 +52,26 @@ def __init__(self, *args, **kwargs):
# to avoid a large number of `OpenIDConnectConfig.get_solo` calls when
# `OIDCAuthenticationBackend.__init__` is called for permission checks

def retrieve_identifier_claim(self, claims: dict) -> str:
# NOTE: this does not support the extraction of claims that contain dots "." in
# their name (e.g. {"foo.bar": "baz"})
identifier_claim_name = getattr(self.config, self.config_identifier_field)
unique_id = glom(claims, identifier_claim_name, default="")
def retrieve_identifier_claim(
self, claims: dict, raise_on_empty: bool = False
) -> str:
claim_bits = getattr(self.config, self.config_identifier_field)
unique_id = glom(claims, Path(*claim_bits), default="")
if raise_on_empty and not unique_id:
raise MissingIdentifierClaim(claim_bits=claim_bits)
return unique_id

def get_sensitive_claims_names(self) -> list:
def get_sensitive_claims_names(self) -> list[list[str]]:
"""
Defines the claims that should be obfuscated before logging claims.
Nested claims can be specified by using a dotted path (e.g. "foo.bar.baz")

NOTE: this does not support claim names that have dots in them, so the following
claim cannot be marked as a sensitive claim

{
"foo.bar": "baz"
}
Nested claims are represented with a path of bits (e.g. ["foo", "bar", "baz"]).
Claims with dots in them are supported, e.g. ["foo.bar"].
"""
identifier_claim_name = getattr(self.config, self.config_identifier_field)
return [identifier_claim_name] + self.sensitive_claim_names
identifier_claim_bits: list[str] = getattr(
self.config, self.config_identifier_field
)
return [identifier_claim_bits] + self.sensitive_claim_names

def get_userinfo(self, access_token, id_token, payload):
"""
Expand Down Expand Up @@ -132,8 +137,8 @@ def get_user_instance_values(self, claims) -> dict[str, Any]:
Map the names and values of the claims to the fields of the User model
"""
return {
model_field: glom(claims, claims_field, default="")
for model_field, claims_field in self.config.claim_mapping.items()
model_field: glom(claims, Path(*claim_bits), default="")
for model_field, claim_bits in self.config.claim_mapping.items()
}

def create_user(self, claims):
Expand Down Expand Up @@ -169,11 +174,14 @@ def verify_claims(self, claims) -> bool:

logger.debug("OIDC claims received: %s", obfuscated_claims)

identifier_claim_name = getattr(self.config, self.config_identifier_field)
if not glom(claims, identifier_claim_name, default=""):
# check if we have an identifier
try:
self.retrieve_identifier_claim(claims, raise_on_empty=True)
except MissingIdentifierClaim as exc:
logger.error(
"%s not in OIDC claims, cannot proceed with authentication",
identifier_claim_name,
"'%s' not in OIDC claims, cannot proceed with authentication",
" > ".join(exc.claim_bits),
exc_info=exc,
)
return False
return True
Expand All @@ -199,76 +207,79 @@ def update_user(self, user, claims):

return user

def _retrieve_groups_claim(self, claims: dict[str, Any]) -> list[str]:
groups_claim_bits = self.config.groups_claim
return glom(claims, Path(*groups_claim_bits), default=[])

def update_user_superuser_status(self, user, claims) -> None:
"""
Assigns superuser status to the user if the user is a member of at least one
specific group. Superuser status is explicitly removed if the user is not or
no longer member of at least one of these groups.
"""
groups_claim = self.config.groups_claim
# can't do an isinstance check here
superuser_group_names = cast(list[str], self.config.superuser_group_names)

if not superuser_group_names:
return

claim_groups = glom(claims, groups_claim, default=[])
claim_groups = self._retrieve_groups_claim(claims)
if set(superuser_group_names) & set(claim_groups):
user.is_superuser = True
else:
user.is_superuser = False
user.save()

def update_user_groups(self, user, claims):
def update_user_groups(self, user, claims) -> None:
"""
Updates user group memberships based on the group_claim setting.

Copied and modified from: https://github.com/snok/django-auth-adfs/blob/master/django_auth_adfs/backend.py
"""
groups_claim = self.config.groups_claim

if groups_claim:
# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
claim_groups = glom(claims, groups_claim, default=[])
if claim_groups:
if not isinstance(claim_groups, list):
claim_groups = [
group_claim_bits: list[str] = self.config.groups_claim
if not group_claim_bits:
return

claim_groups = self._retrieve_groups_claim(claims)

# Update the user's group memberships
django_groups = [group.name for group in user.groups.all()]
if claim_groups:
if not isinstance(claim_groups, list):
claim_groups = [
claim_groups,
]
else:
logger.debug(
"The configured groups claim '%s' was not found in the access token",
" > ".join(group_claim_bits),
)
claim_groups = []
if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(
Group.objects.filter(name__in=claim_groups).iterator()
)
existing_group_names = frozenset(group.name for group in existing_groups)
new_groups = []
if self.config.sync_groups:
# Only sync groups that match the supplied glob pattern
new_groups = [
Group.objects.get_or_create(name=name)[0]
for name in fnmatch.filter(
claim_groups,
]
self.config.sync_groups_glob_pattern,
)
if name not in existing_group_names
]
else:
logger.debug(
"The configured groups claim '%s' was not found in the access token",
groups_claim,
)
claim_groups = []
if sorted(claim_groups) != sorted(django_groups):
existing_groups = list(
Group.objects.filter(name__in=claim_groups).iterator()
)
existing_group_names = frozenset(
group.name for group in existing_groups
)
new_groups = []
if self.config.sync_groups:
# Only sync groups that match the supplied glob pattern
new_groups = [
Group.objects.get_or_create(name=name)[0]
for name in fnmatch.filter(
claim_groups,
self.config.sync_groups_glob_pattern,
)
if name not in existing_group_names
]
else:
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)
for name in claim_groups:
if name not in existing_group_names:
try:
group = Group.objects.get(name=name)
new_groups.append(group)
except ObjectDoesNotExist:
pass
user.groups.set(existing_groups + new_groups)

def update_user_default_groups(self, user):
"""
Expand Down
16 changes: 16 additions & 0 deletions mozilla_django_oidc_db/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import models
from django.utils.translation import gettext_lazy as _

from django_jsonform.models.fields import ArrayField


class ClaimField(ArrayField):
sergei-maertens marked this conversation as resolved.
Show resolved Hide resolved
"""
A field to store a path to claims holding the desired value(s).

Each item is a segment in the path from the root to leaf for nested claims.
"""

def __init__(self, *args, **kwargs):
kwargs["base_field"] = models.CharField(_("claim path segment"), max_length=50)
super().__init__(*args, **kwargs)
Loading