Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7107 Add support to validate at training start #7108

Merged
merged 9 commits into from
Oct 10, 2023
10 changes: 9 additions & 1 deletion monai/handlers/validation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ class ValidationHandler:
"""

def __init__(self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True) -> None:
def __init__(
self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True, exec_at_start: bool = False
) -> None:
"""
Args:
interval: do validation every N epochs or every N iterations during training.
validator: run the validator when trigger validation, suppose to be Evaluator.
if None, should call `set_validator()` before training.
epoch_level: execute validation every N epochs or N iterations.
`True` is epoch level, `False` is iteration level.
exec_at_start: whether to execute a validation first when starting the training.
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
default to `False`. It can be useful especially for some transfer-learning cases
to validate the initial model before training.
Raises:
TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``.
Expand All @@ -49,6 +54,7 @@ def __init__(self, interval: int, validator: Evaluator | None = None, epoch_leve
self.validator = validator
self.interval = interval
self.epoch_level = epoch_level
self.exec_at_start = exec_at_start

def set_validator(self, validator: Evaluator) -> None:
"""
Expand All @@ -67,6 +73,8 @@ def attach(self, engine: Engine) -> None:
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self)
else:
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self)
if self.exec_at_start:
engine.add_event_handler(Events.STARTED, self)

def __call__(self, engine: Engine) -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions tests/test_handler_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ def _train_func(engine, batch):
# set up testing handler
val_data_loader = torch.utils.data.DataLoader(Dataset(data))
evaluator = TestEvaluator(torch.device("cpu:0"), val_data_loader)
saver = ValidationHandler(interval=2, validator=evaluator)
saver.attach(engine)
ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine)
# test execution at start
engine.run(data, max_epochs=1)
self.assertEqual(evaluator.state.max_epochs, 0)
self.assertEqual(evaluator.state.epoch_length, 8)

engine.run(data, max_epochs=5)
self.assertEqual(evaluator.state.max_epochs, 4)
Expand Down