Skip to content

Commit

Permalink
4084 Add kwargs for Tensor.to() in engines (#4112)
Browse files Browse the repository at this point in the history
* [DLMED] add kwargs for to() API

Signed-off-by: Nic Ma <nma@nvidia.com>

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

* [DLMED] fix typo

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix flake8

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <nma@nvidia.com>

Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
  • Loading branch information
Nic-Ma and monai-bot authored Apr 13, 2022
1 parent 1880d38 commit f8c2655
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 21 deletions.
20 changes: 18 additions & 2 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class Evaluator(Workflow):
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
"""

Expand All @@ -95,6 +97,7 @@ def __init__(
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -113,6 +116,7 @@ def __init__(
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
)
mode = look_up_option(mode, ForwardMode)
if mode == ForwardMode.EVAL:
Expand Down Expand Up @@ -181,6 +185,8 @@ class SupervisedEvaluator(Evaluator):
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
"""

Expand All @@ -204,6 +210,7 @@ def __init__(
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -222,6 +229,7 @@ def __init__(
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
)

self.network = network
Expand All @@ -245,7 +253,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
batch = self.prepare_batch(
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore
)
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
Expand Down Expand Up @@ -314,6 +324,8 @@ class EnsembleEvaluator(Evaluator):
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
"""

Expand All @@ -338,6 +350,7 @@ def __init__(
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -356,6 +369,7 @@ def __init__(
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
)

self.networks = ensure_tuple(networks)
Expand Down Expand Up @@ -387,7 +401,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
batch = self.prepare_batch(
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore
)
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
Expand Down
18 changes: 16 additions & 2 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class SupervisedTrainer(Trainer):
default to `True`.
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
"""

Expand All @@ -131,6 +133,7 @@ def __init__(
event_to_attr: Optional[dict] = None,
decollate: bool = True,
optim_set_to_none: bool = False,
to_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -149,6 +152,7 @@ def __init__(
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
to_kwargs=to_kwargs,
)

self.network = network
Expand Down Expand Up @@ -176,7 +180,9 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
batch = self.prepare_batch(
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore
)
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
Expand Down Expand Up @@ -267,6 +273,8 @@ class GanTrainer(Trainer):
default to `True`.
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
"""

Expand Down Expand Up @@ -298,6 +306,7 @@ def __init__(
train_handlers: Optional[Sequence] = None,
decollate: bool = True,
optim_set_to_none: bool = False,
to_kwargs: Optional[Dict] = None,
):
if not isinstance(train_data_loader, DataLoader):
raise ValueError("train_data_loader must be PyTorch DataLoader.")
Expand All @@ -317,6 +326,7 @@ def __init__(
handlers=train_handlers,
postprocessing=postprocessing,
decollate=decollate,
to_kwargs=to_kwargs,
)
self.g_network = g_network
self.g_optimizer = g_optimizer
Expand Down Expand Up @@ -349,13 +359,16 @@ def _iteration(
if batchdata is None:
raise ValueError("must provide batch data for current iteration.")

d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
d_input = self.prepare_batch(
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore
)
batch_size = self.data_loader.batch_size # type: ignore
g_input = self.g_prepare_batch(
num_latents=batch_size,
latent_size=self.latent_shape,
device=engine.state.device, # type: ignore
non_blocking=engine.non_blocking, # type: ignore
**engine.to_kwargs, # type: ignore
)
g_output = self.g_inferer(g_input, self.g_network)

Expand All @@ -375,6 +388,7 @@ def _iteration(
latent_size=self.latent_shape,
device=engine.state.device, # type: ignore
non_blocking=engine.non_blocking, # type: ignore
**engine.to_kwargs, # type: ignore
)
g_output = self.g_inferer(g_input, self.g_network)
self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
Expand Down
62 changes: 45 additions & 17 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t


def default_prepare_batch(
batchdata: Dict[str, torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
"""
Default function to prepare the data for current iteration.
Refer to ignite: https://pytorch.org/ignite/v0.4.5/generated/ignite.engine.create_supervised_trainer.html
#ignite.engine.create_supervised_trainer.
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
Returns:
image, label(optional).
Expand All @@ -119,18 +123,21 @@ def default_prepare_batch(
raise AssertionError("default prepare_batch expects dictionary input data.")
if isinstance(batchdata.get(CommonKeys.LABEL), torch.Tensor):
return (
batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking),
batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking),
batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs),
batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking, **kwargs),
)
if GanKeys.REALS in batchdata:
return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking)
return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), None
return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking, **kwargs)
return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), None


class PrepareBatch(ABC):
"""
Interface of customized prepare_batch in the trainer or evaluator workflows.
It takes the data of current batch, target device and non_blocking flag as input.
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
"""

Expand All @@ -140,6 +147,7 @@ def __call__(
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
):
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

Expand All @@ -155,8 +163,15 @@ def __call__(
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
):
return default_prepare_batch(batchdata, device, non_blocking)
"""
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
"""
return default_prepare_batch(batchdata, device, non_blocking, **kwargs)


class PrepareBatchExtraInput(PrepareBatch):
Expand All @@ -181,29 +196,42 @@ def __call__(
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
):
image, label = default_prepare_batch(batchdata, device, non_blocking)
args = list()
kwargs = dict()
"""
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
"""
image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
args_ = list()
kwargs_ = dict()

def _get_data(key: str):
data = batchdata[key]
return data.to(device=device, non_blocking=non_blocking) if isinstance(data, torch.Tensor) else data
return (
data.to(device=device, non_blocking=non_blocking, **kwargs) if isinstance(data, torch.Tensor) else data
)

if isinstance(self.extra_keys, (str, list, tuple)):
for k in ensure_tuple(self.extra_keys):
args.append(_get_data(k))
args_.append(_get_data(k))
elif isinstance(self.extra_keys, dict):
for k, v in self.extra_keys.items():
kwargs.update({k: _get_data(v)})
kwargs_.update({k: _get_data(v)})

return image, label, tuple(args), kwargs
return image, label, tuple(args_), kwargs_


def default_make_latent(
num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
num_latents: int,
latent_size: int,
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
) -> torch.Tensor:
return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking)
return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking, **kwargs)


def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]):
Expand Down
4 changes: 4 additions & 0 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
Raises:
TypeError: When ``device`` is not a ``torch.Device``.
Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
Expand Down Expand Up @@ -166,6 +169,7 @@ def set_sampler_epoch(engine: Engine):
self.prepare_batch = prepare_batch
self.metric_cmp_fn = metric_cmp_fn
self.amp = amp
self.to_kwargs = {} if to_kwargs is None else to_kwargs
self.scaler: Optional[torch.cuda.amp.GradScaler] = None

if event_names is None:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _forward_completed(self, engine):
metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric
val_handlers=val_handlers,
amp=bool(amp),
to_kwargs={"memory_format": torch.preserve_format},
)

train_postprocessing = Compose(
Expand Down Expand Up @@ -202,6 +203,7 @@ def _model_completed(self, engine):
train_handlers=train_handlers,
amp=bool(amp),
optim_set_to_none=True,
to_kwargs={"memory_format": torch.preserve_format},
)
trainer.run()

Expand Down
1 change: 1 addition & 0 deletions tests/test_integration_workflows_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def generator_loss(gen_images):
latent_shape=latent_size,
key_train_metric=key_train_metric,
train_handlers=train_handlers,
to_kwargs={"memory_format": torch.preserve_format, "dtype": torch.float32},
)
trainer.run()

Expand Down

0 comments on commit f8c2655

Please sign in to comment.