Skip to content

Commit

Permalink
Enhancement#5 add threshold checker (#6)
Browse files Browse the repository at this point in the history
* added threshold checker

* updated Pipfile

* changed threshold to greater (instead of greater-equal)

* added tests for threshold_checker

* updated pipfile

* updated threshold_checker to work with __call__

* * added monitor_mode to base class
* moved input validation to base class
* name fix for utils test class
* attached tests to main.py

Co-authored-by: Naor Haba <naor.haba@toluna.com>
Co-authored-by: Roy Sadaka <roy.sadaka@toluna.com>
  • Loading branch information
3 people authored Sep 15, 2022
1 parent 8476f2a commit c44c8bf
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ tensorboard = "==2.3.0"
tqdm = "==4.51.0"

[packages]
numpy = "==1.19.2"
numpy = "*"
torch = "*"
torchvision = "*"
protobuf = "==3.20.*"
Expand Down
2 changes: 1 addition & 1 deletion lpd/callbacks/callback_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CallbackBase():
State.EXTERNAL
Phase.PREDICT_END
Agrs:
Args:
apply_on_phase - (lpd.enums.Phase) the phase to invoke this callback
apply_on_states - (lpd.enums.State) state or list of states to invoke this parameter (under the relevant phase), None will invoke it on all states
round_values_on_print_to - optional, it will round the numerical values in the prints
Expand Down
28 changes: 16 additions & 12 deletions lpd/callbacks/callback_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,34 @@
from math import inf
import torch

class CallbackMonitor():
from lpd.utils.threshold_checker import ThresholdChecker, AbsoluteThresholdChecker


class CallbackMonitor:
"""
Will check if the desired metric improved with support for patience
Agrs:
Args:
patience - int or None (will be set to inf) track how many epochs/iterations without improvements in monitoring
(negative number will set to inf)
monitor_type - e.g lpd.enums.MonitorType.LOSS
stats_type - e.g lpd.enums.StatsType.VAL
monitor_mode - e.g. lpd.enums.MonitorMode.MIN, min wothh check if the metric decreased, MAX will check for increase
metric_name - in case of monitor_mode=lpd.enums.MonitorMode.METRIC, provide metric_name, otherwise, leave it None
"""
def __init__(self, monitor_type: MonitorType, stats_type: StatsType, monitor_mode: MonitorMode, patience: int=None, metric_name: Optional[str]=None):
def __init__(self, monitor_type: MonitorType, stats_type: StatsType, monitor_mode: MonitorMode,
threshold_checker: Optional[ThresholdChecker] = None, patience: int=None, metric_name: Optional[str]=None):
self.patience = inf if patience is None or patience < 0 else patience
self.patience_countdown = self.patience
self.monitor_type = monitor_type
self.stats_type = stats_type
self.monitor_mode = monitor_mode
self.threshold_checker = AbsoluteThresholdChecker(monitor_mode) if threshold_checker is None else threshold_checker
self.metric_name = metric_name
self.minimum = torch.tensor(inf)
self.maximum = torch.tensor(-inf)
self.previous = self._get_best()
self.description = self._get_description()
self._track_invoked = False
self._track_invoked = False

def _get_description(self):
desc = f'{self.monitor_mode}_{self.stats_type}_{self.monitor_type}'
Expand Down Expand Up @@ -82,29 +87,28 @@ def track(self, callback_context: CallbackContext):

