From 4a8f8150888b999f35535ffeb09e2559c48ca0c8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 20 Apr 2022 06:51:30 +0800 Subject: [PATCH] 4116 Add support for advanced args of AMP (#4132) * [DLMED] fix typo in bundle scripts Signed-off-by: Nic Ma * [DLMED] add support for AMP args Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix flake8 Signed-off-by: Nic Ma Co-authored-by: monai-bot --- monai/bundle/scripts.py | 4 ++-- monai/engines/evaluator.py | 16 ++++++++++++++-- monai/engines/trainer.py | 10 +++++++++- monai/engines/workflow.py | 4 ++++ tests/test_integration_workflows.py | 2 ++ 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 64172c4541..b741e40e8d 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -363,7 +363,7 @@ def ckpt_export( .. code-block:: bash - python -m monai.bundle export network --filepath --ckpt_file ... + python -m monai.bundle ckpt_export network --filepath --ckpt_file ... Args: net_id: ID name of the network component in the config, it must be `torch.nn.Module`. @@ -390,7 +390,7 @@ def ckpt_export( key_in_ckpt=key_in_ckpt, **override, ) - _log_input_summary(tag="export", args=_args) + _log_input_summary(tag="ckpt_export", args=_args) filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args( _args, "filepath", "ckpt_file", "config_file", net_id="", meta_file=None, key_in_ckpt="" ) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index f9dab35450..c69c0e0547 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -76,6 +76,8 @@ class Evaluator(Workflow): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. """ @@ -98,6 +100,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, to_kwargs: Optional[Dict] = None, + amp_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -117,6 +120,7 @@ def __init__( event_to_attr=event_to_attr, decollate=decollate, to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, ) mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: @@ -187,6 +191,8 @@ class SupervisedEvaluator(Evaluator): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. """ @@ -211,6 +217,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, to_kwargs: Optional[Dict] = None, + amp_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -230,6 +237,7 @@ def __init__( event_to_attr=event_to_attr, decollate=decollate, to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, ) self.network = network @@ -269,7 +277,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): # execute forward computation with self.mode(self.network): if self.amp: - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore else: engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore @@ -326,6 +334,8 @@ class EnsembleEvaluator(Evaluator): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. """ @@ -351,6 +361,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, to_kwargs: Optional[Dict] = None, + amp_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -370,6 +381,7 @@ def __init__( event_to_attr=event_to_attr, decollate=decollate, to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, ) self.networks = ensure_tuple(networks) @@ -417,7 +429,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore if isinstance(engine.state.output, dict): engine.state.output.update( {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 16c50d4fa2..12753765ef 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -107,6 +107,8 @@ class SupervisedTrainer(Trainer): more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. """ @@ -134,6 +136,7 @@ def __init__( decollate: bool = True, optim_set_to_none: bool = False, to_kwargs: Optional[Dict] = None, + amp_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -153,6 +156,7 @@ def __init__( event_to_attr=event_to_attr, decollate=decollate, to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, ) self.network = network @@ -202,7 +206,7 @@ def _compute_pred_loss(): self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) if self.amp and self.scaler is not None: - with torch.cuda.amp.autocast(): + with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore _compute_pred_loss() self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) @@ -275,6 +279,8 @@ class GanTrainer(Trainer): more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. """ @@ -307,6 +313,7 @@ def __init__( decollate: bool = True, optim_set_to_none: bool = False, to_kwargs: Optional[Dict] = None, + amp_kwargs: Optional[Dict] = None, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -327,6 +334,7 @@ def __init__( postprocessing=postprocessing, decollate=decollate, to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, ) self.g_network = g_network self.g_optimizer = g_optimizer diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 4ea0a69d55..75123da153 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -96,6 +96,8 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. Raises: TypeError: When ``device`` is not a ``torch.Device``. @@ -124,6 +126,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, to_kwargs: Optional[Dict] = None, + amp_kwargs: Optional[Dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -170,6 +173,7 @@ def set_sampler_epoch(engine: Engine): self.metric_cmp_fn = metric_cmp_fn self.amp = amp self.to_kwargs = {} if to_kwargs is None else to_kwargs + self.amp_kwargs = {} if amp_kwargs is None else amp_kwargs self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index fafdf43522..688f664089 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -149,6 +149,7 @@ def _forward_completed(self, engine): val_handlers=val_handlers, amp=bool(amp), to_kwargs={"memory_format": torch.preserve_format}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, ) train_postprocessing = Compose( @@ -204,6 +205,7 @@ def _model_completed(self, engine): amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, ) trainer.run()