Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Replace simple_async_mock with AsyncMock (#16180)
Browse files Browse the repository at this point in the history
Python 3.8 has a native AsyncMock, use it instead of a custom
implementation.
  • Loading branch information
clokep authored and hughns committed Sep 4, 2023
1 parent 0904b35 commit 9c3f1c7
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 160 deletions.
1 change: 1 addition & 0 deletions changelog.d/16180.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `AsyncMock` instead of custom code.
97 changes: 49 additions & 48 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock

import pymacaroons

Expand All @@ -35,7 +35,6 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import simple_async_mock
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders

Expand All @@ -60,16 +59,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)

self.store.insert_client_ip = simple_async_mock(None)
self.store.is_support_user = simple_async_mock(False)
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.is_support_user = AsyncMock(return_value=False)

def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = AsyncMock(return_value=False)

request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
Expand All @@ -78,7 +77,7 @@ def test_get_user_by_req_user_valid_token(self) -> None:
self.assertEqual(requester.user.to_string(), self.test_user)

def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
Expand All @@ -91,7 +90,7 @@ def test_get_user_by_req_user_bad_token(self) -> None:

def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = simple_async_mock(user_info)
self.store.get_user_by_access_token = AsyncMock(return_value=user_info)

request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
Expand All @@ -106,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self) -> None:
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
Expand All @@ -125,7 +124,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.getClientAddress.return_value.host = "192.168.10.10"
Expand All @@ -144,7 +143,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.getClientAddress.return_value.host = "131.111.8.42"
Expand All @@ -158,7 +157,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:

def test_get_user_by_req_appservice_bad_token(self) -> None:
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
Expand All @@ -172,7 +171,7 @@ def test_get_user_by_req_appservice_bad_token(self) -> None:
def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
Expand All @@ -190,8 +189,8 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
Expand All @@ -210,7 +209,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
Expand All @@ -234,10 +233,10 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None:
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = AsyncMock(return_value=None)
# This also needs to just return a truth-y value
self.store.get_device = simple_async_mock({"hidden": False})
self.store.get_device = AsyncMock(return_value={"hidden": False})

request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
Expand Down Expand Up @@ -266,10 +265,10 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = AsyncMock(return_value=None)
# This also needs to just return a falsey value
self.store.get_device = simple_async_mock(None)
self.store.get_device = AsyncMock(return_value=None)

request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
Expand All @@ -283,18 +282,18 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)

def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
self.store.get_user_by_access_token = AsyncMock(
return_value=TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
token_id=5,
token_owner="@admin:matrix.org",
token_used=True,
)
)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = AsyncMock(return_value=False)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
Expand All @@ -304,18 +303,18 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non

def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.auth._track_puppeted_user_ips = True
self.store.get_user_by_access_token = simple_async_mock(
TokenLookupResult(
self.store.get_user_by_access_token = AsyncMock(
return_value=TokenLookupResult(
user_id="@baldrick:matrix.org",
device_id="device",
token_id=5,
token_owner="@admin:matrix.org",
token_used=True,
)
)
self.store.get_user_locked_status = simple_async_mock(False)
self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
self.store.get_user_locked_status = AsyncMock(return_value=False)
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
Expand All @@ -324,7 +323,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.assertEqual(self.store.insert_client_ip.call_count, 2)

def test_get_user_from_macaroon(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_access_token = AsyncMock(return_value=None)

user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
Expand All @@ -342,8 +341,8 @@ def test_get_user_from_macaroon(self) -> None:
)

def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None)
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
self.store.get_user_by_access_token = AsyncMock(return_value=None)

user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
Expand Down Expand Up @@ -373,7 +372,7 @@ def test_blocking_mau(self) -> None:

self.auth_blocking._limit_usage_by_mau = True

self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)

e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
Expand All @@ -383,25 +382,27 @@ def test_blocking_mau(self) -> None:
self.assertEqual(e.value.code, 403)

# Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.store.get_monthly_active_count = AsyncMock(
return_value=small_number_of_users
)
self.get_success(self.auth_blocking.check_auth_blocking())

def test_blocking_mau__depending_on_user_type(self) -> None:
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True

self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Support users allowed
self.get_success(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Bots not allowed
self.get_failure(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
ResourceLimitError,
)
self.store.get_monthly_active_count = simple_async_mock(100)
self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Real users not allowed
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)

Expand All @@ -412,9 +413,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = False

self.store.get_monthly_active_count = simple_async_mock(100)
self.store.user_last_seen_monthly_active = simple_async_mock()
self.store.is_trial_user = simple_async_mock()
self.store.get_monthly_active_count = AsyncMock(return_value=100)
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.store.is_trial_user = AsyncMock(return_value=False)

appservice = ApplicationService(
"abcd",
Expand Down Expand Up @@ -443,9 +444,9 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = True

self.store.get_monthly_active_count = simple_async_mock(100)
self.store.user_last_seen_monthly_active = simple_async_mock()
self.store.is_trial_user = simple_async_mock()
self.store.get_monthly_active_count = AsyncMock(return_value=100)
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.store.is_trial_user = AsyncMock(return_value=False)

appservice = ApplicationService(
"abcd",
Expand Down Expand Up @@ -473,7 +474,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
def test_reserved_threepid(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = simple_async_mock(2)
self.store.get_monthly_active_count = AsyncMock(return_value=2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
Expand Down
31 changes: 16 additions & 15 deletions tests/appservice/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.
import re
from typing import Any, Generator
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock

from twisted.internet import defer

from synapse.appservice import ApplicationService, Namespace

from tests import unittest
from tests.test_utils import simple_async_mock


def _regex(regex: str, exclusive: bool = True) -> Namespace:
Expand All @@ -43,8 +42,8 @@ def setUp(self) -> None:
)

self.store = Mock()
self.store.get_aliases_for_room = simple_async_mock([])
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_aliases_for_room = AsyncMock(return_value=[])
self.store.get_local_users_in_room = AsyncMock(return_value=[])

@defer.inlineCallbacks
def test_regex_user_id_prefix_match(
Expand Down Expand Up @@ -127,10 +126,10 @@ def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, Non
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.store.get_aliases_for_room = simple_async_mock(
["#irc_foobar:matrix.org", "#athing:matrix.org"]
self.store.get_aliases_for_room = AsyncMock(
return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
)
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertTrue(
(
yield self.service.is_interested_in_event(
Expand Down Expand Up @@ -182,10 +181,10 @@ def test_regex_alias_no_match(
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
self.store.get_aliases_for_room = simple_async_mock(
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
self.store.get_aliases_for_room = AsyncMock(
return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
)
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertFalse(
(
yield defer.ensureDeferred(
Expand All @@ -205,8 +204,10 @@ def test_regex_multiple_matches(
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
self.store.get_local_users_in_room = simple_async_mock([])
self.store.get_aliases_for_room = AsyncMock(
return_value=["#irc_barfoo:matrix.org"]
)
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertTrue(
(
yield self.service.is_interested_in_event(
Expand Down Expand Up @@ -235,10 +236,10 @@ def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, No
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock(
["@alice:here", "@irc_fo:here", "@bob:here"]
self.store.get_local_users_in_room = AsyncMock(
return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
)
self.store.get_aliases_for_room = simple_async_mock([])
self.store.get_aliases_for_room = AsyncMock(return_value=[])

self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
Expand Down
Loading

0 comments on commit 9c3f1c7

Please sign in to comment.