From a2a8ba0fb65fedfc39b7b7b9d639ba90a5a740df Mon Sep 17 00:00:00 2001 From: Florian Strzelecki Date: Sun, 5 Jun 2022 18:23:25 +0200 Subject: [PATCH] rules: properly track rate limits of rules --- sopel/plugins/rules.py | 116 +++++++++++++++++++---------- test/plugins/test_plugins_rules.py | 27 +++++++ test/test_bot.py | 4 +- 3 files changed, 107 insertions(+), 40 deletions(-) diff --git a/sopel/plugins/rules.py b/sopel/plugins/rules.py index f31aa59f96..9709664a77 100644 --- a/sopel/plugins/rules.py +++ b/sopel/plugins/rules.py @@ -24,7 +24,7 @@ import logging import re import threading -from typing import Generator, Iterable, Optional, Type, TypeVar +import typing from urllib.parse import urlparse @@ -33,6 +33,12 @@ COMMAND_DEFAULT_HELP_PREFIX, COMMAND_DEFAULT_PREFIX, URL_DEFAULT_SCHEMES) +if typing.TYPE_CHECKING: + from typing import Any, Dict, Generator, Iterable, Optional, Type + + from sopel.tools.identifiers import Identifier + + __all__ = [ 'Manager', 'Rule', @@ -44,7 +50,7 @@ 'URLCallback', ] -TypedRule = TypeVar('TypedRule', bound='AbstractRule') +TypedRule = typing.TypeVar('TypedRule', bound='AbstractRule') LOGGER = logging.getLogger(__name__) @@ -445,6 +451,57 @@ def check_url_callback(self, bot, url): ) +class RuleMetrics: + """Tracker of a rule's usage.""" + def __init__(self) -> None: + self.started_at: Optional[datetime.datetime] = None + self.ended_at: Optional[datetime.datetime] = None + self.last_return_value: Any = None + + def start(self) -> None: + """Record a starting time (before execution).""" + self.started_at = datetime.datetime.utcnow() + + def end(self) -> None: + """Record a ending time (after execution).""" + self.ended_at = datetime.datetime.utcnow() + + def set_return_value(self, value: Any) -> None: + """Set the last return value of a rule.""" + self.last_return_value = value + + def is_limited( + self, + time_limit: datetime.datetime, + ) -> bool: + """Determine if the rule hits the time limit.""" + if not self.started_at and not self.ended_at: + # not even started, so not limited + return False + + # detect if we just started something or if it ended + last_time = self.started_at + if self.ended_at and self.started_at < self.ended_at: + last_time = self.ended_at + # since it ended, check the return value + if self.last_return_value == IGNORE_RATE_LIMIT: + return False + + return last_time > time_limit + + def __enter__(self) -> RuleMetrics: + self.start() + return self + + def __exit__( + self, + type: Optional[Any] = None, + value: Optional[Any] = None, + traceback: Optional[Any] = None, + ) -> None: + self.end() + + class AbstractRule(abc.ABC): """Abstract definition of a plugin's rule. @@ -900,9 +957,9 @@ def __init__(self, self._global_rate_limit = global_rate_limit # metrics - self._metrics_nick = {} - self._metrics_sender = {} - self._metrics_global = None + self._metrics_nick: Dict[Identifier, RuleMetrics] = {} + self._metrics_sender: Dict[Identifier, RuleMetrics] = {} + self._metrics_global = RuleMetrics() # docs & tests self._usages = usages or tuple() @@ -1030,56 +1087,37 @@ def is_unblockable(self): return self._unblockable def is_rate_limited(self, nick): - metrics = self._metrics_nick.get(nick) - if metrics is None: - return False - last_usage_at, exit_code = metrics - - if exit_code == IGNORE_RATE_LIMIT: - return False - + metrics: RuleMetrics = self._metrics_nick.get(nick, RuleMetrics()) now = datetime.datetime.utcnow() rate_limit = datetime.timedelta(seconds=self._rate_limit) - return (now - last_usage_at) <= rate_limit + return metrics.is_limited(now - rate_limit) def is_channel_rate_limited(self, channel): - metrics = self._metrics_sender.get(channel) - if metrics is None: - return False - last_usage_at, exit_code = metrics - - if exit_code == IGNORE_RATE_LIMIT: - return False - + metrics: RuleMetrics = self._metrics_sender.get(channel, RuleMetrics()) now = datetime.datetime.utcnow() rate_limit = datetime.timedelta(seconds=self._channel_rate_limit) - return (now - last_usage_at) <= rate_limit + return metrics.is_limited(now - rate_limit) def is_global_rate_limited(self): - metrics = self._metrics_global - if metrics is None: - return False - last_usage_at, exit_code = metrics - - if exit_code == IGNORE_RATE_LIMIT: - return False - now = datetime.datetime.utcnow() rate_limit = datetime.timedelta(seconds=self._global_rate_limit) - return (now - last_usage_at) <= rate_limit + return self._metrics_global.is_limited(now - rate_limit) def execute(self, bot, trigger): if not self._handler: raise RuntimeError('Improperly configured rule: no handler') - # execute the handler - exit_code = self._handler(bot, trigger) + user_metrics: RuleMetrics = self._metrics_nick.setdefault( + trigger.nick, RuleMetrics()) + sender_metrics: RuleMetrics = self._metrics_sender.setdefault( + trigger.sender, RuleMetrics()) - # register metrics - now = datetime.datetime.utcnow() - self._metrics_nick[trigger.nick] = (now, exit_code) - self._metrics_sender[trigger.sender] = (now, exit_code) - self._metrics_global = (now, exit_code) + # execute the handler + with user_metrics, sender_metrics, self._metrics_global: + exit_code = self._handler(bot, trigger) + user_metrics.set_return_value(exit_code) + sender_metrics.set_return_value(exit_code) + self._metrics_global.set_return_value(exit_code) # return exit code return exit_code diff --git a/test/plugins/test_plugins_rules.py b/test/plugins/test_plugins_rules.py index 9bb9b3dc5a..388be6be32 100644 --- a/test/plugins/test_plugins_rules.py +++ b/test/plugins/test_plugins_rules.py @@ -1,6 +1,7 @@ """Tests for the ``sopel.plugins.rules`` module.""" from __future__ import annotations +import datetime import re import pytest @@ -463,6 +464,31 @@ def test_manager_has_action_command_aliases(): assert not manager.has_action_command('unknown') +# ----------------------------------------------------------------------------- +# tests for :class:`Manager` + +def test_rulemetrics(): + now = datetime.datetime.utcnow() + time_window = datetime.timedelta(seconds=3600) + metrics = rules.RuleMetrics() + + # never executed, so not limited + assert not metrics.is_limited(now) + + # test limit while running + with metrics: + assert metrics.is_limited(now - time_window) + assert not metrics.is_limited(now + time_window) + + # test limit after + assert metrics.is_limited(now - time_window) + assert not metrics.is_limited(now + time_window) + + # test with NO LIMIT on the return value + metrics.set_return_value(rules.IGNORE_RATE_LIMIT) + assert not metrics.is_limited(now - time_window) + assert not metrics.is_limited(now + time_window) + # ----------------------------------------------------------------------------- # tests for :class:`Rule` @@ -1488,6 +1514,7 @@ def handler(bot, trigger): rate_limit=20, global_rate_limit=20, channel_rate_limit=20, + threaded=False, # make sure there is no race-condition here ) assert rule.is_rate_limited(mocktrigger.nick) is False assert rule.is_channel_rate_limited(mocktrigger.sender) is False diff --git a/test/test_bot.py b/test/test_bot.py index bdea59c91f..67659ec14c 100644 --- a/test/test_bot.py +++ b/test/test_bot.py @@ -654,7 +654,9 @@ def testrule(bot, trigger): plugin='testplugin', label='testrule', handler=testrule, - rate_limit=100) + rate_limit=100, + threaded=False, + ) # trigger line = ':Test!test@example.com PRIVMSG #channel :hello'