Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 78b99de

Browse files
authored
Prefer make_awaitable over defer.succeed in tests (#12505)
When configuring the return values of mocks, prefer awaitables from `make_awaitable` over `defer.succeed`. `Deferred`s are only awaitable once, so it is inappropriate for a mock to return the same `Deferred` multiple times. Also update `run_in_background` to support functions that return arbitrary awaitables. Signed-off-by: Sean Quah <seanq@element.io>
1 parent 5ef673d commit 78b99de

14 files changed

+72
-69
lines changed

changelog.d/12505.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.

synapse/logging/context.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,11 @@ def nested_logging_context(suffix: str) -> LoggingContext:
722722
R = TypeVar("R")
723723

724724

725+
async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
726+
"""Unwraps an arbitrary awaitable by awaiting it."""
727+
return await awaitable
728+
729+
725730
@overload
726731
def preserve_fn( # type: ignore[misc]
727732
f: Callable[P, Awaitable[R]],
@@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
802807
# by synchronous exceptions, so let's turn them into Failures.
803808
return defer.fail()
804809

810+
# `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
811+
# value. Convert it to a `Deferred`.
805812
if isinstance(res, typing.Coroutine):
813+
# Wrap the coroutine in a `Deferred`.
806814
res = defer.ensureDeferred(res)
807-
808-
# At this point we should have a Deferred, if not then f was a synchronous
809-
# function, wrap it in a Deferred for consistency.
810-
if not isinstance(res, defer.Deferred):
811-
# `res` is not a `Deferred` and not a `Coroutine`.
812-
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
813-
assert not isinstance(res, Awaitable)
814-
815-
return defer.succeed(res)
815+
elif isinstance(res, defer.Deferred):
816+
pass
817+
elif isinstance(res, Awaitable):
818+
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
819+
# or `Future` from `make_awaitable`.
820+
res = defer.ensureDeferred(_unwrap_awaitable(res))
821+
else:
822+
# `res` is a plain value. Wrap it in a `Deferred`.
823+
res = defer.succeed(res)
816824

817825
if res.called and not res.paused:
818826
# The function should have maintained the logcontext, so we can

tests/federation/test_federation_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_get_room_state(self):
8383
)
8484

8585
# mock up the response, and have the agent return it
86-
self._mock_agent.request.return_value = defer.succeed(
86+
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
8787
_mock_response(
8888
{
8989
"pdus": [

tests/federation/test_federation_sender.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_dont_send_device_updates_for_remote_users(self):
226226
# Send the server a device list EDU for the other user, this will cause
227227
# it to try and resync the device lists.
228228
self.hs.get_federation_transport_client().query_user_devices.return_value = (
229-
defer.succeed(
229+
make_awaitable(
230230
{
231231
"stream_id": "1",
232232
"user_id": "@user2:host2",

tests/handlers/test_e2e_keys.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from parameterized import parameterized
2020
from signedjson import key as key, sign as sign
2121

22-
from twisted.internet import defer
2322
from twisted.test.proto_helpers import MemoryReactor
2423

2524
from synapse.api.constants import RoomEncryptionAlgorithms
@@ -704,7 +703,7 @@ def test_query_devices_remote_no_sync(self) -> None:
704703
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
705704

706705
self.hs.get_federation_client().query_client_keys = mock.Mock(
707-
return_value=defer.succeed(
706+
return_value=make_awaitable(
708707
{
709708
"device_keys": {remote_user_id: {}},
710709
"master_keys": {
@@ -777,14 +776,14 @@ def test_query_devices_remote_sync(self) -> None:
777776
# Pretend we're sharing a room with the user we're querying. If not,
778777
# `_query_devices_for_destination` will return early.
779778
self.store.get_rooms_for_user = mock.Mock(
780-
return_value=defer.succeed({"some_room_id"})
779+
return_value=make_awaitable({"some_room_id"})
781780
)
782781

783782
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
784783
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
785784

786785
self.hs.get_federation_client().query_user_devices = mock.Mock(
787-
return_value=defer.succeed(
786+
return_value=make_awaitable(
788787
{
789788
"user_id": remote_user_id,
790789
"stream_id": 1,

tests/handlers/test_password_providers.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from typing import Any, Type, Union
1818
from unittest.mock import Mock
1919

20-
from twisted.internet import defer
21-
2220
import synapse
2321
from synapse.api.constants import LoginType
2422
from synapse.api.errors import Codes
@@ -190,7 +188,7 @@ def password_only_auth_provider_login_test_body(self):
190188
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
191189

192190
# check_password must return an awaitable
193-
mock_password_provider.check_password.return_value = defer.succeed(True)
191+
mock_password_provider.check_password.return_value = make_awaitable(True)
194192
channel = self._send_password_login("u", "p")
195193
self.assertEqual(channel.code, 200, channel.result)
196194
self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -226,13 +224,13 @@ def password_only_auth_provider_ui_auth_test_body(self):
226224
self.get_success(module_api.register_user("u"))
227225

228226
# log in twice, to get two devices
229-
mock_password_provider.check_password.return_value = defer.succeed(True)
227+
mock_password_provider.check_password.return_value = make_awaitable(True)
230228
tok1 = self.login("u", "p")
231229
self.login("u", "p", device_id="dev2")
232230
mock_password_provider.reset_mock()
233231

234232
# have the auth provider deny the request to start with
235-
mock_password_provider.check_password.return_value = defer.succeed(False)
233+
mock_password_provider.check_password.return_value = make_awaitable(False)
236234

237235
# make the initial request which returns a 401
238236
session = self._start_delete_device_session(tok1, "dev2")
@@ -246,7 +244,7 @@ def password_only_auth_provider_ui_auth_test_body(self):
246244
mock_password_provider.reset_mock()
247245

248246
# Finally, check the request goes through when we allow it
249-
mock_password_provider.check_password.return_value = defer.succeed(True)
247+
mock_password_provider.check_password.return_value = make_awaitable(True)
250248
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
251249
self.assertEqual(channel.code, 200)
252250
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -260,7 +258,7 @@ def local_user_fallback_login_test_body(self):
260258
self.register_user("localuser", "localpass")
261259

262260
# check_password must return an awaitable
263-
mock_password_provider.check_password.return_value = defer.succeed(False)
261+
mock_password_provider.check_password.return_value = make_awaitable(False)
264262
channel = self._send_password_login("u", "p")
265263
self.assertEqual(channel.code, 403, channel.result)
266264

@@ -277,7 +275,7 @@ def local_user_fallback_ui_auth_test_body(self):
277275
self.register_user("localuser", "localpass")
278276

279277
# have the auth provider deny the request
280-
mock_password_provider.check_password.return_value = defer.succeed(False)
278+
mock_password_provider.check_password.return_value = make_awaitable(False)
281279

282280
# log in twice, to get two devices
283281
tok1 = self.login("localuser", "localpass")
@@ -320,7 +318,7 @@ def no_local_user_fallback_login_test_body(self):
320318
self.register_user("localuser", "localpass")
321319

322320
# check_password must return an awaitable
323-
mock_password_provider.check_password.return_value = defer.succeed(False)
321+
mock_password_provider.check_password.return_value = make_awaitable(False)
324322
channel = self._send_password_login("localuser", "localpass")
325323
self.assertEqual(channel.code, 403)
326324
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -342,7 +340,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
342340
self.register_user("localuser", "localpass")
343341

344342
# allow login via the auth provider
345-
mock_password_provider.check_password.return_value = defer.succeed(True)
343+
mock_password_provider.check_password.return_value = make_awaitable(True)
346344

347345
# log in twice, to get two devices
348346
tok1 = self.login("localuser", "p")
@@ -359,7 +357,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
359357
mock_password_provider.check_password.assert_not_called()
360358

361359
# now try deleting with the local password
362-
mock_password_provider.check_password.return_value = defer.succeed(False)
360+
mock_password_provider.check_password.return_value = make_awaitable(False)
363361
channel = self._authed_delete_device(
364362
tok1, "dev2", session, "localuser", "localpass"
365363
)
@@ -413,7 +411,7 @@ def custom_auth_provider_login_test_body(self):
413411
self.assertEqual(channel.code, 400, channel.result)
414412
mock_password_provider.check_auth.assert_not_called()
415413

416-
mock_password_provider.check_auth.return_value = defer.succeed(
414+
mock_password_provider.check_auth.return_value = make_awaitable(
417415
("@user:bz", None)
418416
)
419417
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -427,7 +425,7 @@ def custom_auth_provider_login_test_body(self):
427425
# try a weird username. Again, it's unclear what we *expect* to happen
428426
# in these cases, but at least we can guard against the API changing
429427
# unexpectedly
430-
mock_password_provider.check_auth.return_value = defer.succeed(
428+
mock_password_provider.check_auth.return_value = make_awaitable(
431429
("@ MALFORMED! :bz", None)
432430
)
433431
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
@@ -477,7 +475,7 @@ def custom_auth_provider_ui_auth_test_body(self):
477475
mock_password_provider.reset_mock()
478476

479477
# right params, but authing as the wrong user
480-
mock_password_provider.check_auth.return_value = defer.succeed(
478+
mock_password_provider.check_auth.return_value = make_awaitable(
481479
("@user:bz", None)
482480
)
483481
body["auth"]["test_field"] = "foo"
@@ -490,7 +488,7 @@ def custom_auth_provider_ui_auth_test_body(self):
490488
mock_password_provider.reset_mock()
491489

492490
# and finally, succeed
493-
mock_password_provider.check_auth.return_value = defer.succeed(
491+
mock_password_provider.check_auth.return_value = make_awaitable(
494492
("@localuser:test", None)
495493
)
496494
channel = self._delete_device(tok1, "dev2", body)
@@ -508,9 +506,9 @@ def test_custom_auth_provider_callback(self):
508506
self.custom_auth_provider_callback_test_body()
509507

510508
def custom_auth_provider_callback_test_body(self):
511-
callback = Mock(return_value=defer.succeed(None))
509+
callback = Mock(return_value=make_awaitable(None))
512510

513-
mock_password_provider.check_auth.return_value = defer.succeed(
511+
mock_password_provider.check_auth.return_value = make_awaitable(
514512
("@user:bz", callback)
515513
)
516514
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -646,7 +644,7 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self):
646644
login is disabled"""
647645
# register the user and log in twice via the test login type to get two devices,
648646
self.register_user("localuser", "localpass")
649-
mock_password_provider.check_auth.return_value = defer.succeed(
647+
mock_password_provider.check_auth.return_value = make_awaitable(
650648
("@localuser:test", None)
651649
)
652650
channel = self._send_login("test.login_type", "localuser", test_field="")

tests/handlers/test_typing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
6565
# we mock out the keyring so as to skip the authentication check on the
6666
# federation API call.
6767
mock_keyring = Mock(spec=["verify_json_for_server"])
68-
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
68+
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
6969

7070
# we mock out the federation client too
7171
mock_federation_client = Mock(spec=["put_json"])
72-
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
72+
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
7373

7474
# the tests assume that we are starting at unix time 1000
7575
reactor.pump((1000,))
@@ -98,7 +98,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
9898

9999
self.datastore = hs.get_datastores().main
100100
self.datastore.get_destination_retry_timings = Mock(
101-
return_value=defer.succeed(None)
101+
return_value=make_awaitable(None)
102102
)
103103

104104
self.datastore.get_device_updates_by_remote = Mock(

tests/handlers/test_user_directory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from unittest.mock import Mock, patch
1616
from urllib.parse import quote
1717

18-
from twisted.internet import defer
1918
from twisted.test.proto_helpers import MemoryReactor
2019

2120
import synapse.rest.admin
@@ -30,6 +29,7 @@
3029

3130
from tests import unittest
3231
from tests.storage.test_user_directory import GetUserDirectoryTables
32+
from tests.test_utils import make_awaitable
3333
from tests.test_utils.event_injection import inject_member_event
3434
from tests.unittest import override_config
3535

@@ -439,7 +439,7 @@ def test_handle_user_deactivated_support_user(self) -> None:
439439
)
440440
)
441441

442-
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
442+
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
443443
with patch.object(
444444
self.store, "remove_from_user_dir", mock_remove_from_user_dir
445445
):
@@ -454,7 +454,7 @@ def test_handle_user_deactivated_regular_user(self) -> None:
454454
self.store.register_user(user_id=r_user_id, password_hash=None)
455455
)
456456

457-
mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
457+
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
458458
with patch.object(
459459
self.store, "remove_from_user_dir", mock_remove_from_user_dir
460460
):

tests/rest/client/test_presence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from http import HTTPStatus
1515
from unittest.mock import Mock
1616

17-
from twisted.internet import defer
1817
from twisted.test.proto_helpers import MemoryReactor
1918

2019
from synapse.handlers.presence import PresenceHandler
@@ -24,6 +23,7 @@
2423
from synapse.util import Clock
2524

2625
from tests import unittest
26+
from tests.test_utils import make_awaitable
2727

2828

2929
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
3737
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
3838

3939
presence_handler = Mock(spec=PresenceHandler)
40-
presence_handler.set_state.return_value = defer.succeed(None)
40+
presence_handler.set_state.return_value = make_awaitable(None)
4141

4242
hs = self.setup_test_homeserver(
4343
"red",

tests/rest/client/test_rooms.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from unittest.mock import Mock, call
2323
from urllib import parse as urlparse
2424

25-
from twisted.internet import defer
2625
from twisted.test.proto_helpers import MemoryReactor
2726

2827
import synapse.rest.admin
@@ -1426,9 +1425,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
14261425

14271426
def test_simple(self) -> None:
14281427
"Simple test for searching rooms over federation"
1429-
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
1430-
{}
1431-
)
1428+
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
14321429

14331430
search_filter = {"generic_search_term": "foobar"}
14341431

@@ -1456,7 +1453,7 @@ def test_fallback(self) -> None:
14561453
# with a 404, when using search filters.
14571454
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
14581455
HttpResponseException(404, "Not Found", b""),
1459-
defer.succeed({}),
1456+
make_awaitable({}),
14601457
)
14611458

14621459
search_filter = {"generic_search_term": "foobar"}

0 commit comments

Comments
 (0)