Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3942 3966 adds _requires_ and import support #3945

Merged
merged 12 commits into from
Mar 21, 2022
44 changes: 36 additions & 8 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import inspect
import os
import sys
Expand All @@ -18,7 +19,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

from monai.bundle.utils import EXPR_KEY
from monai.utils import ensure_tuple, instantiate
from monai.utils import ensure_tuple, instantiate, optional_import

__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"]

Expand Down Expand Up @@ -164,17 +165,22 @@ class ConfigComponent(ConfigItem, Instantiable):
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to
represent a component of `class` or `function` and supports instantiation.

Currently, two special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:

- class or function identifier of the python module, specified by one of the two keys.
- ``"_target_"``: indicates build-in python classes or functions such as "LoadImageDict",
or full module name, such as "monai.transforms.LoadImageDict".
- class or function identifier of the python module, specified by ``"_target_"``,
indicating a build-in python class or function such as ``"LoadImageDict"``,
or a full module name, such as ``"monai.transforms.LoadImageDict"``.
- ``"_requires_"``: specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression``
wyli marked this conversation as resolved.
Show resolved Hide resolved
of the dependencies for this ``ConfigComponent`` object. These dependencies will be
evaluated/instantiated before this object is instantiated.
- ``"_disabled_"``: a flag to indicate whether to skip the instantiation.

Other fields in the config content are input arguments to the python module.

.. code-block:: python

from monai.bundle import ComponentLocator, ConfigComponent

locator = ComponentLocator(excludes=["modules_to_exclude"])
config = {
"_target_": "LoadImaged",
Expand All @@ -195,7 +201,7 @@ class ConfigComponent(ConfigItem, Instantiable):

"""

non_arg_keys = {"_target_", "_disabled_"}
non_arg_keys = {"_target_", "_disabled_", "_requires_"}

def __init__(
self,
Expand Down Expand Up @@ -279,7 +285,7 @@ def instantiate(self, **kwargs) -> object: # type: ignore
class ConfigExpression(ConfigItem):
"""
Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression
(execute based on ``eval()``).
(execute based on ``eval()``, or import the module to the `globals` if it's an import statement).

See also:

Expand Down Expand Up @@ -308,7 +314,26 @@ class ConfigExpression(ConfigItem):

def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None:
super().__init__(config=config, id=id)
self.globals = globals
self.globals = globals if globals is not None else {}

def _parse_import_string(self, import_string: str):
# parse single import statement such as "from monai.transforms import Resize"
for n in ast.iter_child_nodes(ast.parse(import_string)):
if not isinstance(n, (ast.Import, ast.ImportFrom)):
return None
if len(n.names) < 1:
return None
if len(n.names) > 1:
warnings.warn(f"ignoring multiple import alias '{import_string}'.")
name, asname = f"{n.names[0].name}", n.names[0].asname
asname = name if asname is None else f"{asname}"
if isinstance(n, ast.ImportFrom):
self.globals[asname], _ = optional_import(f"{n.module}", name=f"{name}")
return self.globals[asname]
elif isinstance(n, ast.Import):
self.globals[asname], _ = optional_import(f"{name}")
return self.globals[asname]
return None

def evaluate(self, locals: Optional[Dict] = None):
"""
Expand All @@ -322,6 +347,9 @@ def evaluate(self, locals: Optional[Dict] = None):
value = self.get_config()
if not ConfigExpression.is_expression(value):
return None
optional_module = self._parse_import_string(value[len(self.prefix) :])
if optional_module is not None:
return optional_module
if not self.run_eval:
return f"{value[len(self.prefix) :]}"
return eval(value[len(self.prefix) :], self.globals, locals)
Expand Down
15 changes: 9 additions & 6 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import json
import re
from copy import deepcopy
Expand All @@ -26,6 +25,8 @@

__all__ = ["ConfigParser"]

_default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}


class ConfigParser:
"""
Expand Down Expand Up @@ -74,7 +75,7 @@ class ConfigParser:
so that expressions, for example, ``"$monai.data.list_data_collate"`` can use ``monai`` modules.
The current supported globals and alias names are
``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``.
These are MONAI's minimal dependencies.
These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`.

See also:

Expand All @@ -96,10 +97,12 @@ def __init__(
):
self.config = None
self.globals: Dict[str, Any] = {}
globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} if globals is None else globals
if globals is not None:
for k, v in globals.items():
self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v
_globals = _default_globals.copy()
if isinstance(_globals, dict) and globals is not None:
_globals.update(globals)
if _globals is not None:
for k, v in _globals.items():
self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v

self.locator = ComponentLocator(excludes=excludes)
self.ref_resolver = ReferenceResolver()
Expand Down
19 changes: 19 additions & 0 deletions tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,25 @@ def test_lazy_instantiation(self):
self.assertTrue(isinstance(ret, DataLoader))
self.assertEqual(ret.batch_size, 4)

@parameterized.expand([("$import json", "json"), ("$import json as j", "j")])
def test_import(self, stmt, mod_name):
test_globals = {}
ConfigExpression(id="", config=stmt, globals=test_globals).evaluate()
self.assertTrue(callable(test_globals[mod_name].dump))

@parameterized.expand(
[
("$from json import dump", "dump"),
("$from json import dump, dumps", "dump"),
("$from json import dump as jd", "jd"),
("$from json import dump as jd, dumps as ds", "jd"),
]
)
def test_import_from(self, stmt, mod_name):
test_globals = {}
ConfigExpression(id="", config=stmt, globals=test_globals).evaluate()
self.assertTrue(callable(test_globals[mod_name]))


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/testing_data/inference.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
{
"dataset_dir": "/workspace/data/Task09_Spleen",
"import_glob": "$import glob",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"set_seed": "$monai.utils.set_determinism(0)",
"print_test_name": "$print('json_test')",
"print_glob_file": "$print(@import_glob.__file__)",
wyli marked this conversation as resolved.
Show resolved Hide resolved
"network_def": {
"_target_": "UNet",
"spatial_dims": 3,
Expand Down Expand Up @@ -94,6 +98,12 @@
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"_requires_": [
"@set_seed",
"@print_test_name",
"@print_glob_file",
"$print('test_in_line_json')"
],
"device": "@device",
"val_data_loader": "@dataloader",
"network": "@network",
Expand Down
6 changes: 6 additions & 0 deletions tests/testing_data/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
---
dataset_dir: "/workspace/data/Task09_Spleen"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
set_seed: "$monai.utils.set_determinism(0)"
print_test_name: "$print('yaml_test')"
network_def:
_target_: UNet
spatial_dims: 3
Expand Down Expand Up @@ -67,6 +69,10 @@ postprocessing:
output_dir: "@_meta_#output_dir"
evaluator:
_target_: SupervisedEvaluator
_requires_:
- "$print('test_in_line_yaml')"
- "@set_seed"
- "@print_test_name"
device: "@device"
val_data_loader: "@dataloader"
network: "@network"
Expand Down