Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add a callback to allow modules to deny 3PID #11854

Merged
merged 8 commits into from
Feb 8, 2022
18 changes: 18 additions & 0 deletions docs/modules/password_auth_provider_callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,24 @@ any of the subsequent implementations of this callback. If every callback return
the username provided by the user is used, if any (otherwise one is automatically
generated).

## `is_3pid_allowed`

_First introduced in Synapse v1.52.0_
babolivier marked this conversation as resolved.
Show resolved Hide resolved

```python
async def is_3pid_allowed(self, medium: str, address: str) -> bool
```

Called when attempting to bind a third-party identifier (i.e. an email address or a phone
number). The module is given the medium of the third-party identifier (which is `email` if
the identifier is an email address, or `msisdn` if the identifier is a phone number). The
module must return a boolean indicating whether the identifier can be allowed to be bound
to an account on the local homeserver.

If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
callback that does not return `True` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.

## Example

Expand Down
38 changes: 38 additions & 0 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,7 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str], Awaitable[bool]]


class PasswordAuthProvider:
Expand All @@ -2079,6 +2080,7 @@ def __init__(self) -> None:
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []

# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
Expand All @@ -2090,6 +2092,7 @@ def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
Expand Down Expand Up @@ -2145,6 +2148,9 @@ def register_password_auth_provider_callbacks(
get_username_for_registration,
)

if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)

def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider

Expand Down Expand Up @@ -2343,3 +2349,35 @@ async def get_username_for_registration(
raise SynapseError(code=500, msg="Internal Server Error")

return None

async def is_3pid_allowed(self, medium: str, address: str) -> bool:
"""Check if the user can be allowed to bind a 3PID on this homeserver.

Args:
medium: The medium of the 3PID.
address: The address of the 3PID.

Returns:
Whether the 3PID is allowed to be bound on this homeserver
"""
for callback in self.is_3pid_allowed_callbacks:
try:
res = await callback(medium, address)

if res is False:
return res
elif not isinstance(res, bool):
# mypy complains that this line is unreachable because it assumes the
# data returned by the module fits the expected type. We just want
# to make sure this is the case.
logger.warning( # type: ignore[unreachable]
"Ignoring non-string value returned by"
" is_3pid_allowed callback %s: %s",
callback,
res,
)
except Exception as e:
logger.error("Module raised an exception in is_3pid_allowed: %s", e)
raise SynapseError(code=500, msg="Internal Server Error")

return True
4 changes: 2 additions & 2 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param

if not check_3pid_allowed(self.hs, "email", email):
if not await check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
"Your email domain is not authorized on this server",
Expand Down Expand Up @@ -468,7 +468,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

msisdn = phone_number_to_msisdn(country, phone_number)

if not check_3pid_allowed(self.hs, "msisdn", msisdn):
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param

if not check_3pid_allowed(self.hs, "email", email):
if not await check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
"Your email domain is not authorized to register on this server",
Expand Down Expand Up @@ -192,7 +192,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

msisdn = phone_number_to_msisdn(country, phone_number)

if not check_3pid_allowed(self.hs, "msisdn", msisdn):
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
Expand Down Expand Up @@ -617,7 +617,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]

if not check_3pid_allowed(self.hs, medium, address):
if not await check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
Expand Down
4 changes: 3 additions & 1 deletion synapse/util/threepids.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
MAX_EMAIL_ADDRESS_LENGTH = 500


def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
async def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
"""Checks whether a given format of 3PID is allowed to be used on this HS

Args:
Expand All @@ -43,6 +43,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
if not await hs.get_password_auth_provider().is_3pid_allowed(medium, address):
return False

if hs.config.registration.allowed_local_3pids:
for constraint in hs.config.registration.allowed_local_3pids:
Expand Down
59 changes: 58 additions & 1 deletion tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID

from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config

# (possibly experimental) login flows we expect to appear in the list after the normal
Expand Down Expand Up @@ -158,6 +160,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
devices.register_servlets,
logout.register_servlets,
register.register_servlets,
account.register_servlets,
]

def setUp(self):
Expand Down Expand Up @@ -803,6 +806,60 @@ def test_username_uia(self):
# Check that the callback has been called.
m.assert_called_once()

# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
def test_3pid_allowed(self):
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allong a 3PID makes Synapse accept to bind the
3PID.
"""
self.hs.get_identity_handler().send_threepid_validation = Mock(
return_value=make_awaitable(0),
)

m = Mock(return_value=make_awaitable(False))
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]

self.register_user("rin", "password")
tok = self.login("rin", "password")

channel = self.make_request(
"POST",
"/account/3pid/email/requestToken",
{
"client_secret": "foo",
"email": "foo@test.com",
"send_attempt": 0,
},
access_token=tok,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
channel.json_body,
)

m.assert_called_once_with("email", "foo@test.com")

m = Mock(return_value=make_awaitable(True))
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]

channel = self.make_request(
"POST",
"/account/3pid/email/requestToken",
{
"client_secret": "foo",
"email": "bar@test.com",
"send_attempt": 0,
},
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertIn("sid", channel.json_body)

m.assert_called_once_with("email", "bar@test.com")

def _setup_get_username_for_registration(self) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the
username the client is trying to register.
Expand Down