if len(value_to_consider.shape) == 0 or \
(len(value_to_consider.shape) == 1 and value_to_consider.shape[0] == 1):
if self.monitor_mode == MonitorMode.MIN and value_to_consider < curr_minimum or \
self.monitor_mode == MonitorMode.MAX and value_to_consider > curr_maximum:
if self.threshold_checker(new_value=value_to_consider, old_value=curr_best):
did_improve = True
self.patience_countdown = self.patience
else:
if self.patience != inf:
raise ValueError("[CallbackMonitor] - can't monitor patience for metric that has multiple values")
return CallbackMonitorResult(did_improve=did_improve,
new_value=value_to_consider,

return CallbackMonitorResult(did_improve=did_improve,
new_value=value_to_consider,
prev_value=curr_previous,
new_best=new_best,
prev_best=curr_best,
change_from_previous=change_from_previous,
change_from_best=change_from_best,
patience_left=self.patience_countdown,
patience_left=self.patience_countdown,
description=self.description,
name = name)


class CallbackMonitorResult():
def __init__(self, did_improve: bool,
new_value: float,
def __init__(self, did_improve: bool,
new_value: float,
prev_value: float,
new_best: float,
prev_best: float,
Expand Down
2 changes: 1 addition & 1 deletion lpd/callbacks/scheduler_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class SchedulerStep(CallbackBase):
"""This callback will invoke a "step()" on the scheduler.
Agrs:
Args:
apply_on_phase - see in CallbackBase
apply_on_states - see in CallbackBase
scheduler_parameters_func - Since some schedulers takes parameters in step(param1, param2...)
Expand Down
72 changes: 72 additions & 0 deletions lpd/utils/threshold_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from abc import ABC, abstractmethod
from typing import Union
from torch import Tensor
from lpd.enums import MonitorMode


class ThresholdChecker(ABC):
"""
Check if the current value is better than the previous best value according to different threshold criteria
This is an abstract class meant to be inherited by different threshold checkers
Can also be inherited by the user to create custom threshold checkers
"""
def __init__(self, monitor_mode: MonitorMode, threshold: float):
self.monitor_mode = monitor_mode
self.threshold = threshold

def validate_input(self):
if self.threshold < 0:
raise ValueError(f"Threshold must be non-negative, but got {self.threshold}")

@abstractmethod
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
pass


class AbsoluteThresholdChecker(ThresholdChecker):
"""
A threshold checker that checks if the difference between the current value and the previous best value
is greater than or equal to a given threshold
Args:
monitor_mode: MIN or MAX
threshold - the threshold to check (must be non-negative)
"""
def __init__(self, monitor_mode: MonitorMode, threshold: float = 0.0):
super(AbsoluteThresholdChecker, self).__init__(monitor_mode, threshold)

def _is_new_value_lower(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
return old_value - new_value > self.threshold

def _is_new_value_higher(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
return new_value - old_value > self.threshold

def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
if self.monitor_mode == MonitorMode.MIN:
return self._is_new_value_lower(new_value, old_value)
if self.monitor_mode == MonitorMode.MAX:
return self._is_new_value_higher(new_value, old_value)


class RelativeThresholdChecker(ThresholdChecker):
"""
A threshold checker that checks if the relative difference between the current value and the previous best value
is greater than or equal to a given threshold
Args:
threshold - the threshold to check (must be non-negative)
"""
def __init__(self, monitor_mode: MonitorMode, threshold: float = 0.0):
super(RelativeThresholdChecker, self).__init__(monitor_mode, threshold)

def _is_new_value_lower(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
return (old_value - new_value) / old_value > self.threshold

def _is_new_value_higher(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
return (new_value - old_value) / old_value > self.threshold

def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
if self.monitor_mode == MonitorMode.MIN:
return self._is_new_value_lower(new_value, old_value)
if self.monitor_mode == MonitorMode.MAX:
return self._is_new_value_higher(new_value, old_value)
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tests.test_trainer import TestTrainer
from tests.test_predictor import TestPredictor
from tests.test_callbacks import TestCallbacks
from tests.test_utils import TestUtils
import unittest

import examples.multiple_inputs.train as multiple_inputs_example
Expand Down
48 changes: 48 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest

from lpd.enums import MonitorMode


class TestUtils(unittest.TestCase):

def test_absolute_threshold_checker__true(self):
from lpd.utils.threshold_checker import AbsoluteThresholdChecker
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.899), (0.1, 0.9, 0.799)]:
min_checker = AbsoluteThresholdChecker(MonitorMode.MIN, threshold)
with self.subTest():
self.assertTrue(min_checker(new_value=lower_value, old_value=higher_value))

max_checker = AbsoluteThresholdChecker(MonitorMode.MAX, threshold)
with self.subTest():
self.assertTrue(max_checker(new_value=higher_value, old_value=lower_value))

def test_absolute_threshold_checker__false(self):
from lpd.utils.threshold_checker import AbsoluteThresholdChecker
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.9), (0.1, 0.9, 0.81)]:
min_checker = AbsoluteThresholdChecker(MonitorMode.MIN, threshold)
with self.subTest():
self.assertFalse(min_checker(new_value=lower_value, old_value=higher_value))

max_checker = AbsoluteThresholdChecker(MonitorMode.MAX, threshold)
with self.subTest():
self.assertFalse(max_checker(new_value=higher_value, old_value=lower_value))

def test_relative_threshold_checker__true(self):
from lpd.utils.threshold_checker import RelativeThresholdChecker
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.899), (0.1, 120.1, 100.0)]:
min_checker = RelativeThresholdChecker(MonitorMode.MIN, threshold)
with self.subTest():
self.assertTrue(min_checker(new_value=lower_value, old_value=higher_value))
max_checker = RelativeThresholdChecker(MonitorMode.MAX, threshold)
with self.subTest():
self.assertTrue(max_checker(new_value=higher_value, old_value=lower_value))

def test_relative_threshold_checker__false(self):
from lpd.utils.threshold_checker import RelativeThresholdChecker
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.9), (0.1, 109.99, 100.0)]:
min_checker = RelativeThresholdChecker(MonitorMode.MIN, threshold)
with self.subTest():
self.assertFalse(min_checker(new_value=lower_value, old_value=higher_value))
max_checker = RelativeThresholdChecker(MonitorMode.MAX, threshold)
with self.subTest():
self.assertFalse(max_checker(new_value=higher_value, old_value=lower_value))

0 comments on commit c44c8bf

Please sign in to comment.