-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhancement#5 add threshold checker (#6)
* added threshold checker * updated Pipfile * changed threshold to greater (instead of greater-equal) * added tests for threshold_checker * updated pipfile * updated threshold_checker to work with __call__ * * added monitor_mode to base class * moved input validation to base class * name fix for utils test class * attached tests to main.py Co-authored-by: Naor Haba <naor.haba@toluna.com> Co-authored-by: Roy Sadaka <roy.sadaka@toluna.com>
- Loading branch information
1 parent
8476f2a
commit c44c8bf
Showing
7 changed files
with
140 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
from torch import Tensor | ||
from lpd.enums import MonitorMode | ||
|
||
|
||
class ThresholdChecker(ABC): | ||
""" | ||
Check if the current value is better than the previous best value according to different threshold criteria | ||
This is an abstract class meant to be inherited by different threshold checkers | ||
Can also be inherited by the user to create custom threshold checkers | ||
""" | ||
def __init__(self, monitor_mode: MonitorMode, threshold: float): | ||
self.monitor_mode = monitor_mode | ||
self.threshold = threshold | ||
|
||
def validate_input(self): | ||
if self.threshold < 0: | ||
raise ValueError(f"Threshold must be non-negative, but got {self.threshold}") | ||
|
||
@abstractmethod | ||
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
pass | ||
|
||
|
||
class AbsoluteThresholdChecker(ThresholdChecker): | ||
""" | ||
A threshold checker that checks if the difference between the current value and the previous best value | ||
is greater than or equal to a given threshold | ||
Args: | ||
monitor_mode: MIN or MAX | ||
threshold - the threshold to check (must be non-negative) | ||
""" | ||
def __init__(self, monitor_mode: MonitorMode, threshold: float = 0.0): | ||
super(AbsoluteThresholdChecker, self).__init__(monitor_mode, threshold) | ||
|
||
def _is_new_value_lower(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
return old_value - new_value > self.threshold | ||
|
||
def _is_new_value_higher(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
return new_value - old_value > self.threshold | ||
|
||
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
if self.monitor_mode == MonitorMode.MIN: | ||
return self._is_new_value_lower(new_value, old_value) | ||
if self.monitor_mode == MonitorMode.MAX: | ||
return self._is_new_value_higher(new_value, old_value) | ||
|
||
|
||
class RelativeThresholdChecker(ThresholdChecker): | ||
""" | ||
A threshold checker that checks if the relative difference between the current value and the previous best value | ||
is greater than or equal to a given threshold | ||
Args: | ||
threshold - the threshold to check (must be non-negative) | ||
""" | ||
def __init__(self, monitor_mode: MonitorMode, threshold: float = 0.0): | ||
super(RelativeThresholdChecker, self).__init__(monitor_mode, threshold) | ||
|
||
def _is_new_value_lower(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
return (old_value - new_value) / old_value > self.threshold | ||
|
||
def _is_new_value_higher(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
return (new_value - old_value) / old_value > self.threshold | ||
|
||
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool: | ||
if self.monitor_mode == MonitorMode.MIN: | ||
return self._is_new_value_lower(new_value, old_value) | ||
if self.monitor_mode == MonitorMode.MAX: | ||
return self._is_new_value_higher(new_value, old_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import unittest | ||
|
||
from lpd.enums import MonitorMode | ||
|
||
|
||
class TestUtils(unittest.TestCase): | ||
|
||
def test_absolute_threshold_checker__true(self): | ||
from lpd.utils.threshold_checker import AbsoluteThresholdChecker | ||
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.899), (0.1, 0.9, 0.799)]: | ||
min_checker = AbsoluteThresholdChecker(MonitorMode.MIN, threshold) | ||
with self.subTest(): | ||
self.assertTrue(min_checker(new_value=lower_value, old_value=higher_value)) | ||
|
||
max_checker = AbsoluteThresholdChecker(MonitorMode.MAX, threshold) | ||
with self.subTest(): | ||
self.assertTrue(max_checker(new_value=higher_value, old_value=lower_value)) | ||
|
||
def test_absolute_threshold_checker__false(self): | ||
from lpd.utils.threshold_checker import AbsoluteThresholdChecker | ||
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.9), (0.1, 0.9, 0.81)]: | ||
min_checker = AbsoluteThresholdChecker(MonitorMode.MIN, threshold) | ||
with self.subTest(): | ||
self.assertFalse(min_checker(new_value=lower_value, old_value=higher_value)) | ||
|
||
max_checker = AbsoluteThresholdChecker(MonitorMode.MAX, threshold) | ||
with self.subTest(): | ||
self.assertFalse(max_checker(new_value=higher_value, old_value=lower_value)) | ||
|
||
def test_relative_threshold_checker__true(self): | ||
from lpd.utils.threshold_checker import RelativeThresholdChecker | ||
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.899), (0.1, 120.1, 100.0)]: | ||
min_checker = RelativeThresholdChecker(MonitorMode.MIN, threshold) | ||
with self.subTest(): | ||
self.assertTrue(min_checker(new_value=lower_value, old_value=higher_value)) | ||
max_checker = RelativeThresholdChecker(MonitorMode.MAX, threshold) | ||
with self.subTest(): | ||
self.assertTrue(max_checker(new_value=higher_value, old_value=lower_value)) | ||
|
||
def test_relative_threshold_checker__false(self): | ||
from lpd.utils.threshold_checker import RelativeThresholdChecker | ||
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.9), (0.1, 109.99, 100.0)]: | ||
min_checker = RelativeThresholdChecker(MonitorMode.MIN, threshold) | ||
with self.subTest(): | ||
self.assertFalse(min_checker(new_value=lower_value, old_value=higher_value)) | ||
max_checker = RelativeThresholdChecker(MonitorMode.MAX, threshold) | ||
with self.subTest(): | ||
self.assertFalse(max_checker(new_value=higher_value, old_value=lower_value)) |