Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support trying multiple localparts for OpenID Connect. #8801

Merged
merged 17 commits into from
Nov 25, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8801.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
11 changes: 10 additions & 1 deletion docs/sso_mapping_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
* `map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
Expand Down
119 changes: 49 additions & 70 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
Expand All @@ -35,15 +36,10 @@

from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import (
JsonDict,
UserID,
contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
)
from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder

if TYPE_CHECKING:
Expand Down Expand Up @@ -869,73 +865,50 @@ async def _map_userinfo_to_user(
# to be strings.
remote_user_id = str(remote_user_id)

# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
mapper_signature = inspect.signature(
self._user_mapping_provider.map_user_attributes
)
if previously_registered_user_id:
return previously_registered_user_id
supports_failures = "failures" in mapper_signature.parameters

# Otherwise, generate a new user.
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)
except Exception as e:
raise MappingException(
"Could not extract user attributes from OIDC response: " + str(e)
)
async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
"""
Call the mapping provider to map the OIDC userinfo and token to user attributes.

logger.debug(
"Retrieved user attributes from user mapping provider: %r", attributes
)
This is backwards compatibility for abstraction for the SSO handler.
"""
if supports_failures:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token, failures
)
else:
# If the mapping provider does not support processing failures,
# do not continually generate the same Matrix ID since this will
# likely continue to fail.
if failures:
raise RuntimeError(
"Mapping provider does not support de-duplicating Matrix IDs"
)

localpart = attributes["localpart"]
if not localpart:
raise MappingException(
"Error parsing OIDC response: OIDC mapping provider plugin "
"did not return a localpart value"
)
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
userinfo, token
)

user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
registered_user_id = next(iter(users))
elif user_id in users:
registered_user_id = user_id
else:
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, list(users.keys())
)
)
else:
# This mxid is taken
raise MappingException("mxid '{}' is already taken".format(user_id))
else:
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(localpart):
raise MappingException("localpart is invalid: %s" % (localpart,))

# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=[(user_agent, ip_address)],
)
return UserAttributes(**attributes)

await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
self._allow_existing_users,
)
return registered_user_id


UserAttribute = TypedDict(
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")

Expand Down Expand Up @@ -978,13 +951,15 @@ def get_remote_user_id(self, userinfo: UserInfo) -> str:
raise NotImplementedError()

async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.

Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
failures: How many times a call to this function with this
UserInfo has resulted in a failure.

Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
Expand Down Expand Up @@ -1084,13 +1059,17 @@ def get_remote_user_id(self, userinfo: UserInfo) -> str:
return userinfo[self._config.subject_claim]

async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
clokep marked this conversation as resolved.
Show resolved Hide resolved
) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()

# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)

# Append suffix integer if last call to this function failed to produce
# a usable mxid.
localpart += str(failures) if failures else ""
Copy link
Member

@erikjohnston erikjohnston Nov 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite hesitant about this as a) it'll explode if we have more users with same localpart than the number of retries, and b) there isn't a way of writing this particular logic correctly with this API? I know that this is what we do for other SSO, but it feels like a footgun that will one day hit us (especially if we decided to reduce the number of retries to a sensible number). I don't suggest we fix this here, but might be worth at least making an issue for it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #8813


display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
Expand All @@ -1100,7 +1079,7 @@ async def map_user_attributes(
if display_name == "":
display_name = None

return UserAttribute(localpart=localpart, display_name=display_name)
return UserAttributeDict(localpart=localpart, display_name=display_name)

async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
Expand Down
91 changes: 28 additions & 63 deletions synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
Expand Down Expand Up @@ -250,14 +249,26 @@ async def _map_saml_response_to_user(
"Failed to extract remote user id from SAML response"
)

with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
"""
Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.

This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
result = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
emails=result.get("emails"),
)
if previously_registered_user_id:
return previously_registered_user_id

with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
Expand All @@ -284,59 +295,13 @@ async def _map_saml_response_to_user(
)
return registered_user_id

# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i, client_redirect_url=client_redirect_url,
)

logger.debug(
"Retrieved SAML attributes from user mapping provider: %s "
"(attempt %d)",
attribute_dict,
i,
)

localpart = attribute_dict.get("mxid_localpart")
if not localpart:
raise MappingException(
"Error parsing SAML2 response: SAML mapping provider plugin "
"did not return a mxid_localpart value"
)

displayname = attribute_dict.get("displayname")
emails = attribute_dict.get("emails", [])

# Check if this mxid already exists
if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self.server_name).to_string()
):
# This mxid is free
break
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise MappingException(
"Unable to generate a Matrix ID from the SAML response"
)

# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(localpart):
raise MappingException("localpart is invalid: %s" % (localpart,))

logger.debug("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
user_agent_ips=[(user_agent, ip_address)],
)

await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
)
return registered_user_id

def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
Expand Down Expand Up @@ -451,11 +416,11 @@ def saml_response_to_user_attributes(
)

# Use the configured mapper for this mxid_source
base_mxid_localpart = self._mxid_mapper(mxid_source)
localpart = self._mxid_mapper(mxid_source)

# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = base_mxid_localpart + (str(failures) if failures else "")
# a usable mxid.
localpart += str(failures) if failures else ""

# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
Expand Down
Loading