Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve robustness when handling a perspective key response by deduplicating received server keys. #15423

Merged
merged 2 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ async def get_server_verify_key_v2_indirect(
)

keys: Dict[str, Dict[str, FetchKeyResult]] = {}
added_keys: List[Tuple[str, str, FetchKeyResult]] = []
added_keys: Dict[Tuple[str, str], FetchKeyResult] = {}

time_now_ms = self.clock.time_msec()

Expand Down Expand Up @@ -752,9 +752,27 @@ async def get_server_verify_key_v2_indirect(
# we continue to process the rest of the response
continue

added_keys.extend(
(server_name, key_id, key) for key_id, key in processed_response.items()
)
for key_id, key in processed_response.items():
dict_key = (server_name, key_id)
if dict_key in added_keys:
already_present_key = added_keys[dict_key]
logger.warning(
"Duplicate server keys for %s (%s) from perspective %s (%r, %r)",
server_name,
key_id,
perspective_name,
already_present_key,
key,
)

if already_present_key.valid_until_ts > key.valid_until_ts:
# Favour the entry with the largest valid_until_ts,
# as `old_verify_keys` are also collected from this
# response.
continue

added_keys[dict_key] = key

Comment on lines -755 to +775
Copy link
Contributor

@DMRobertson DMRobertson Apr 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check I understand, the failure mode here is:

  • trusted key server gives us back multiple entries for the same (server name, key id) pair
    • possibly with different values?
  • we try to insert all such records into the db
  • db rejects this
  • guessing: the txn gets cancelled and we don't persist any keys whatsoever?

Have we seen this in practice or is this cautious, defensive programming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct — there have been at least 2 such cases (see #12736), the most recent one is preventing someone's join to #synapse.

keys.setdefault(server_name, {}).update(processed_response)

await self.store.store_server_verify_keys(
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import itertools
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

from signedjson.key import decode_verify_key_bytes

Expand Down Expand Up @@ -95,7 +95,7 @@ async def store_server_verify_keys(
self,
from_server: str,
ts_added_ms: int,
verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
) -> None:
"""Stores NACL verification keys for remote servers.
Args:
Expand All @@ -108,7 +108,7 @@ async def store_server_verify_keys(
key_values = []
value_values = []
invalidations = []
for server_name, key_id, fetch_result in verify_keys:
for (server_name, key_id), fetch_result in verify_keys.items():
key_values.append((server_name, key_id))
value_values.append(
(
Expand Down
4 changes: 2 additions & 2 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_verify_json_for_server(self) -> None:
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)},
)
self.get_success(r)

Expand Down Expand Up @@ -291,7 +291,7 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys.
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], # type: ignore[arg-type]
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
)
self.get_success(r)

Expand Down
18 changes: 9 additions & 9 deletions tests/storage/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def test_get_server_verify_keys(self) -> None:
store.store_server_verify_keys(
"from_server",
10,
[
("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
{
("server1", key_id_1): FetchKeyResult(KEY_1, 100),
("server1", key_id_2): FetchKeyResult(KEY_2, 200),
},
)
)

Expand Down Expand Up @@ -90,10 +90,10 @@ def test_cache(self) -> None:
store.store_server_verify_keys(
"from_server",
0,
[
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
{
("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
},
)
)

Expand All @@ -119,7 +119,7 @@ def test_cache(self) -> None:
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
"from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
)
self.get_success(d)

Expand Down
16 changes: 6 additions & 10 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,16 +793,12 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
hs.get_datastores().main.store_server_verify_keys(
from_server=self.OTHER_SERVER_NAME,
ts_added_ms=clock.time_msec(),
verify_keys=[
(
self.OTHER_SERVER_NAME,
verify_key_id,
FetchKeyResult(
verify_key=verify_key,
valid_until_ts=clock.time_msec() + 10000,
),
)
],
verify_keys={
(self.OTHER_SERVER_NAME, verify_key_id): FetchKeyResult(
verify_key=verify_key,
valid_until_ts=clock.time_msec() + 10000,
),
},
)
)

Expand Down