Skip to content

Commit

Permalink
4116 Add support for advanced args of AMP (#4132)
Browse files Browse the repository at this point in the history
* [DLMED] fix typo in bundle scripts

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] add support for AMP args

Signed-off-by: Nic Ma <nma@nvidia.com>

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

* [DLMED] fix flake8

Signed-off-by: Nic Ma <nma@nvidia.com>

Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
  • Loading branch information
Nic-Ma and monai-bot authored Apr 19, 2022
1 parent e1a62f2 commit 4a8f815
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 5 deletions.
4 changes: 2 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def ckpt_export(
.. code-block:: bash
python -m monai.bundle export network --filepath <export path> --ckpt_file <checkpoint path> ...
python -m monai.bundle ckpt_export network --filepath <export path> --ckpt_file <checkpoint path> ...
Args:
net_id: ID name of the network component in the config, it must be `torch.nn.Module`.
Expand All @@ -390,7 +390,7 @@ def ckpt_export(
key_in_ckpt=key_in_ckpt,
**override,
)
_log_input_summary(tag="export", args=_args)
_log_input_summary(tag="ckpt_export", args=_args)
filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args(
_args, "filepath", "ckpt_file", "config_file", net_id="", meta_file=None, key_in_ckpt=""
)
Expand Down
16 changes: 14 additions & 2 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class Evaluator(Workflow):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
"""

Expand All @@ -98,6 +100,7 @@ def __init__(
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -117,6 +120,7 @@ def __init__(
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)
mode = look_up_option(mode, ForwardMode)
if mode == ForwardMode.EVAL:
Expand Down Expand Up @@ -187,6 +191,8 @@ class SupervisedEvaluator(Evaluator):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
"""

Expand All @@ -211,6 +217,7 @@ def __init__(
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -230,6 +237,7 @@ def __init__(
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

self.network = network
Expand Down Expand Up @@ -269,7 +277,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
# execute forward computation
with self.mode(self.network):
if self.amp:
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore
else:
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore
Expand Down Expand Up @@ -326,6 +334,8 @@ class EnsembleEvaluator(Evaluator):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
"""

Expand All @@ -351,6 +361,7 @@ def __init__(
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -370,6 +381,7 @@ def __init__(
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

self.networks = ensure_tuple(networks)
Expand Down Expand Up @@ -417,7 +429,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
for idx, network in enumerate(self.networks):
with self.mode(network):
if self.amp:
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore
if isinstance(engine.state.output, dict):
engine.state.output.update(
{self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
Expand Down
10 changes: 9 additions & 1 deletion monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class SupervisedTrainer(Trainer):
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
"""

Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
decollate: bool = True,
optim_set_to_none: bool = False,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -153,6 +156,7 @@ def __init__(
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)

self.network = network
Expand Down Expand Up @@ -202,7 +206,7 @@ def _compute_pred_loss():
self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)

if self.amp and self.scaler is not None:
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore
_compute_pred_loss()
self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
Expand Down Expand Up @@ -275,6 +279,8 @@ class GanTrainer(Trainer):
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
"""

Expand Down Expand Up @@ -307,6 +313,7 @@ def __init__(
decollate: bool = True,
optim_set_to_none: bool = False,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
):
if not isinstance(train_data_loader, DataLoader):
raise ValueError("train_data_loader must be PyTorch DataLoader.")
Expand All @@ -327,6 +334,7 @@ def __init__(
postprocessing=postprocessing,
decollate=decollate,
to_kwargs=to_kwargs,
amp_kwargs=amp_kwargs,
)
self.g_network = g_network
self.g_optimizer = g_optimizer
Expand Down
4 changes: 4 additions & 0 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
Raises:
TypeError: When ``device`` is not a ``torch.Device``.
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
Expand Down Expand Up @@ -170,6 +173,7 @@ def set_sampler_epoch(engine: Engine):
self.metric_cmp_fn = metric_cmp_fn
self.amp = amp
self.to_kwargs = {} if to_kwargs is None else to_kwargs
self.amp_kwargs = {} if amp_kwargs is None else amp_kwargs
self.scaler: Optional[torch.cuda.amp.GradScaler] = None

if event_names is None:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _forward_completed(self, engine):
val_handlers=val_handlers,
amp=bool(amp),
to_kwargs={"memory_format": torch.preserve_format},
amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32},
)

train_postprocessing = Compose(
Expand Down Expand Up @@ -204,6 +205,7 @@ def _model_completed(self, engine):
amp=bool(amp),
optim_set_to_none=True,
to_kwargs={"memory_format": torch.preserve_format},
amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32},
)
trainer.run()

Expand Down

0 comments on commit 4a8f815

Please sign in to comment.