Skip to content

Commit

Permalink
rate_limit: Stop wrapping rate limited functions.
Browse files Browse the repository at this point in the history
This refactors `rate_limit` so that we no longer use it as a decorator.
This is a workaround to python/mypy#12909 as
`rate_limit` previous expects different parameters than its callers.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
  • Loading branch information
PIG208 committed Jul 28, 2022
1 parent b828d66 commit 72ea021
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 45 deletions.
63 changes: 22 additions & 41 deletions zerver/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ def _wrapped_view_func(
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> HttpResponse:
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
return rate_limit(view_func)(request, *args, **kwargs)
rate_limit(request)
return view_func(request, *args, **kwargs)

return _wrapped_view_func

Expand Down Expand Up @@ -716,10 +717,8 @@ def _wrapped_func_arguments(
) -> HttpResponse:
user_profile = validate_api_key(request, None, api_key, False)
if not skip_rate_limiting:
limited_func = rate_limit(view_func)
else:
limited_func = view_func
return limited_func(request, user_profile, *args, **kwargs)
rate_limit(request)
return view_func(request, user_profile, *args, **kwargs)

return _wrapped_func_arguments

Expand Down Expand Up @@ -780,10 +779,8 @@ def _wrapped_func_arguments(
try:
if not skip_rate_limiting:
# Apply rate limiting
target_view_func = rate_limit(view_func)
else:
target_view_func = view_func
return target_view_func(request, profile, *args, **kwargs)
rate_limit(request)
return view_func(request, profile, *args, **kwargs)
except Exception as err:
if not webhook_client_name:
raise err
Expand Down Expand Up @@ -857,9 +854,7 @@ def authenticate_log_and_execute_json(
**kwargs: object,
) -> HttpResponse:
if not skip_rate_limiting:
limited_view_func = rate_limit(view_func)
else:
limited_view_func = view_func
rate_limit(request)

if not request.user.is_authenticated:
if not allow_unauthenticated:
Expand All @@ -870,7 +865,7 @@ def authenticate_log_and_execute_json(
is_browser_view=True,
query=view_func.__name__,
)
return limited_view_func(request, request.user, *args, **kwargs)
return view_func(request, request.user, *args, **kwargs)

user_profile = request.user
validate_account_and_subdomain(request, user_profile)
Expand All @@ -879,7 +874,7 @@ def authenticate_log_and_execute_json(
raise JsonableError(_("Webhook bots can only access webhooks"))

process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
return limited_view_func(request, user_profile, *args, **kwargs)
return view_func(request, user_profile, *args, **kwargs)


# Checks if the user is logged in. If not, return an error (the
Expand Down Expand Up @@ -1062,36 +1057,22 @@ def rate_limit_remote_server(
raise e


def rate_limit(func: ViewFuncT) -> ViewFuncT:
"""Rate-limits a view."""

@wraps(func)
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:

# It is really tempting to not even wrap our original function
# when settings.RATE_LIMITING is False, but it would make
# for awkward unit testing in some situations.
if not settings.RATE_LIMITING:
return func(request, *args, **kwargs)

if client_is_exempt_from_rate_limiting(request):
return func(request, *args, **kwargs)

user = request.user
remote_server = RequestNotes.get_notes(request).remote_server
def rate_limit(request: HttpRequest) -> None:
if not settings.RATE_LIMITING:
return

if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
return func(request, *args, **kwargs)
else:
assert isinstance(user, UserProfile)
rate_limit_user(request, user, domain="api_by_user")
if client_is_exempt_from_rate_limiting(request):
return

return func(request, *args, **kwargs)
remote_server = RequestNotes.get_notes(request).remote_server

return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
if settings.ZILENCER_ENABLED and remote_server is not None:
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
elif not request.user.is_authenticated:
rate_limit_request_by_ip(request, domain="api_by_ip")
else:
assert isinstance(request.user, UserProfile)
rate_limit_user(request, request.user, domain="api_by_user")


def return_success_on_head_request(
Expand Down
7 changes: 3 additions & 4 deletions zerver/tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def my_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
request.method = "POST"
request.user = self.example_user("hamlet")
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
result = my_unlimited_view(request)
result = my_unlimited_view(request, request.user)

self.assert_json_success(result)
self.assertFalse(rate_limit_mock.called)
Expand All @@ -528,7 +528,7 @@ def my_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
request.method = "POST"
request.user = self.example_user("hamlet")
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
result = my_rate_limited_view(request)
result = my_rate_limited_view(request, request.user)

# Don't assert json_success, since it'll be the rate_limit mock object
self.assertTrue(rate_limit_mock.called)
Expand Down Expand Up @@ -630,10 +630,9 @@ def test_authenticated_rest_api_view_errors(self) -> None:
class RateLimitTestCase(ZulipTestCase):
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
def f(req: Any) -> HttpResponse:
rate_limit(req)
return json_response(msg="some value")

f = rate_limit(f)

return f

def errors_disallowed(self) -> Any:
Expand Down

0 comments on commit 72ea021

Please sign in to comment.