Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify CLI #25

Merged
merged 6 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions assets/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
114 changes: 35 additions & 79 deletions lighter/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,96 +7,52 @@
import yaml
from loguru import logger
from monai.bundle.scripts import run
from monai.utils.misc import ensure_tuple

from lighter.utils.dynamic_imports import import_module_from_path

# Trainer methods calls in YAML format with support for command line arguments.
# Waiting for https://github.com/Project-MONAI/MONAI/pull/5854#issuecomment-1384800886.
trainer_methods = {
"fit": yaml.safe_load(
"""
fit:
_method_: >
$@trainer.fit(model=@fit#model,
ckpt_path=@fit#ckpt_path)
model: "@system"
ckpt_path: null"""
),
"validate": yaml.safe_load(
"""
validate:
_method_: >
$@trainer.validate(model=@validate#model,
ckpt_path=@validate#ckpt_path,
verbose=@validate#verbose)
model: "@system"
ckpt_path: null
verbose: True"""
),
"predict": yaml.safe_load(
"""
predict:
_method_: >
$@trainer.predict(model=@predict#model,
ckpt_path=@predict#ckpt_path)
model: "@system"
ckpt_path: null"""
),
"test": yaml.safe_load(
"""
test:
_method_: >
$@trainer.test(model=@test#model,
ckpt_path=@test#ckpt_path,
verbose=@test#verbose)
model: "@system"
ckpt_path: null
verbose: True"""
),
"tune": yaml.safe_load(
"""
tune:
_method_: >
$@trainer.tune(model=@tune#model,
ckpt_path=@tune#ckpt_path,
scale_batch_size_kwargs=@tune#scale_batch_size_kwargs,
lr_find_kwargs=@tune#lr_find_kwargs,
method=@tune#method)
model: "@system"
ckpt_path: null
scale_batch_size_kwargs: null
lr_find_kwargs: null
method: fit
"""
),
}
from lighter.utils.misc import ensure_list


def interface():
fire.Fire({k: partial(run_trainer_method, v) for k, v in trainer_methods.items()})
"""Defines the command line interface for running Trainer's methods. The available methods are:
- fit
- validate
- test
- predict
- tune
"""
commands = {}
# All Trainer methods' names
for name in ["fit", "validate", "test", "predict", "tune"]:
# Creates the configuration for the Trainer method 'name' and sets 'system' as the model to run on.
config = {name: {"_target_": f"$@trainer.{name}", "model": "@system"}}
# Set the method name and config. Fire collects **kwargs from the CLI and passes them to the method.
commands[name] = partial(run_trainer_method, name, config)
fire.Fire(commands)


def run_trainer_method(trainer_method: Dict, **kwargs: Any):
def run_trainer_method(method_name, method_config: Dict, **kwargs: Any):
"""Call monai.bundle.run() on a Trainer method. If a project path
is defined in the config file(s), import it.

Args:
trainer_method (Dict): Trainer method to run, collection defined in `trainer_methods`
**kwargs (Any): keyword arguments for the monai.bundle.run function.
method_name: name of the Trainer method to run. ["fit", "validate", "test", "predict", "tune"].
method_config: config definition of the Trainer method.
**kwargs (Any): keyword arguments passed to the `monai.bundle.run` function.
"""
# Import the project as a module.
if "config_file" in kwargs:
project_imported = False
# Handle multiple configs
for config in ensure_tuple(kwargs["config_file"]):
config = yaml.safe_load(open(config))
# Import the project as a module
if "project" in config:
# Only one config file can specify the project path
if project_imported:
logger.error("`project` must be specified in one config only. Exiting.")
sys.exit()
import_module_from_path("project", config["project"])
project_imported = True
# Run the Trainer method
run(**trainer_method, **kwargs)
# Handle multiple configs. Start from the config file specified last as it overrides the previous ones.
for config in reversed(ensure_list(kwargs["config_file"])):
config = yaml.safe_load(open(config, "r"))
if "project" not in config:
continue
# Only one config file can specify the project path
if project_imported:
logger.error("`project` must be specified in one config only. Exiting.")
sys.exit()
# Import it as a module named 'project'.
import_module_from_path("project", config["project"])
project_imported = True
# Run the Trainer method.
run(method_name, **method_config, **kwargs)
22 changes: 15 additions & 7 deletions tests/integration/test_cifar.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
"""Tests for running CIFAR training to verify integrity of the pipeline"""
import pytest

from lighter.utils.cli import run_trainer_method, trainer_methods
from lighter.utils.cli import run_trainer_method

test_overrides = "./tests/integration/test_overrides.yaml"


@pytest.mark.parametrize(
("mode", "config_file"),
[("fit", "./projects/cifar10/experiments/monai_bundle_prototype.yaml")],
("method_name", "method_config", "config_file"),
[
( # Method name
"fit",
# Method config
{"fit": {"_target_": "$@trainer.fit", "model": "@system"}},
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
)
],
)
@pytest.mark.slow
def test_trainer_method(mode: str, config_file: str):
"""Test trainer method for different mode configurations"""
def test_trainer_method(method_name: str, method_config: dict, config_file: str):
""" """
kwargs = {"config_file": config_file, "args_file": test_overrides}

func_return = run_trainer_method(trainer_methods[mode], **kwargs)
assert func_return == None
func_return = run_trainer_method(method_name, method_config, **kwargs)
assert func_return is None