Skip to content

Commit

Permalink
[Auto3DSeg] Add mlflow support in autorunner. (#7176)
Browse files Browse the repository at this point in the history
Add MLflow support in AutoRunner Class.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: dongy <dongy@nvidia.com>
  • Loading branch information
dongyang0122 authored Nov 2, 2023
1 parent cf886e7 commit c3f9914
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
5 changes: 5 additions & 0 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class AutoRunner:
zip url will be downloaded and extracted into the work_dir.
allow_skip: a switch passed to BundleGen process which determines if some Algo in the default templates
can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer.
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote
tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None.
kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
Expand Down Expand Up @@ -209,6 +211,7 @@ def __init__(
not_use_cache: bool = False,
templates_path_or_url: str | None = None,
allow_skip: bool = True,
mlflow_tracking_uri: str | None = None,
**kwargs: Any,
):
logger.info(f"AutoRunner using work directory {work_dir}")
Expand All @@ -220,6 +223,7 @@ def __init__(
self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.allow_skip = allow_skip
self.mlflow_tracking_uri = mlflow_tracking_uri
self.kwargs = deepcopy(kwargs)

if input is None and os.path.isfile(self.data_src_cfg_name):
Expand Down Expand Up @@ -783,6 +787,7 @@ def run(self):
templates_path_or_url=self.templates_path_or_url,
data_stats_filename=self.datastats_filename,
data_src_cfg_name=self.data_src_cfg_name,
mlflow_tracking_uri=self.mlflow_tracking_uri,
)

if self.gpu_customization:
Expand Down
34 changes: 34 additions & 0 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, template_path: PathLike):
self.template_path = template_path
self.data_stats_files = ""
self.data_list_file = ""
self.mlflow_tracking_uri = None
self.output_path = ""
self.name = ""
self.best_metric = None
Expand Down Expand Up @@ -129,6 +130,17 @@ def set_data_source(self, data_src_cfg: str) -> None:
"""
self.data_list_file = data_src_cfg

def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None:
"""
Set the tracking URI for MLflow server
Args:
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
"""
self.mlflow_tracking_uri = mlflow_tracking_uri # type: ignore

def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict:
"""
The configuration files defined when constructing this Algo instance might not have a complete training
Expand Down Expand Up @@ -432,6 +444,9 @@ class BundleGen(AlgoGen):
data_stats_filename: the path to the data stats file (generated by DataAnalyzer).
data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
{"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}.
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
.. code-block:: bash
python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
Expand All @@ -444,6 +459,7 @@ def __init__(
templates_path_or_url: str | None = None,
data_stats_filename: str | None = None,
data_src_cfg_name: str | None = None,
mlflow_tracking_uri: str | None = None,
):
if algos is None or isinstance(algos, (list, tuple, str)):
if templates_path_or_url is None:
Expand Down Expand Up @@ -496,6 +512,7 @@ def __init__(

self.data_stats_filename = data_stats_filename
self.data_src_cfg_name = data_src_cfg_name
self.mlflow_tracking_uri = mlflow_tracking_uri
self.history: list[dict] = []

def set_data_stats(self, data_stats_filename: str) -> None:
Expand Down Expand Up @@ -524,6 +541,21 @@ def get_data_src(self):
"""Get the data source filename"""
return self.data_src_cfg_name

def set_mlflow_tracking_uri(self, mlflow_tracking_uri):
"""
Set the tracking URI for MLflow server
Args:
mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
the value is None.
"""
self.mlflow_tracking_uri = mlflow_tracking_uri

def get_mlflow_tracking_uri(self):
"""Get the tracking URI for MLflow server"""
return self.mlflow_tracking_uri

def get_history(self) -> list:
"""Get the history of the bundleAlgo object with their names/identifiers"""
return self.history
Expand Down Expand Up @@ -575,9 +607,11 @@ def generate(
for f_id in ensure_tuple(fold_idx):
data_stats = self.get_data_stats()
data_src_cfg = self.get_data_src()
mlflow_tracking_uri = self.get_mlflow_tracking_uri()
gen_algo = deepcopy(algo)
gen_algo.set_data_stats(data_stats)
gen_algo.set_data_source(data_src_cfg)
gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri)
name = f"{gen_algo.name}_{f_id}"

if allow_skip:
Expand Down

0 comments on commit c3f9914

Please sign in to comment.