From 313b13841a6804b99abb64b4fcf19118f85cfbe2 Mon Sep 17 00:00:00 2001 From: Tian Luan Date: Sat, 17 Feb 2024 02:30:55 -0500 Subject: [PATCH] Implement river drift detector --- src/drift_detector/river_drift_detector.py | 35 +++++++++++++++++++ tests/drift_detector/.gitkeep | 0 .../test_river_drift_detector.py | 27 ++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 src/drift_detector/river_drift_detector.py delete mode 100644 tests/drift_detector/.gitkeep create mode 100644 tests/drift_detector/test_river_drift_detector.py diff --git a/src/drift_detector/river_drift_detector.py b/src/drift_detector/river_drift_detector.py new file mode 100644 index 0000000..9fa0c3c --- /dev/null +++ b/src/drift_detector/river_drift_detector.py @@ -0,0 +1,35 @@ +from drift_detector.base_drift_detector import BaseDriftDetector +from river import drift +from numpy import array, mean +from typing import Callable + + +class RiverDriftDetector(BaseDriftDetector): + def __init__( + self, + drift_detect_algo: str = 'ADWIN', + agg_func: Callable[[array], float] = lambda x: mean(x) + ) -> None: + super().__init__() + if drift_detect_algo == 'ADWIN': + self.drift_detector = drift.ADWIN() + else: + raise ValueError(f"Support for algorithm {drift_detect_algo} not implemented yet") + + if not callable(agg_func): + raise TypeError("Aggregation function must be a callable.") + self.agg_func = agg_func + + def is_drifted(self, feat_vec: array) -> bool: + """ + Check if the given feature vector indicates drift. + + Parameters: + feat_vec (array): The feature vector to be checked for drift. + + Returns: + bool: True if the feature vector indicates drift, False otherwise. + """ + val = self.agg_func(feat_vec) + self.drift_detector.update(val) + return self.drift_detector.drift_detected \ No newline at end of file diff --git a/tests/drift_detector/.gitkeep b/tests/drift_detector/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/tests/drift_detector/test_river_drift_detector.py b/tests/drift_detector/test_river_drift_detector.py new file mode 100644 index 0000000..f0a7c45 --- /dev/null +++ b/tests/drift_detector/test_river_drift_detector.py @@ -0,0 +1,27 @@ +from unittest.mock import patch +from numpy import mean, array +from drift_detector.river_drift_detector import RiverDriftDetector + +def test_initialization(): + with patch('river.drift.ADWIN') as MockADWIN: + RiverDriftDetector() + MockADWIN.assert_called_once() + +def test_agg_func_usage(): + test_data = array([1, 2, 3, 4, 5]) + custom_agg_func = lambda x: sum(x) / len(x) # Same as mean + detector = RiverDriftDetector(agg_func=custom_agg_func) + assert detector.agg_func(test_data) == mean(test_data), "Aggregation function doesn't work as expected." + +@patch('river.drift.ADWIN') +def test_is_drifted(MockADWIN): + mock_adwin_instance = MockADWIN.return_value + mock_adwin_instance.drift_detected = False + detector = RiverDriftDetector() + + # No drift + assert not detector.is_drifted(array([1, 2, 3])), "Shouldn't detect drift." + + # Drift exists + mock_adwin_instance.drift_detected = True + assert detector.is_drifted(array([4, 5, 6])), "Should detect drift."