From 0da58f0c90d9895ff8af71fb3c5b3bd7aac1ecb4 Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Thu, 12 Oct 2023 10:46:33 +0900 Subject: [PATCH] Re-introduce adaptive training (#2543) * Re-introduce adaptive patience for training * Revert unit tests --- .../mmcv/hooks/adaptive_training_hook.py | 30 +++++++++++++++++-- .../hooks/test_adaptive_training_hooks.py | 4 +-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/otx/algorithms/common/adapters/mmcv/hooks/adaptive_training_hook.py b/src/otx/algorithms/common/adapters/mmcv/hooks/adaptive_training_hook.py index 64a7be8f338..f9e6fb345ff 100644 --- a/src/otx/algorithms/common/adapters/mmcv/hooks/adaptive_training_hook.py +++ b/src/otx/algorithms/common/adapters/mmcv/hooks/adaptive_training_hook.py @@ -23,9 +23,15 @@ class AdaptiveTrainSchedulingHook(Hook): """Adaptive Training Scheduling Hook. - Depending on the size of iteration per epoch, adaptively update the validation interval. + Depending on the size of iteration per epoch, adaptively update the validation interval and related values. Args: + base_lr_patience (int): The value of LR drop patience are expected in total epoch. + Patience used when interval is 1, Defaults to 5. + min_lr_patience (int): Minumum value of LR drop patience. + Defaults to 2. + base_es_patience (int): The value of Early-Stopping patience are expected in total epoch. + Patience used when interval is 1, Defaults to 10. max_interval (int): Maximum value of validation interval. Defaults to 5. decay (float): Parameter to control the interval. This value is set by manual manner. @@ -39,6 +45,10 @@ class AdaptiveTrainSchedulingHook(Hook): def __init__( self, max_interval=5, + base_lr_patience=5, + min_lr_patience=2, + base_es_patience=10, + min_es_patience=3, decay=-0.025, enable_adaptive_interval_hook=False, enable_eval_before_run=False, @@ -47,6 +57,10 @@ def __init__( super().__init__(**kwargs) self.max_interval = max_interval + self.base_lr_patience = base_lr_patience + self.min_lr_patience = min_lr_patience + self.base_es_patience = base_es_patience + self.min_es_patience = min_es_patience self.decay = decay self.enable_adaptive_interval_hook = enable_adaptive_interval_hook self.enable_eval_before_run = enable_eval_before_run @@ -84,13 +98,23 @@ def before_train_iter(self, runner): logger.info(f"Update EvalHook interval: {hook.interval} -> {adaptive_interval}") hook.interval = adaptive_interval elif isinstance(hook, LrUpdaterHook): + patience = max( + math.ceil((self.base_lr_patience / adaptive_interval)), + self.min_lr_patience, + ) if hasattr(hook, "interval") and hasattr(hook, "patience"): hook.interval = adaptive_interval - logger.info(f"Update LrUpdaterHook interval: {hook.interval} -> {adaptive_interval}") + hook.patience = patience + logger.info(f"Update LrUpdaterHook patience: {hook.patience} -> {patience}") elif isinstance(hook, EarlyStoppingHook): - logger.info(f"Update EarlyStoppingHook interval: {hook.interval} -> {adaptive_interval}") + patience = max( + math.ceil((self.base_es_patience / adaptive_interval)), + self.min_es_patience, + ) + logger.info(f"Update EarlyStoppingHook patience: {hook.patience} -> {patience}") hook.start = adaptive_interval hook.interval = adaptive_interval + hook.patience = patience elif isinstance(hook, CheckpointHook): # make sure checkpoint is saved at last limit = runner.max_epochs if hook.by_epoch else runner.max_iters diff --git a/tests/unit/algorithms/common/adapters/mmcv/hooks/test_adaptive_training_hooks.py b/tests/unit/algorithms/common/adapters/mmcv/hooks/test_adaptive_training_hooks.py index 71716effd62..51a30756b97 100644 --- a/tests/unit/algorithms/common/adapters/mmcv/hooks/test_adaptive_training_hooks.py +++ b/tests/unit/algorithms/common/adapters/mmcv/hooks/test_adaptive_training_hooks.py @@ -86,7 +86,7 @@ def test_before_train_iter(self) -> None: assert hook._original_interval is None assert eval_hook.interval == 4 assert lr_hook.interval == 4 - assert lr_hook.patience == 1 + assert lr_hook.patience == 2 assert early_hook.interval == 4 - assert early_hook.patience == 1 + assert early_hook.patience == 3 assert ckpt_hook.interval == 4