Skip to content

Commit

Permalink
rules: properly track rate limits of rules
Browse files Browse the repository at this point in the history
  • Loading branch information
Exirel committed Jun 5, 2022
1 parent 98eb805 commit a2a8ba0
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 40 deletions.
116 changes: 77 additions & 39 deletions sopel/plugins/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import logging
import re
import threading
from typing import Generator, Iterable, Optional, Type, TypeVar
import typing
from urllib.parse import urlparse


Expand All @@ -33,6 +33,12 @@
COMMAND_DEFAULT_HELP_PREFIX, COMMAND_DEFAULT_PREFIX, URL_DEFAULT_SCHEMES)


if typing.TYPE_CHECKING:
from typing import Any, Dict, Generator, Iterable, Optional, Type

from sopel.tools.identifiers import Identifier


__all__ = [
'Manager',
'Rule',
Expand All @@ -44,7 +50,7 @@
'URLCallback',
]

TypedRule = TypeVar('TypedRule', bound='AbstractRule')
TypedRule = typing.TypeVar('TypedRule', bound='AbstractRule')

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -445,6 +451,57 @@ def check_url_callback(self, bot, url):
)


class RuleMetrics:
"""Tracker of a rule's usage."""
def __init__(self) -> None:
self.started_at: Optional[datetime.datetime] = None
self.ended_at: Optional[datetime.datetime] = None
self.last_return_value: Any = None

def start(self) -> None:
"""Record a starting time (before execution)."""
self.started_at = datetime.datetime.utcnow()

def end(self) -> None:
"""Record a ending time (after execution)."""
self.ended_at = datetime.datetime.utcnow()

def set_return_value(self, value: Any) -> None:
"""Set the last return value of a rule."""
self.last_return_value = value

def is_limited(
self,
time_limit: datetime.datetime,
) -> bool:
"""Determine if the rule hits the time limit."""
if not self.started_at and not self.ended_at:
# not even started, so not limited
return False

# detect if we just started something or if it ended
last_time = self.started_at
if self.ended_at and self.started_at < self.ended_at:
last_time = self.ended_at
# since it ended, check the return value
if self.last_return_value == IGNORE_RATE_LIMIT:
return False

return last_time > time_limit

def __enter__(self) -> RuleMetrics:
self.start()
return self

def __exit__(
self,
type: Optional[Any] = None,
value: Optional[Any] = None,
traceback: Optional[Any] = None,
) -> None:
self.end()


class AbstractRule(abc.ABC):
"""Abstract definition of a plugin's rule.
Expand Down Expand Up @@ -900,9 +957,9 @@ def __init__(self,
self._global_rate_limit = global_rate_limit

# metrics
self._metrics_nick = {}
self._metrics_sender = {}
self._metrics_global = None
self._metrics_nick: Dict[Identifier, RuleMetrics] = {}
self._metrics_sender: Dict[Identifier, RuleMetrics] = {}
self._metrics_global = RuleMetrics()

# docs & tests
self._usages = usages or tuple()
Expand Down Expand Up @@ -1030,56 +1087,37 @@ def is_unblockable(self):
return self._unblockable

def is_rate_limited(self, nick):
metrics = self._metrics_nick.get(nick)
if metrics is None:
return False
last_usage_at, exit_code = metrics

if exit_code == IGNORE_RATE_LIMIT:
return False

metrics: RuleMetrics = self._metrics_nick.get(nick, RuleMetrics())
now = datetime.datetime.utcnow()
rate_limit = datetime.timedelta(seconds=self._rate_limit)
return (now - last_usage_at) <= rate_limit
return metrics.is_limited(now - rate_limit)

def is_channel_rate_limited(self, channel):
metrics = self._metrics_sender.get(channel)
if metrics is None:
return False
last_usage_at, exit_code = metrics

if exit_code == IGNORE_RATE_LIMIT:
return False

metrics: RuleMetrics = self._metrics_sender.get(channel, RuleMetrics())
now = datetime.datetime.utcnow()
rate_limit = datetime.timedelta(seconds=self._channel_rate_limit)
return (now - last_usage_at) <= rate_limit
return metrics.is_limited(now - rate_limit)

def is_global_rate_limited(self):
metrics = self._metrics_global
if metrics is None:
return False
last_usage_at, exit_code = metrics

if exit_code == IGNORE_RATE_LIMIT:
return False

now = datetime.datetime.utcnow()
rate_limit = datetime.timedelta(seconds=self._global_rate_limit)
return (now - last_usage_at) <= rate_limit
return self._metrics_global.is_limited(now - rate_limit)

def execute(self, bot, trigger):
if not self._handler:
raise RuntimeError('Improperly configured rule: no handler')

# execute the handler
exit_code = self._handler(bot, trigger)
user_metrics: RuleMetrics = self._metrics_nick.setdefault(
trigger.nick, RuleMetrics())
sender_metrics: RuleMetrics = self._metrics_sender.setdefault(
trigger.sender, RuleMetrics())

# register metrics
now = datetime.datetime.utcnow()
self._metrics_nick[trigger.nick] = (now, exit_code)
self._metrics_sender[trigger.sender] = (now, exit_code)
self._metrics_global = (now, exit_code)
# execute the handler
with user_metrics, sender_metrics, self._metrics_global:
exit_code = self._handler(bot, trigger)
user_metrics.set_return_value(exit_code)
sender_metrics.set_return_value(exit_code)
self._metrics_global.set_return_value(exit_code)

# return exit code
return exit_code
Expand Down
27 changes: 27 additions & 0 deletions test/plugins/test_plugins_rules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the ``sopel.plugins.rules`` module."""
from __future__ import annotations

import datetime
import re

import pytest
Expand Down Expand Up @@ -463,6 +464,31 @@ def test_manager_has_action_command_aliases():
assert not manager.has_action_command('unknown')


# -----------------------------------------------------------------------------
# tests for :class:`Manager`

def test_rulemetrics():
now = datetime.datetime.utcnow()
time_window = datetime.timedelta(seconds=3600)
metrics = rules.RuleMetrics()

# never executed, so not limited
assert not metrics.is_limited(now)

# test limit while running
with metrics:
assert metrics.is_limited(now - time_window)
assert not metrics.is_limited(now + time_window)

# test limit after
assert metrics.is_limited(now - time_window)
assert not metrics.is_limited(now + time_window)

# test with NO LIMIT on the return value
metrics.set_return_value(rules.IGNORE_RATE_LIMIT)
assert not metrics.is_limited(now - time_window)
assert not metrics.is_limited(now + time_window)

# -----------------------------------------------------------------------------
# tests for :class:`Rule`

Expand Down Expand Up @@ -1488,6 +1514,7 @@ def handler(bot, trigger):
rate_limit=20,
global_rate_limit=20,
channel_rate_limit=20,
threaded=False, # make sure there is no race-condition here
)
assert rule.is_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
Expand Down
4 changes: 3 additions & 1 deletion test/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,9 @@ def testrule(bot, trigger):
plugin='testplugin',
label='testrule',
handler=testrule,
rate_limit=100)
rate_limit=100,
threaded=False,
)

# trigger
line = ':Test!test@example.com PRIVMSG #channel :hello'
Expand Down

0 comments on commit a2a8ba0

Please sign in to comment.