From 3b26f256d190c592111e613866eff8f79183d241 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 7 Dec 2021 00:24:01 +0800 Subject: [PATCH 1/4] [DLMED] add BaseWorkflow Signed-off-by: Nic Ma --- docs/source/engines.rst | 10 ++++++---- monai/engines/__init__.py | 1 + monai/engines/workflow.py | 16 ++++++++++++++-- monai/networks/blocks/activation.py | 2 ++ 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/docs/source/engines.rst b/docs/source/engines.rst index cc0ec3c659..90c7be2a1e 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -15,16 +15,18 @@ Multi-GPU data parallel Workflows --------- -.. automodule:: monai.engines.workflow -.. currentmodule:: monai.engines.workflow +.. currentmodule:: monai.engines + +`BaseWorkflow` +~~~~~~~~~~~~~~ +.. autoclass:: BaseWorkflow + :members: `Workflow` ~~~~~~~~~~ .. autoclass:: Workflow :members: -.. currentmodule:: monai.engines - `Trainer` ~~~~~~~~~ .. autoclass:: Trainer diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d04401829f..f24bc0fc37 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -24,3 +24,4 @@ engine_apply_transform, get_devices_spec, ) +from .workflow import BaseWorkflow, Workflow diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 48e2dc1774..66bae19922 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -10,6 +10,7 @@ # limitations under the License. import warnings +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch @@ -24,7 +25,6 @@ from .utils import engine_apply_transform -IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") State, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "State") Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -37,7 +37,19 @@ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import +class BaseWorkflow(ABC): + """ + Base class for any MONAI style workflow. + `run()` is designed to execute the train, evaluation or inference logic. + + """ + + @abstractmethod + def run(self): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class Workflow(Engine): # type: ignore[valid-type, misc] # due to optional_import """ Workflow defines the core work process inheriting from Ignite engine. All trainer, validator and evaluator share this same workflow as base class, diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 9b58be04e8..b136eb7f1f 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -19,6 +19,7 @@ def monai_mish(x, inplace: bool = False): return torch.nn.functional.mish(x, inplace=inplace) + else: def monai_mish(x, inplace: bool = False): @@ -30,6 +31,7 @@ def monai_mish(x, inplace: bool = False): def monai_swish(x, inplace: bool = False): return torch.nn.functional.silu(x, inplace=inplace) + else: def monai_swish(x, inplace: bool = False): From 0f8488857eadec13646bbdd32a4b4edeeb55abf4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 7 Dec 2021 00:32:14 +0800 Subject: [PATCH 2/4] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/engines/workflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 66bae19922..785478d4ec 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -25,6 +25,7 @@ from .utils import engine_apply_transform +IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") State, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "State") Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -49,7 +50,7 @@ def run(self): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class Workflow(Engine): # type: ignore[valid-type, misc] # due to optional_import +class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import """ Workflow defines the core work process inheriting from Ignite engine. All trainer, validator and evaluator share this same workflow as base class, From bb211860d0b449649509fba8218ace82ad39466f Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 6 Dec 2021 16:38:11 +0000 Subject: [PATCH 3/4] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/networks/blocks/activation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index b136eb7f1f..9b58be04e8 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -19,7 +19,6 @@ def monai_mish(x, inplace: bool = False): return torch.nn.functional.mish(x, inplace=inplace) - else: def monai_mish(x, inplace: bool = False): @@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False): def monai_swish(x, inplace: bool = False): return torch.nn.functional.silu(x, inplace=inplace) - else: def monai_swish(x, inplace: bool = False): From 0dd0308e6b35895f5e73af93879f5215b9eb6e7b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 7 Dec 2021 08:16:07 +0800 Subject: [PATCH 4/4] [DLMED] add *args, **kwargs Signed-off-by: Nic Ma --- monai/engines/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 785478d4ec..f6f0a6a059 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -46,7 +46,7 @@ class BaseWorkflow(ABC): """ @abstractmethod - def run(self): + def run(self, *args, **kwargs): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")