Skip to content

Commit

Permalink
Persist user interactive authentication sessions (matrix-org#7302)
Browse files Browse the repository at this point in the history
By persisting the user interactive authentication sessions to the database, this fixes
situations where a user hits different works throughout their auth session and also
allows sessions to persist through restarts of Synapse.
  • Loading branch information
clokep authored and phil-flex committed Jun 16, 2020
1 parent 6c90e71 commit da7d46f
Show file tree
Hide file tree
Showing 14 changed files with 434 additions and 125 deletions.
1 change: 1 addition & 0 deletions changelog.d/7302.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Persist user interactive authentication sessions across workers and Synapse restarts.
2 changes: 2 additions & 0 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
Expand Down Expand Up @@ -439,6 +440,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
Expand Down
175 changes: 61 additions & 114 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache

from ._base import BaseHandler

Expand All @@ -69,15 +69,6 @@ def __init__(self, hs):

self.bcrypt_rounds = hs.config.bcrypt_rounds

# This is not a cache per se, but a store of all current sessions that
# expire after N hours
self.sessions = ExpiringCache(
cache_name="register_sessions",
clock=hs.get_clock(),
expiry_ms=self.SESSION_EXPIRE_MS,
reset_expiry_on_get=True,
)

account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
Expand Down Expand Up @@ -119,6 +110,15 @@ def __init__(self, hs):

self._clock = self.hs.get_clock()

# Expire old UI auth sessions after a period of time.
if hs.config.worker_app is None:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"expire_old_sessions",
self._expire_old_sessions,
)

# Load the SSO HTML templates.

# The following template is shown to the user during a client login via SSO,
Expand Down Expand Up @@ -301,16 +301,21 @@ async def check_auth(
if "session" in authdict:
sid = authdict["session"]

# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
method = request.uri.decode("utf-8")

# If there's no session ID, create a new session.
if not sid:
session = self._create_session(
clientdict, (request.uri, request.method, clientdict), description
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
session_id = session["id"]

else:
session = self._get_session_info(sid)
session_id = sid
try:
session = await self.store.get_ui_auth_session(sid)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (sid,))

if not clientdict:
# This was designed to allow the client to omit the parameters
Expand All @@ -322,36 +327,35 @@ async def check_auth(
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
clientdict = session["clientdict"]
clientdict = session.clientdict

# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if session["ui_auth"] != comparator:
comparator = (uri, method, clientdict)
if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
)

if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session_id)
self._auth_dict_for_flows(flows, session.session_id)
)

creds = session["creds"]

# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
login_type = authdict["type"] # type: str
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
await self.store.mark_ui_auth_stage_complete(
session.session_id, login_type, result
)
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
Expand All @@ -367,6 +371,7 @@ async def check_auth(
# so that the client can have another go.
errordict = e.error_dict()

creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
Expand All @@ -380,9 +385,9 @@ async def check_auth(
list(clientdict),
)

return creds, clientdict, session_id
return creds, clientdict, session.session_id

ret = self._auth_dict_for_flows(flows, session_id)
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
Expand All @@ -399,13 +404,11 @@ async def add_oob_auth(
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)

sess = self._get_session_info(authdict["session"])
creds = sess["creds"]

result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)
return True
return False

Expand All @@ -427,7 +430,7 @@ def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
sid = authdict["session"]
return sid

def set_session_data(self, session_id: str, key: str, value: Any) -> None:
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
Expand All @@ -438,11 +441,12 @@ def set_session_data(self, session_id: str, key: str, value: Any) -> None:
key: The key to store the data under
value: The data to store
"""
sess = self._get_session_info(session_id)
sess["serverdict"][key] = value
self._save_session(sess)
try:
await self.store.set_ui_auth_session_data(session_id, key, value)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

def get_session_data(
async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
Expand All @@ -453,8 +457,18 @@ def get_session_data(
key: The key to store the data under
default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess["serverdict"].get(key, default)
try:
return await self.store.get_ui_auth_session_data(session_id, key, default)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

async def _expire_old_sessions(self):
"""
Invalidate any user interactive authentication sessions that have expired.
"""
now = self._clock.time_msec()
expiration_time = now - self.SESSION_EXPIRE_MS
await self.store.delete_old_ui_auth_sessions(expiration_time)

async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
Expand Down Expand Up @@ -534,67 +548,6 @@ def _auth_dict_for_flows(
"params": params,
}

def _create_session(
self,
clientdict: Dict[str, Any],
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
description: str,
) -> dict:
"""
Creates a new user interactive authentication session.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
Each session has the following keys:
id:
A unique identifier for this session. Passed back to the client
and returned for each stage.
clientdict:
The dictionary from the client root level, not the 'auth' key.
ui_auth:
A tuple which is checked at each stage of the authentication to
ensure that the asked for operation has not changed.
creds:
A map, which maps each auth-type (str) to the relevant identity
authenticated by that auth-type (mostly str, but for captcha, bool).
serverdict:
A map of data that is stored server-side and cannot be modified
by the client.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
"""
session_id = None
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)

self.sessions[session_id] = {
"id": session_id,
"clientdict": clientdict,
"ui_auth": ui_auth,
"creds": {},
"serverdict": {},
"description": description,
}

return self.sessions[session_id]

def _get_session_info(self, session_id: str) -> dict:
"""
Gets a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
try:
return self.sessions[session_id]
except KeyError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
Expand Down Expand Up @@ -994,13 +947,6 @@ async def delete_threepid(
await self.store.user_delete_threepid(user_id, medium, address)
return result

def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session

async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
Expand Down Expand Up @@ -1052,7 +998,7 @@ def _do_validate_hash():
else:
return False

def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Expand All @@ -1063,12 +1009,15 @@ def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
Returns:
The HTML to render.
"""
session = self._get_session_info(session_id)
try:
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
description=session["description"], redirect_url=redirect_url,
description=session.description, redirect_url=redirect_url,
)

def complete_sso_ui_auth(
async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
Expand All @@ -1080,13 +1029,11 @@ def complete_sso_ui_auth(
process.
"""
# Mark the stage of the authentication as successful.
sess = self._get_session_info(session_id)
creds = sess["creds"]

# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
creds[LoginType.SSO] = registered_user_id
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
session_id, LoginType.SSO, registered_user_id
)

# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def handle_ticket(
registered_user_id = await self._auth_handler.check_user_exists(user_id)

if session:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def handle_saml_response(self, request):

# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)

Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v2_alpha/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, hs):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url

def on_GET(self, request, stagetype):
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
Expand Down Expand Up @@ -180,7 +180,7 @@ def on_GET(self, request, stagetype):
else:
raise SynapseError(400, "Homeserver not configured for SSO.")

html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)

else:
raise SynapseError(404, "Unknown auth stage type")
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ async def on_POST(self, request):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)

Expand Down Expand Up @@ -598,7 +598,7 @@ async def on_POST(self, request):

# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)

Expand Down
Loading

0 comments on commit da7d46f

Please sign in to comment.