-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add basic authentication to replication endpoints. #8853
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add optional HTTP authentication to replication endpoints. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,25 @@ def __init__(self, hs): | |
|
||
assert self.METHOD in ("PUT", "POST", "GET") | ||
|
||
self._replication_secret = None | ||
if hs.config.worker.worker_replication_secret: | ||
self._replication_secret = hs.config.worker.worker_replication_secret | ||
|
||
def _check_auth(self, request) -> None: | ||
# Get the authorization header. | ||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") | ||
|
||
if len(auth_headers) > 1: | ||
raise RuntimeError("Too many Authorization headers.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case the client will receive a 500 error. I'm not sure if there's a "better" exception to raise that would get passed back. But since this is a misconfiguration it isn't necessarily that the client was unauthorized or such. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, and it will send exceptions to sentry like this 👍 |
||
parts = auth_headers[0].split(b" ") | ||
if parts[0] == b"Bearer" and len(parts) == 2: | ||
received_secret = parts[1].decode("ascii") | ||
if self._replication_secret == received_secret: | ||
# Success! | ||
return | ||
|
||
raise RuntimeError("Invalid Authorization header.") | ||
|
||
@abc.abstractmethod | ||
async def _serialize_payload(**kwargs): | ||
"""Static method that is called when creating a request. | ||
|
@@ -150,6 +169,12 @@ def make_client(cls, hs): | |
|
||
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) | ||
|
||
replication_secret = None | ||
if hs.config.worker.worker_replication_secret: | ||
replication_secret = hs.config.worker.worker_replication_secret.encode( | ||
"ascii" | ||
) | ||
|
||
@trace(opname="outgoing_replication_request") | ||
@outgoing_gauge.track_inprogress() | ||
async def send_request(instance_name="master", **kwargs): | ||
|
@@ -202,6 +227,9 @@ async def send_request(instance_name="master", **kwargs): | |
# the master, and so whether we should clean up or not. | ||
while True: | ||
headers = {} # type: Dict[bytes, List[bytes]] | ||
# Add an authorization header, if configured. | ||
if replication_secret: | ||
headers[b"Authorization"] = [b"Bearer " + replication_secret] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure whether to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
inject_active_span_byte_dict(headers, None, check_destination=False) | ||
try: | ||
result = await request_func(uri, data, headers=headers) | ||
|
@@ -236,28 +264,35 @@ def register(self, http_server): | |
""" | ||
|
||
url_args = list(self.PATH_ARGS) | ||
handler = self._handle_request | ||
method = self.METHOD | ||
|
||
if self.CACHE: | ||
handler = self._cached_handler # type: ignore | ||
url_args.append("txn_id") | ||
|
||
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) | ||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) | ||
|
||
http_server.register_paths( | ||
method, [pattern], handler, self.__class__.__name__, | ||
method, [pattern], self._check_auth_and_handle, self.__class__.__name__, | ||
) | ||
|
||
def _cached_handler(self, request, txn_id, **kwargs): | ||
def _check_auth_and_handle(self, request, **kwargs): | ||
"""Called on new incoming requests when caching is enabled. Checks | ||
if there is a cached response for the request and returns that, | ||
otherwise calls `_handle_request` and caches its response. | ||
""" | ||
# We just use the txn_id here, but we probably also want to use the | ||
# other PATH_ARGS as well. | ||
|
||
assert self.CACHE | ||
# Check the authorization headers before handling the request. | ||
if self._replication_secret: | ||
self._check_auth(request) | ||
|
||
if self.CACHE: | ||
txn_id = kwargs.pop("txn_id") | ||
|
||
return self.response_cache.wrap( | ||
txn_id, self._handle_request, request, **kwargs | ||
) | ||
|
||
return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) | ||
return self._handle_request(request, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2020 The Matrix.org Foundation C.I.C. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
from typing import Tuple | ||
|
||
from synapse.http.site import SynapseRequest | ||
from synapse.rest.client.v2_alpha import register | ||
|
||
from tests.replication._base import BaseMultiWorkerStreamTestCase | ||
from tests.server import FakeChannel, make_request | ||
from tests.unittest import override_config | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): | ||
"""Test the authentication of HTTP calls between workers.""" | ||
|
||
servlets = [register.register_servlets] | ||
|
||
def make_homeserver(self, reactor, clock): | ||
config = self.default_config() | ||
# This isn't a real configuration option but is used to provide the main | ||
# homeserver and worker homeserver different options. | ||
main_replication_secret = config.pop("main_replication_secret", None) | ||
if main_replication_secret: | ||
config["worker_replication_secret"] = main_replication_secret | ||
return self.setup_test_homeserver(config=config) | ||
|
||
def _get_worker_hs_config(self) -> dict: | ||
config = self.default_config() | ||
config["worker_app"] = "synapse.app.client_reader" | ||
config["worker_replication_host"] = "testserv" | ||
config["worker_replication_http_port"] = "8765" | ||
|
||
return config | ||
|
||
def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]: | ||
"""Run the actual test: | ||
|
||
1. Create a worker homeserver. | ||
2. Start registration by providing a user/password. | ||
3. Complete registration by providing dummy auth (this hits the main synapse). | ||
4. Return the final request. | ||
|
||
""" | ||
worker_hs = self.make_worker_hs("synapse.app.client_reader") | ||
site = self._hs_to_site[worker_hs] | ||
|
||
request_1, channel_1 = make_request( | ||
self.reactor, | ||
site, | ||
"POST", | ||
"register", | ||
{"username": "user", "type": "m.login.password", "password": "bar"}, | ||
) # type: SynapseRequest, FakeChannel | ||
self.assertEqual(request_1.code, 401) | ||
|
||
# Grab the session | ||
session = channel_1.json_body["session"] | ||
|
||
# also complete the dummy auth | ||
return make_request( | ||
self.reactor, | ||
site, | ||
"POST", | ||
"register", | ||
{"auth": {"session": session, "type": "m.login.dummy"}}, | ||
) | ||
|
||
def test_no_auth(self): | ||
"""With no authentication the request should finish. | ||
""" | ||
request, channel = self._test_register() | ||
self.assertEqual(request.code, 200) | ||
|
||
# We're given a registered user. | ||
self.assertEqual(channel.json_body["user_id"], "@user:test") | ||
|
||
@override_config({"main_replication_secret": "my-secret"}) | ||
def test_missing_auth(self): | ||
"""If the main process expects a secret that is not provided, an error results. | ||
""" | ||
request, channel = self._test_register() | ||
self.assertEqual(request.code, 500) | ||
|
||
@override_config( | ||
{ | ||
"main_replication_secret": "my-secret", | ||
"worker_replication_secret": "wrong-secret", | ||
} | ||
) | ||
def test_unauthorized(self): | ||
"""If the main process receives the wrong secret, an error results. | ||
""" | ||
request, channel = self._test_register() | ||
self.assertEqual(request.code, 500) | ||
|
||
@override_config({"worker_replication_secret": "my-secret"}) | ||
def test_authorized(self): | ||
"""The request should finish when the worker provides the authentication header. | ||
""" | ||
request, channel = self._test_register() | ||
self.assertEqual(request.code, 200) | ||
|
||
# We're given a registered user. | ||
self.assertEqual(channel.json_body["user_id"], "@user:test") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,27 +14,20 @@ | |
# limitations under the License. | ||
import logging | ||
|
||
from synapse.api.constants import LoginType | ||
from synapse.http.site import SynapseRequest | ||
from synapse.rest.client.v2_alpha import register | ||
|
||
from tests.replication._base import BaseMultiWorkerStreamTestCase | ||
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker | ||
from tests.server import FakeChannel, make_request | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the test I cribbed from and noticed a couple of unnecessary things in it, figured I'd clean it up while here. |
||
"""Base class for tests of the replication streams""" | ||
"""Test using one or more client readers for registration.""" | ||
|
||
servlets = [register.register_servlets] | ||
|
||
def prepare(self, reactor, clock, hs): | ||
self.recaptcha_checker = DummyRecaptchaChecker(hs) | ||
auth_handler = hs.get_auth_handler() | ||
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker | ||
|
||
def _get_worker_hs_config(self) -> dict: | ||
config = self.default_config() | ||
config["worker_app"] = "synapse.app.client_reader" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unrelated, but unused.