From 84fd85bbea9cfbe7e6ee5dcc45e740048c463ba3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 17 Feb 2022 17:54:16 +0100 Subject: [PATCH 1/5] Allow modules to set a display name on registration (#12009) Co-authored-by: Patrick Cloke --- changelog.d/12009.feature | 1 + .../password_auth_provider_callbacks.md | 35 ++++- synapse/handlers/auth.py | 58 +++++++++ synapse/module_api/__init__.py | 5 + synapse/rest/client/register.py | 20 +-- tests/handlers/test_password_providers.py | 123 +++++++++++++----- 6 files changed, 194 insertions(+), 48 deletions(-) create mode 100644 changelog.d/12009.feature diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature new file mode 100644 index 000000000..c8a531481 --- /dev/null +++ b/changelog.d/12009.feature @@ -0,0 +1 @@ +Enable modules to set a custom display name when registering a user. diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index 3f34c7977..8c5489889 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the authentication is denied. ### `on_logged_out` @@ -162,10 +162,38 @@ return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username provided by the user is used, if any (otherwise one is automatically generated). +### `get_displayname_for_registration` + +_First introduced in Synapse v1.54.0_ + +```python +async def get_displayname_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a display name to set for the +user being registered by returning it as a string, or `None` if it doesn't wish to force a +display name for this user. + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. These dictionaries are identical to the ones passed to +[`get_username_for_registration`](#get_username_for_registration), so refer to the +documentation of this callback for more information about them. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback returns `None`, +the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). + ## `is_3pid_allowed` _First introduced in Synapse v1.53.0_ @@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo - Is checked by the method: `self.check_my_login` - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - Expects a `password` field to be sent to `/login` - - Is checked by the method: `self.check_pass` - + - Is checked by the method: `self.check_pass` ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6959d1aa7..572f54b1e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2064,6 +2064,10 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable: [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2080,6 +2084,9 @@ def __init__(self) -> None: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters @@ -2099,6 +2106,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2148,6 +2158,11 @@ def register_password_auth_provider_callbacks( get_username_for_registration, ) + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + if is_3pid_allowed is not None: self.is_3pid_allowed_callbacks.append(is_3pid_allowed) @@ -2350,6 +2365,49 @@ async def get_username_for_registration( return None + async def get_displayname_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the display name to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_displayname_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_displayname_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_displayname_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + async def is_3pid_allowed( self, medium: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 59a6918b4..295b0be30 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -71,6 +71,7 @@ from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -317,6 +318,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -328,6 +332,7 @@ def register_password_auth_provider_callbacks( is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 1514c43c5..07b788422 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -695,26 +695,18 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: session_id ) - # TODO: This won't be needed anymore once https://github.com/matrix-org/matrix-dinsic/issues/793 - # is resolved. - desired_display_name = body.get("display_name") - if auth_result: - if LoginType.EMAIL_IDENTITY in auth_result: - address = auth_result[LoginType.EMAIL_IDENTITY]["address"] - if ( - self.hs.config.registration.register_just_use_email_for_display_name - ): - desired_display_name = address - else: - # Custom mapping between email address and display name - desired_display_name = _map_email_to_displayname(address) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) + ) registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, - default_display_name=desired_display_name, threepid=threepid, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4740dd0a6..49d832de8 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,7 +84,7 @@ def parse_config(self): def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) def check_auth(self, *args): @@ -122,7 +122,7 @@ def __init__(self, config, api: ModuleApi): auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth, - }, + } ) pass @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -754,7 +757,9 @@ def test_username(self): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -777,30 +782,14 @@ def test_username_uia(self): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() - - # Initiate the UIA flow. - username = "rin" - channel = self.make_request( - "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) - # Check that the callback hasn't been called yet. - m.assert_not_called() + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, - ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] + mxid = res["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") # Check that the callback has been called. @@ -817,6 +806,56 @@ def test_3pid_allowed(self): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) + def test_displayname(self): + """Tests that the get_displayname_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -877,23 +916,47 @@ def _test_3pid_allowed(self, username: str, registration: bool): m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_username_for_registration(self) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. """ - async def get_username_for_registration(uia_results, params): + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) From de584834615f9bdd1302e0bfaf7646fde51f85ae Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 22 Mar 2022 15:42:43 +0000 Subject: [PATCH 2/5] Remove dead code --- synapse/config/registration.py | 3 -- synapse/rest/client/register.py | 54 --------------------------------- 2 files changed, 57 deletions(-) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 31482b425..a99fcb2a7 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -43,9 +43,6 @@ def read_config(self, config, **kwargs): "registration_requires_token", False ) self.registration_shared_secret = config.get("registration_shared_secret") - self.register_just_use_email_for_display_name = config.get( - "register_just_use_email_for_display_name", False - ) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 07b788422..6f4b3b5dd 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -868,60 +868,6 @@ async def _do_guest_registration( return 200, result -def cap(name: str) -> str: - """Capitalise parts of a name containing different words, including those - separated by hyphens. - For example, 'John-Doe' - - Args: - The name to parse - """ - if not name: - return name - - # Split the name by whitespace then hyphens, capitalizing each part then - # joining it back together. - capatilized_name = " ".join( - "-".join(part.capitalize() for part in space_part.split("-")) - for space_part in name.split() - ) - return capatilized_name - - -def _map_email_to_displayname(address: str) -> str: - """Custom mapping from an email address to a user displayname - - Args: - address: The email address to process - Returns: - The new displayname - """ - # Split the part before and after the @ in the email. - # Replace all . with spaces in the first part - parts = address.replace(".", " ").split("@") - - # Figure out which org this email address belongs to - org_parts = parts[1].split(" ") - - # If this is a ...matrix.org email, mark them as an Admin - if org_parts[-2] == "matrix" and org_parts[-1] == "org": - org = "Tchap Admin" - - # Is this is a ...gouv.fr address, set the org to whatever is before - # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their - # org as "gouv" - elif org_parts[-2] == "gouv" and org_parts[-1] == "fr": - org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2] - - # Otherwise, mark their org as the email's second-level domain name - else: - org = org_parts[-2] - - desired_display_name = cap(parts[0]) + " [" + cap(org) + "]" - - return desired_display_name - - def _calculate_registration_flows( config: HomeServerConfig, auth_handler: AuthHandler ) -> List[List[str]]: From b8b7c861696565189a13200757a16c410b6317bf Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 22 Mar 2022 15:50:08 +0000 Subject: [PATCH 3/5] Remove more dead code --- tests/handlers/test_register.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index c010a1407..4e7dca43e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -23,7 +23,7 @@ SynapseError, ) from synapse.events.spamcheck import load_legacy_spam_checkers -from synapse.rest.client.register import _map_email_to_displayname, register_servlets +from synapse.rest.client.register import register_servlets from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, RoomID, UserID, create_requester @@ -691,26 +691,6 @@ def test_spam_checker_receives_sso_type(self): self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") ) - def test_email_to_displayname_mapping(self): - """Test that custom emails are mapped to new user displaynames correctly""" - self._check_mapping( - "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]" - ) - - self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]") - - self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]") - - # Multibyte unicode characters - self._check_mapping( - "j\u030a\u0065an-poppy.seed@example.com", - "J\u030a\u0065an-Poppy Seed [Example]", - ) - - def _check_mapping(self, i, expected): - result = _map_email_to_displayname(i) - self.assertEqual(result, expected) - @override_config( { "bind_new_user_emails_to_sydent": "https://is.example.com", From 230ff6e2c36ae10b4310bc2dc5188ee54f23a2b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Feb 2022 10:06:52 -0500 Subject: [PATCH 4/5] Pass the proper type when uploading files. (#11927) The Content-Length header should be treated as an int, not a string. This shouldn't have any user-facing change. --- changelog.d/11927.misc | 1 + synapse/rest/media/v1/upload_resource.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 changelog.d/11927.misc diff --git a/changelog.d/11927.misc b/changelog.d/11927.misc new file mode 100644 index 000000000..22c58521c --- /dev/null +++ b/changelog.d/11927.misc @@ -0,0 +1 @@ +Use the proper type for the Content-Length header in the `UploadResource`. diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 8162094cf..fde28d08c 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -49,10 +49,14 @@ async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) - content_length = request.getHeader("Content-Length") - if content_length is None: + raw_content_length = request.getHeader("Content-Length") + if raw_content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) - if int(content_length) > self.max_upload_size: + try: + content_length = int(raw_content_length) + except ValueError: + raise SynapseError(msg="Content-Length value is invalid", code=400) + if content_length > self.max_upload_size: raise SynapseError( msg="Upload request body is too large", code=413, @@ -66,7 +70,8 @@ async def _async_render_POST(self, request: SynapseRequest) -> None: upload_name: Optional[str] = upload_name_bytes.decode("utf8") except UnicodeDecodeError: raise SynapseError( - msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 + msg="Invalid UTF-8 filename parameter: %r" % (upload_name_bytes,), + code=400, ) # If the name is falsey (e.g. an empty byte string) ensure it is None. From fd8bdbf7ea31369a2a4cfd45b5ebdecbf746237f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 18 Feb 2022 12:38:48 +0100 Subject: [PATCH 5/5] Update the olddeps CI check to use an old version of markupsafe (#12025) --- changelog.d/12025.misc | 1 + tox.ini | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/12025.misc diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc new file mode 100644 index 000000000..d9475a771 --- /dev/null +++ b/changelog.d/12025.misc @@ -0,0 +1 @@ +Update the `olddeps` CI job to use an old version of `markupsafe`. diff --git a/tox.ini b/tox.ini index d8fa3fb63..44cf87d18 100644 --- a/tox.ini +++ b/tox.ini @@ -119,6 +119,9 @@ usedevelop = false deps = Automat == 0.8.0 lxml + # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on + # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed. + markupsafe < 2.1 {[base]deps} commands =