From daf66159bc0a777533566bb2e23b41a6095d9ae8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 12 Apr 2022 20:26:00 +0800 Subject: [PATCH 1/5] [DLMED] add kwargs for to() API Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 15 ++++++- monai/engines/trainer.py | 14 ++++++- monai/engines/utils.py | 52 +++++++++++++++++++------ monai/engines/workflow.py | 4 ++ tests/test_integration_workflows.py | 2 + tests/test_integration_workflows_gan.py | 1 + 6 files changed, 72 insertions(+), 16 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c3e8c456b7..0e9e2ddd31 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -74,6 +74,8 @@ class Evaluator(Workflow): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -95,6 +97,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + **to_kwargs, ) -> None: super().__init__( device=device, @@ -181,6 +184,8 @@ class SupervisedEvaluator(Evaluator): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -204,6 +209,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + **to_kwargs, ) -> None: super().__init__( device=device, @@ -222,6 +228,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + **to_kwargs, ) self.network = network @@ -245,7 +252,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -314,6 +321,8 @@ class EnsembleEvaluator(Evaluator): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -338,6 +347,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + **to_kwargs, ) -> None: super().__init__( device=device, @@ -356,6 +366,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + **to_kwargs, ) self.networks = ensure_tuple(networks) @@ -387,7 +398,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 774e535e7f..9faf250f10 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -105,6 +105,8 @@ class SupervisedTrainer(Trainer): default to `True`. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -131,6 +133,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, optim_set_to_none: bool = False, + **to_kwargs, ) -> None: super().__init__( device=device, @@ -149,6 +152,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + **to_kwargs, ) self.network = network @@ -176,7 +180,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -271,6 +275,8 @@ class GanTrainer(Trainer): default to `True`. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. """ @@ -302,6 +308,7 @@ def __init__( train_handlers: Optional[Sequence] = None, decollate: bool = True, optim_set_to_none: bool = False, + **to_kwargs, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -321,6 +328,7 @@ def __init__( handlers=train_handlers, postprocessing=postprocessing, decollate=decollate, + **to_kwargs, ) self.g_network = g_network self.g_optimizer = g_optimizer @@ -353,13 +361,14 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore batch_size = self.data_loader.batch_size # type: ignore g_input = self.g_prepare_batch( num_latents=batch_size, latent_size=self.latent_shape, device=engine.state.device, # type: ignore non_blocking=engine.non_blocking, # type: ignore + **engine.to_kwargs, # type: ignore ) g_output = self.g_inferer(g_input, self.g_network) @@ -383,6 +392,7 @@ def _iteration( latent_size=self.latent_shape, device=engine.state.device, # type: ignore non_blocking=engine.non_blocking, # type: ignore + **engine.to_kwargs, # type: ignore ) g_output = self.g_inferer(g_input, self.g_network) if not pytorch_after(1, 7): diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 726dfc8e98..6d0055970d 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -104,12 +104,16 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t def default_prepare_batch( - batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: """ Default function to prepare the data for current iteration. - Refer to ignite: https://pytorch.org/ignite/v0.4.5/generated/ignite.engine.create_supervised_trainer.html - #ignite.engine.create_supervised_trainer. + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. Returns: image, label(optional). @@ -119,18 +123,21 @@ def default_prepare_batch( raise AssertionError("default prepare_batch expects dictionary input data.") if isinstance(batchdata.get(CommonKeys.LABEL), torch.Tensor): return ( - batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), - batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking), + batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), + batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking, **kwargs), ) if GanKeys.REALS in batchdata: - return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking) - return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None + return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking, **kwargs) + return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), None class PrepareBatch(ABC): """ Interface of customized prepare_batch in the trainer or evaluator workflows. It takes the data of current batch, target device and non_blocking flag as input. + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. """ @@ -140,6 +147,7 @@ def __call__( batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + **kwargs, ): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -155,8 +163,15 @@ def __call__( batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + **kwargs, ): - return default_prepare_batch(batchdata, device, non_blocking) + """ + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. + + """ + return default_prepare_batch(batchdata, device, non_blocking, **kwargs) class PrepareBatchExtraInput(PrepareBatch): @@ -181,14 +196,23 @@ def __call__( batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + **kwargs, ): - image, label = default_prepare_batch(batchdata, device, non_blocking) + """ + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. + + """ + image, label = super().__call__(batchdata, device, non_blocking, **kwargs) args = list() kwargs = dict() def _get_data(key: str): data = batchdata[key] - return data.to(device=device, non_blocking=non_blocking) if isinstance(data, torch.Tensor) else data + return ( + data.to(device=device, non_blocking=non_blocking, **kwargs) if isinstance(data, torch.Tensor) else data + ) if isinstance(self.extra_keys, (str, list, tuple)): for k in ensure_tuple(self.extra_keys): @@ -201,9 +225,13 @@ def _get_data(key: str): def default_make_latent( - num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False + num_latents: int, + latent_size: int, + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, ) -> torch.Tensor: - return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) + return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking, **kwargs) def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]): diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 65bb313e53..7b6e0ad79d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -94,6 +94,8 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. + to_kwargs: other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. Raises: TypeError: When ``device`` is not a ``torch.Device``. @@ -121,6 +123,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, + **to_kwargs, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -166,6 +169,7 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.metric_cmp_fn = metric_cmp_fn self.amp = amp + self.to_kwargs = to_kwargs self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 2e515772d3..6748e73590 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -148,6 +148,7 @@ def _forward_completed(self, engine): metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=bool(amp), + memory_format=torch.preserve_format, ) train_postprocessing = Compose( @@ -202,6 +203,7 @@ def _model_completed(self, engine): train_handlers=train_handlers, amp=bool(amp), optim_set_to_none=True, + memory_format=torch.preserve_format, ) trainer.run() diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index c9306b349f..8af55196e1 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -117,6 +117,7 @@ def generator_loss(gen_images): latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=train_handlers, + memory_format=torch.preserve_format, ) trainer.run() From 37371bf03d70cabaf9c35c05b50a676e5dd77f75 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 12 Apr 2022 12:38:07 +0000 Subject: [PATCH 2/5] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/engines/evaluator.py | 8 ++++++-- monai/engines/trainer.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 0e9e2ddd31..56a37c2b24 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -252,7 +252,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore + batch = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs + ) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -398,7 +400,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore + batch = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs + ) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 5c125bbdab..0d63876ed4 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -180,7 +180,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore + batch = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs + ) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -357,7 +359,9 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) # type: ignore + d_input = self.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs + ) # type: ignore batch_size = self.data_loader.batch_size # type: ignore g_input = self.g_prepare_batch( num_latents=batch_size, From 0762e8000e9e21067e85dc0047ebf6545e09139c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 12 Apr 2022 22:44:10 +0800 Subject: [PATCH 3/5] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 1 + monai/engines/utils.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 0e9e2ddd31..23e31ce950 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -116,6 +116,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, + **to_kwargs, ) mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 6d0055970d..8f3a57beda 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -204,9 +204,9 @@ def __call__( `kwargs` supports other args for `Tensor.to()` API. """ - image, label = super().__call__(batchdata, device, non_blocking, **kwargs) - args = list() - kwargs = dict() + image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs) + args_ = list() + kwargs_ = dict() def _get_data(key: str): data = batchdata[key] @@ -216,12 +216,12 @@ def _get_data(key: str): if isinstance(self.extra_keys, (str, list, tuple)): for k in ensure_tuple(self.extra_keys): - args.append(_get_data(k)) + args_.append(_get_data(k)) elif isinstance(self.extra_keys, dict): for k, v in self.extra_keys.items(): - kwargs.update({k: _get_data(v)}) + kwargs_.update({k: _get_data(v)}) - return image, label, tuple(args), kwargs + return image, label, tuple(args_), kwargs_ def default_make_latent( From e126877a07d8f505c6ccc5ef20a96ae697b0025a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 12 Apr 2022 23:54:08 +0800 Subject: [PATCH 4/5] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 8 ++++---- monai/engines/trainer.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 48927ccc8f..caeef19c1c 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -254,8 +254,8 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) # type: ignore + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -402,8 +402,8 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) # type: ignore + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) if len(batch) == 2: inputs, targets = batch args: Tuple = () diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 0d63876ed4..03e383b85c 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -181,8 +181,8 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) # type: ignore + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -360,8 +360,8 @@ def _iteration( raise ValueError("must provide batch data for current iteration.") d_input = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) # type: ignore + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + ) batch_size = self.data_loader.batch_size # type: ignore g_input = self.g_prepare_batch( num_latents=batch_size, From 461b209a96c02ef947636b13109b3a516d786964 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 13 Apr 2022 20:19:27 +0800 Subject: [PATCH 5/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 18 +++++++++--------- monai/engines/trainer.py | 12 ++++++------ monai/engines/workflow.py | 6 +++--- tests/test_integration_workflows.py | 4 ++-- tests/test_integration_workflows_gan.py | 2 +- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index caeef19c1c..f9dab35450 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -74,7 +74,7 @@ class Evaluator(Workflow): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. - to_kwargs: other args for `prepare_batch` API when converting the input data, except for + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. """ @@ -97,7 +97,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, - **to_kwargs, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -116,7 +116,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, - **to_kwargs, + to_kwargs=to_kwargs, ) mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: @@ -185,7 +185,7 @@ class SupervisedEvaluator(Evaluator): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. - to_kwargs: other args for `prepare_batch` API when converting the input data, except for + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. """ @@ -210,7 +210,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, - **to_kwargs, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -229,7 +229,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, - **to_kwargs, + to_kwargs=to_kwargs, ) self.network = network @@ -324,7 +324,7 @@ class EnsembleEvaluator(Evaluator): decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. - to_kwargs: other args for `prepare_batch` API when converting the input data, except for + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. """ @@ -350,7 +350,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, - **to_kwargs, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -369,7 +369,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, - **to_kwargs, + to_kwargs=to_kwargs, ) self.networks = ensure_tuple(networks) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 03e383b85c..16c50d4fa2 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -105,7 +105,7 @@ class SupervisedTrainer(Trainer): default to `True`. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. - to_kwargs: other args for `prepare_batch` API when converting the input data, except for + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. """ @@ -133,7 +133,7 @@ def __init__( event_to_attr: Optional[dict] = None, decollate: bool = True, optim_set_to_none: bool = False, - **to_kwargs, + to_kwargs: Optional[Dict] = None, ) -> None: super().__init__( device=device, @@ -152,7 +152,7 @@ def __init__( event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, - **to_kwargs, + to_kwargs=to_kwargs, ) self.network = network @@ -273,7 +273,7 @@ class GanTrainer(Trainer): default to `True`. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. - to_kwargs: other args for `prepare_batch` API when converting the input data, except for + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. """ @@ -306,7 +306,7 @@ def __init__( train_handlers: Optional[Sequence] = None, decollate: bool = True, optim_set_to_none: bool = False, - **to_kwargs, + to_kwargs: Optional[Dict] = None, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -326,7 +326,7 @@ def __init__( handlers=train_handlers, postprocessing=postprocessing, decollate=decollate, - **to_kwargs, + to_kwargs=to_kwargs, ) self.g_network = g_network self.g_optimizer = g_optimizer diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 7b6e0ad79d..4ea0a69d55 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -94,7 +94,7 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. - to_kwargs: other args for `prepare_batch` API when converting the input data, except for + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. Raises: @@ -123,7 +123,7 @@ def __init__( event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, - **to_kwargs, + to_kwargs: Optional[Dict] = None, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -169,7 +169,7 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.metric_cmp_fn = metric_cmp_fn self.amp = amp - self.to_kwargs = to_kwargs + self.to_kwargs = {} if to_kwargs is None else to_kwargs self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 6748e73590..fafdf43522 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -148,7 +148,7 @@ def _forward_completed(self, engine): metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=bool(amp), - memory_format=torch.preserve_format, + to_kwargs={"memory_format": torch.preserve_format}, ) train_postprocessing = Compose( @@ -203,7 +203,7 @@ def _model_completed(self, engine): train_handlers=train_handlers, amp=bool(amp), optim_set_to_none=True, - memory_format=torch.preserve_format, + to_kwargs={"memory_format": torch.preserve_format}, ) trainer.run() diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index 8af55196e1..f65a30450a 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -117,7 +117,7 @@ def generator_loss(gen_images): latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=train_handlers, - memory_format=torch.preserve_format, + to_kwargs={"memory_format": torch.preserve_format, "dtype": torch.float32}, ) trainer.run()