From c55cdf75012393f6e6ae717aed1f2ed313286cc8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 21 Sep 2018 12:30:43 +0100 Subject: [PATCH 1/5] docstrings and unittests for storage.state I spent ages trying to figure out how I was going mad... --- synapse/storage/state.py | 30 +++++++++++++++++++++-------- tests/storage/test_state.py | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 4b971efdbaf8..036b7550f30e 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -255,7 +255,17 @@ def _get_state_group_delta_txn(txn): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[(type, state_key), str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -270,7 +280,7 @@ def get_state_groups_ids(self, room_id, event_ids): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -286,7 +296,9 @@ def get_state_ids_for_group(self, state_group): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -324,7 +336,9 @@ def _get_state_groups_from_groups(self, groups, types, members=None): member events (if True), or to exclude member events (if False) Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Returns: + Deferred[dict[int, dict[(type, state_key), str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -732,8 +746,8 @@ def _get_state_for_groups(self, groups, types=None, filtered_types=None): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[(type, state_key), str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types is not None: non_member_types = [t for t in types if t[0] != EventTypes.Member] @@ -788,8 +802,8 @@ def _get_state_for_groups_using_cache( If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + Deferred[dict[int, dict[(type, state_key), str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types: types = frozenset(types) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b91096593216..1c7158c90633 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -74,6 +74,44 @@ def assertStateMapEqual(self, s1, s2): self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) + @defer.inlineCallbacks + def test_get_state_groups_ids(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_map = list(state_group_map.values())[0] + self.assertDictEqual( + state_map, + { + (EventTypes.Create, ''): e1.event_id, + (EventTypes.Name, ''): e2.event_id, + }, + ) + + @defer.inlineCallbacks + def test_get_state_groups(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups( + self.room, [e1.event_id, e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_list = list(state_group_map.values())[0] + self.assertListEqual( + [ev.event_id for ev in state_list], + [e1.event_id, e2.event_id], + ) + @defer.inlineCallbacks def test_get_state_for_event(self): From a5ddf25be6fb220ea0f6b8ac984f7a313eef266f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 26 Sep 2018 09:35:32 +0100 Subject: [PATCH 2/5] More docstring fixes for StateWorkerStore --- synapse/storage/state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 036b7550f30e..3f4cbd61c453 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -263,7 +263,7 @@ def get_state_groups_ids(self, _room_id, event_ids): event_ids (iterable[str]): ids of the events Returns: - Deferred[dict[int, dict[(type, state_key), str]]]: + Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: @@ -337,7 +337,7 @@ def _get_state_groups_from_groups(self, groups, types, members=None): Returns: Returns: - Deferred[dict[int, dict[(type, state_key), str]]]: + Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -746,7 +746,7 @@ def _get_state_for_groups(self, groups, types=None, filtered_types=None): If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), str]]]: + Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types is not None: @@ -802,7 +802,7 @@ def _get_state_for_groups_using_cache( If None, `types` filtering is applied to all events. Returns: - Deferred[dict[int, dict[(type, state_key), str]]]: + Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if types: From 7c7c842275f76b88a95486980daccf5289ea8e1b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 26 Sep 2018 14:27:34 +0100 Subject: [PATCH 3/5] changelog --- changelog.d/3958.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/3958.misc diff --git a/changelog.d/3958.misc b/changelog.d/3958.misc new file mode 100644 index 000000000000..5931d06dcff0 --- /dev/null +++ b/changelog.d/3958.misc @@ -0,0 +1 @@ +Fix docstrings and add tests for state store methods From a0771af0a25e03ea548c000e6ae3c47ccf625222 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 26 Sep 2018 17:25:59 +0100 Subject: [PATCH 4/5] fix tests --- tests/storage/test_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 1c7158c90633..9ec0a8d71b35 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -104,7 +104,7 @@ def test_get_state_groups(self): ) state_group_map = yield self.store.get_state_groups( - self.room, [e1.event_id, e2.event_id]) + self.room, [e2.event_id]) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] self.assertListEqual( From d35bf6fce0aa1c7119b3baf31acfa6fca769e557 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 26 Sep 2018 19:22:58 +0100 Subject: [PATCH 5/5] fix tests, again --- tests/storage/test_state.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 9ec0a8d71b35..b9c5b39d5982 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -107,9 +107,10 @@ def test_get_state_groups(self): self.room, [e2.event_id]) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] - self.assertListEqual( - [ev.event_id for ev in state_list], - [e1.event_id, e2.event_id], + + self.assertEqual( + {ev.event_id for ev in state_list}, + {e1.event_id, e2.event_id}, ) @defer.inlineCallbacks