Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add basic authentication to replication endpoints. #8853

Merged
merged 4 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8853.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add optional HTTP authentication to replication endpoints.
7 changes: 7 additions & 0 deletions docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2587,6 +2587,13 @@ opentracing:
#
#run_background_tasks_on: worker1

# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""


# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).
Expand Down
6 changes: 5 additions & 1 deletion docs/workers.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ shared configuration file.
Normally, only a couple of changes are needed to make an existing configuration
file suitable for use with workers. First, you need to enable an "HTTP replication
listener" for the main process; and secondly, you need to enable redis-based
replication. For example:
replication. Optionally, a shared secret can be used to authenticate HTTP
traffic between workers. For example:


```yaml
Expand All @@ -103,6 +104,9 @@ listeners:
resources:
- names: [replication]

# Add a random shared secret to authenticate traffic.
worker_replication_secret: ""

redis:
enabled: true
```
Expand Down
1 change: 0 additions & 1 deletion synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def __init__(self, hs):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated, but unused.


self._presence_enabled = hs.config.use_presence

Expand Down
10 changes: 10 additions & 0 deletions synapse/config/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def read_config(self, config, **kwargs):
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")

# The shared secret used for authentication when connecting to the main synapse.
self.worker_replication_secret = config.get("worker_replication_secret", None)

self.worker_name = config.get("worker_name", self.worker_app)

self.worker_main_http_uri = config.get("worker_main_http_uri", None)
Expand Down Expand Up @@ -185,6 +188,13 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
# data). If not provided this defaults to the main process.
#
#run_background_tasks_on: worker1

# A shared secret used by the replication APIs to authenticate HTTP requests
# from workers.
#
# By default this is unused and traffic is not authenticated.
#
#worker_replication_secret: ""
"""

def read_arguments(self, args):
Expand Down
47 changes: 41 additions & 6 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ def __init__(self, hs):

assert self.METHOD in ("PUT", "POST", "GET")

self._replication_secret = None
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret

def _check_auth(self, request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")

if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the client will receive a 500 error. I'm not sure if there's a "better" exception to raise that would get passed back. But since this is a misconfiguration it isn't necessarily that the client was unauthorized or such.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, and it will send exceptions to sentry like this 👍

parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
received_secret = parts[1].decode("ascii")
if self._replication_secret == received_secret:
# Success!
return

raise RuntimeError("Invalid Authorization header.")

@abc.abstractmethod
async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
Expand Down Expand Up @@ -150,6 +169,12 @@ def make_client(cls, hs):

outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)

replication_secret = None
if hs.config.worker.worker_replication_secret:
replication_secret = hs.config.worker.worker_replication_secret.encode(
"ascii"
)

@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
Expand Down Expand Up @@ -202,6 +227,9 @@ async def send_request(instance_name="master", **kwargs):
# the master, and so whether we should clean up or not.
while True:
headers = {} # type: Dict[bytes, List[bytes]]
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure whether to use Bearer or Basic here, although it doesn't really matter since we just care about the shared secret part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Basic requires a certain format for the token, so Bearer is probably the right thing here

inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
Expand Down Expand Up @@ -236,28 +264,35 @@ def register(self, http_server):
"""

url_args = list(self.PATH_ARGS)
handler = self._handle_request
method = self.METHOD

if self.CACHE:
handler = self._cached_handler # type: ignore
url_args.append("txn_id")

args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))

http_server.register_paths(
method, [pattern], handler, self.__class__.__name__,
method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
)

def _cached_handler(self, request, txn_id, **kwargs):
def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
"""
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.

assert self.CACHE
# Check the authorization headers before handling the request.
if self._replication_secret:
self._check_auth(request)

if self.CACHE:
txn_id = kwargs.pop("txn_id")

return self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
)

return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
return self._handle_request(request, **kwargs)
119 changes: 119 additions & 0 deletions tests/replication/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Tuple

from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
from tests.unittest import override_config

logger = logging.getLogger(__name__)


class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
"""Test the authentication of HTTP calls between workers."""

servlets = [register.register_servlets]

def make_homeserver(self, reactor, clock):
config = self.default_config()
# This isn't a real configuration option but is used to provide the main
# homeserver and worker homeserver different options.
main_replication_secret = config.pop("main_replication_secret", None)
if main_replication_secret:
config["worker_replication_secret"] = main_replication_secret
return self.setup_test_homeserver(config=config)

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"

return config

def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
"""Run the actual test:

1. Create a worker homeserver.
2. Start registration by providing a user/password.
3. Complete registration by providing dummy auth (this hits the main synapse).
4. Return the final request.

"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]

request_1, channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
self.assertEqual(request_1.code, 401)

# Grab the session
session = channel_1.json_body["session"]

# also complete the dummy auth
return make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
)

def test_no_auth(self):
"""With no authentication the request should finish.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 200)

# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")

@override_config({"main_replication_secret": "my-secret"})
def test_missing_auth(self):
"""If the main process expects a secret that is not provided, an error results.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 500)

@override_config(
{
"main_replication_secret": "my-secret",
"worker_replication_secret": "wrong-secret",
}
)
def test_unauthorized(self):
"""If the main process receives the wrong secret, an error results.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 500)

@override_config({"worker_replication_secret": "my-secret"})
def test_authorized(self):
"""The request should finish when the worker provides the authentication header.
"""
request, channel = self._test_register()
self.assertEqual(request.code, 200)

# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
9 changes: 1 addition & 8 deletions tests/replication/test_client_reader_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,20 @@
# limitations under the License.
import logging

from synapse.api.constants import LoginType
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, make_request

logger = logging.getLogger(__name__)


class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the test I cribbed from and noticed a couple of unnecessary things in it, figured I'd clean it up while here.

"""Base class for tests of the replication streams"""
"""Test using one or more client readers for registration."""

servlets = [register.register_servlets]

def prepare(self, reactor, clock, hs):
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
Expand Down