diff --git a/assets/images/coverage.svg b/assets/images/coverage.svg index a9d4cc45..d199f255 100644 --- a/assets/images/coverage.svg +++ b/assets/images/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 32% - 32% + 33% + 33% diff --git a/lighter/utils/cli.py b/lighter/utils/cli.py index 4c387d6b..1b7227a9 100644 --- a/lighter/utils/cli.py +++ b/lighter/utils/cli.py @@ -7,96 +7,52 @@ import yaml from loguru import logger from monai.bundle.scripts import run -from monai.utils.misc import ensure_tuple from lighter.utils.dynamic_imports import import_module_from_path - -# Trainer methods calls in YAML format with support for command line arguments. -# Waiting for https://github.com/Project-MONAI/MONAI/pull/5854#issuecomment-1384800886. -trainer_methods = { - "fit": yaml.safe_load( - """ - fit: - _method_: > - $@trainer.fit(model=@fit#model, - ckpt_path=@fit#ckpt_path) - model: "@system" - ckpt_path: null""" - ), - "validate": yaml.safe_load( - """ - validate: - _method_: > - $@trainer.validate(model=@validate#model, - ckpt_path=@validate#ckpt_path, - verbose=@validate#verbose) - model: "@system" - ckpt_path: null - verbose: True""" - ), - "predict": yaml.safe_load( - """ - predict: - _method_: > - $@trainer.predict(model=@predict#model, - ckpt_path=@predict#ckpt_path) - model: "@system" - ckpt_path: null""" - ), - "test": yaml.safe_load( - """ - test: - _method_: > - $@trainer.test(model=@test#model, - ckpt_path=@test#ckpt_path, - verbose=@test#verbose) - model: "@system" - ckpt_path: null - verbose: True""" - ), - "tune": yaml.safe_load( - """ - tune: - _method_: > - $@trainer.tune(model=@tune#model, - ckpt_path=@tune#ckpt_path, - scale_batch_size_kwargs=@tune#scale_batch_size_kwargs, - lr_find_kwargs=@tune#lr_find_kwargs, - method=@tune#method) - model: "@system" - ckpt_path: null - scale_batch_size_kwargs: null - lr_find_kwargs: null - method: fit - """ - ), -} +from lighter.utils.misc import ensure_list def interface(): - fire.Fire({k: partial(run_trainer_method, v) for k, v in trainer_methods.items()}) + """Defines the command line interface for running Trainer's methods. The available methods are: + - fit + - validate + - test + - predict + - tune + """ + commands = {} + # All Trainer methods' names + for name in ["fit", "validate", "test", "predict", "tune"]: + # Creates the configuration for the Trainer method 'name' and sets 'system' as the model to run on. + config = {name: {"_target_": f"$@trainer.{name}", "model": "@system"}} + # Set the method name and config. Fire collects **kwargs from the CLI and passes them to the method. + commands[name] = partial(run_trainer_method, name, config) + fire.Fire(commands) -def run_trainer_method(trainer_method: Dict, **kwargs: Any): +def run_trainer_method(method_name, method_config: Dict, **kwargs: Any): """Call monai.bundle.run() on a Trainer method. If a project path is defined in the config file(s), import it. Args: - trainer_method (Dict): Trainer method to run, collection defined in `trainer_methods` - **kwargs (Any): keyword arguments for the monai.bundle.run function. + method_name: name of the Trainer method to run. ["fit", "validate", "test", "predict", "tune"]. + method_config: config definition of the Trainer method. + **kwargs (Any): keyword arguments passed to the `monai.bundle.run` function. """ + # Import the project as a module. if "config_file" in kwargs: project_imported = False - # Handle multiple configs - for config in ensure_tuple(kwargs["config_file"]): - config = yaml.safe_load(open(config)) - # Import the project as a module - if "project" in config: - # Only one config file can specify the project path - if project_imported: - logger.error("`project` must be specified in one config only. Exiting.") - sys.exit() - import_module_from_path("project", config["project"]) - project_imported = True - # Run the Trainer method - run(**trainer_method, **kwargs) + # Handle multiple configs. Start from the config file specified last as it overrides the previous ones. + for config in reversed(ensure_list(kwargs["config_file"])): + config = yaml.safe_load(open(config, "r")) + if "project" not in config: + continue + # Only one config file can specify the project path + if project_imported: + logger.error("`project` must be specified in one config only. Exiting.") + sys.exit() + # Import it as a module named 'project'. + import_module_from_path("project", config["project"]) + project_imported = True + # Run the Trainer method. + run(method_name, **method_config, **kwargs) diff --git a/tests/integration/test_cifar.py b/tests/integration/test_cifar.py index 354d42e3..317db9d3 100644 --- a/tests/integration/test_cifar.py +++ b/tests/integration/test_cifar.py @@ -1,19 +1,27 @@ """Tests for running CIFAR training to verify integrity of the pipeline""" import pytest -from lighter.utils.cli import run_trainer_method, trainer_methods +from lighter.utils.cli import run_trainer_method test_overrides = "./tests/integration/test_overrides.yaml" @pytest.mark.parametrize( - ("mode", "config_file"), - [("fit", "./projects/cifar10/experiments/monai_bundle_prototype.yaml")], + ("method_name", "method_config", "config_file"), + [ + ( # Method name + "fit", + # Method config + {"fit": {"_target_": "$@trainer.fit", "model": "@system"}}, + # Config fiile + "./projects/cifar10/experiments/monai_bundle_prototype.yaml", + ) + ], ) @pytest.mark.slow -def test_trainer_method(mode: str, config_file: str): - """Test trainer method for different mode configurations""" +def test_trainer_method(method_name: str, method_config: dict, config_file: str): + """ """ kwargs = {"config_file": config_file, "args_file": test_overrides} - func_return = run_trainer_method(trainer_methods[mode], **kwargs) - assert func_return == None + func_return = run_trainer_method(method_name, method_config, **kwargs) + assert func_return is None