From f8c265531c29e9add7ecf38f887474d2b815447d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 13 Apr 2022 21:01:31 +0800 Subject: [PATCH] 4084 Add kwargs for `Tensor.to()` in engines (#4112) * [DLMED] add kwargs for to() API Signed-off-by: Nic Ma * [MONAI] python code formatting Signed-off-by: monai-bot * [DLMED] fix typo Signed-off-by: Nic Ma * [DLMED] fix flake8 Signed-off-by: Nic Ma * [DLMED] update according to comments Signed-off-by: Nic Ma Co-authored-by: monai-bot --- monai/engines/evaluator.py | 20 +++++++- monai/engines/trainer.py | 18 ++++++- monai/engines/utils.py | 62 ++++++++++++++++++------- monai/engines/workflow.py | 4 ++ tests/test_integration_workflows.py | 2 + tests/test_integration_workflows_gan.py | 1 + 6 files changed, 86 insertions(+), 21 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c3e8c456b7..f9dab35450 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -74,6 +74,8 @@ class Evaluator(Workflow): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -95,6 +97,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -113,6 +116,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + to_kwargs=to_kwargs, ) mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: @@ -181,6 +185,8 @@ class SupervisedEvaluator(Evaluator): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -204,6 +210,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -222,6 +229,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + to_kwargs=to_kwargs, ) self.network = network @@ -245,7 +253,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + batch = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -314,6 +324,8 @@ class EnsembleEvaluator(Evaluator): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -338,6 +350,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -356,6 +369,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + to_kwargs=to_kwargs, ) self.networks = ensure_tuple(networks) @@ -387,7 +401,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + batch = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) if len(batch) == 2: inputs, targets = batch args: Tuple = () diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a58387a5ef..16c50d4fa2 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -105,6 +105,8 @@ class SupervisedTrainer(Trainer): default to `True`. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -131,6 +133,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, optim_set_to_none: bool = False, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -149,6 +152,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + to_kwargs=to_kwargs, ) self.network = network @@ -176,7 +180,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + batch = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -267,6 +273,8 @@ class GanTrainer(Trainer): default to `True`. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -298,6 +306,7 @@ def __init__( train_handlers: Optional[Sequence] = None, decollate: bool = True, optim_set_to_none: bool = False, + to_kwargs: Optional[Dict] = None, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -317,6 +326,7 @@ def __init__( handlers=train_handlers, postprocessing=postprocessing, decollate=decollate, + to_kwargs=to_kwargs, ) self.g_network = g_network self.g_optimizer = g_optimizer @@ -349,13 +359,16 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + d_input = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) batch_size = self.data_loader.batch_size # type: ignore g_input = self.g_prepare_batch( num_latents=batch_size, latent_size=self.latent_shape, device=engine.state.device, # type: ignore non_blocking=engine.non_blocking, # type: ignore + **engine.to_kwargs, # type: ignore ) g_output = self.g_inferer(g_input, self.g_network) @@ -375,6 +388,7 @@ def _iteration( latent_size=self.latent_shape, device=engine.state.device, # type: ignore non_blocking=engine.non_blocking, # type: ignore + **engine.to_kwargs, # type: ignore ) g_output = self.g_inferer(g_input, self.g_network) self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 726dfc8e98..8f3a57beda 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -104,12 +104,16 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t def default_prepare_batch( - batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: """ Default function to prepare the data for current iteration. - Refer to ignite: https://pytorch.org/ignite/v0.4.5/generated/ignite.engine.create_supervised_trainer.html - #ignite.engine.create_supervised_trainer. + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. Returns: image, label(optional). @@ -119,18 +123,21 @@ def default_prepare_batch( raise AssertionError("default prepare_batch expects dictionary input data.") if isinstance(batchdata.get(CommonKeys.LABEL), torch.Tensor): return ( - batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), - batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking), + batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), + batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking, **kwargs), ) if GanKeys.REALS in batchdata: - return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking) - return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None + return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking, **kwargs) + return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), None class PrepareBatch(ABC): """ Interface of customized prepare_batch in the trainer or evaluator workflows. It takes the data of current batch, target device and non_blocking flag as input. + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. """ @@ -140,6 +147,7 @@ def __call__( batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + **kwargs, ): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -155,8 +163,15 @@ def __call__( batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + **kwargs, ): - return default_prepare_batch(batchdata, device, non_blocking) + """ + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. + + """ + return default_prepare_batch(batchdata, device, non_blocking, **kwargs) class PrepareBatchExtraInput(PrepareBatch): @@ -181,29 +196,42 @@ def __call__( batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + **kwargs, ): - image, label = default_prepare_batch(batchdata, device, non_blocking) - args = list() - kwargs = dict() + """ + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. + + """ + image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + args_ = list() + kwargs_ = dict() def _get_data(key: str): data = batchdata[key] - return data.to(device=device, non_blocking=non_blocking) if isinstance(data, torch.Tensor) else data + return ( + data.to(device=device, non_blocking=non_blocking, **kwargs) if isinstance(data, torch.Tensor) else data + ) if isinstance(self.extra_keys, (str, list, tuple)): for k in ensure_tuple(self.extra_keys): - args.append(_get_data(k)) + args_.append(_get_data(k)) elif isinstance(self.extra_keys, dict): for k, v in self.extra_keys.items(): - kwargs.update({k: _get_data(v)}) + kwargs_.update({k: _get_data(v)}) - return image, label, tuple(args), kwargs + return image, label, tuple(args_), kwargs_ def default_make_latent( - num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False + num_latents: int, + latent_size: int, + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, ) -> torch.Tensor: - return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) + return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking, **kwargs) def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]): diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 65bb313e53..4ea0a69d55 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -94,6 +94,8 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. Raises: TypeError: When ``device`` is not a ``torch.Device``. @@ -121,6 +123,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + to_kwargs: Optional[Dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -166,6 +169,7 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.metric_cmp_fn = metric_cmp_fn self.amp = amp + self.to_kwargs = {} if to_kwargs is None else to_kwargs self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 2e515772d3..fafdf43522 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -148,6 +148,7 @@ def _forward_completed(self, engine): metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=bool(amp), + to_kwargs={"memory_format": torch.preserve_format}, ) train_postprocessing = Compose( @@ -202,6 +203,7 @@ def _model_completed(self, engine): train_handlers=train_handlers, amp=bool(amp), optim_set_to_none=True, + to_kwargs={"memory_format": torch.preserve_format}, ) trainer.run() diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index c9306b349f..f65a30450a 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -117,6 +117,7 @@ def generator_loss(gen_images): latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=train_handlers, + to_kwargs={"memory_format": torch.preserve_format, "dtype": torch.float32}, ) trainer.run()