diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py index 2105f051bd..23fb3eb250 100644 --- a/monai/apps/auto3dseg/auto_runner.py +++ b/monai/apps/auto3dseg/auto_runner.py @@ -407,7 +407,7 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int: def set_gpu_customization( self, gpu_customization: bool = False, gpu_customization_specs: dict[str, Any] | None = None - ) -> None: + ) -> AutoRunner: """ Set options for GPU-based parameter customization/optimization. @@ -442,7 +442,9 @@ def set_gpu_customization( if gpu_customization_specs is not None: self.gpu_customization_specs = gpu_customization_specs - def set_num_fold(self, num_fold: int = 5) -> None: + return self + + def set_num_fold(self, num_fold: int = 5) -> AutoRunner: """ Set the number of cross validation folds for all algos. @@ -454,7 +456,9 @@ def set_num_fold(self, num_fold: int = 5) -> None: raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}") self.num_fold = num_fold - def set_training_params(self, params: dict[str, Any] | None = None) -> None: + return self + + def set_training_params(self, params: dict[str, Any] | None = None) -> AutoRunner: """ Set the training params for all algos. @@ -474,13 +478,15 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None: DeprecationWarning, ) + return self + def set_device_info( self, cuda_visible_devices: list[int] | str | None = None, num_nodes: int | None = None, mn_start_method: str | None = None, cmd_prefix: str | None = None, - ) -> None: + ) -> AutoRunner: """ Set the device related info @@ -531,7 +537,9 @@ def set_device_info( if cmd_prefix is not None: logger.info(f"Using user defined command running prefix {cmd_prefix}, will override other settings") - def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None: + return self + + def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> AutoRunner: """ Set the bundle ensemble method name and parameters for save image transform parameters. @@ -546,7 +554,9 @@ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFol ) self.kwargs.update(kwargs) - def set_image_save_transform(self, **kwargs: Any) -> None: + return self + + def set_image_save_transform(self, **kwargs: Any) -> AutoRunner: """ Set the ensemble output transform. @@ -565,7 +575,9 @@ def set_image_save_transform(self, **kwargs: Any) -> None: "Check https://docs.monai.io/en/stable/transforms.html#saveimage for more information." ) - def set_prediction_params(self, params: dict[str, Any] | None = None) -> None: + return self + + def set_prediction_params(self, params: dict[str, Any] | None = None) -> AutoRunner: """ Set the prediction params for all algos. @@ -581,7 +593,9 @@ def set_prediction_params(self, params: dict[str, Any] | None = None) -> None: """ self.pred_params = deepcopy(params) if params is not None else {} - def set_analyze_params(self, params: dict[str, Any] | None = None) -> None: + return self + + def set_analyze_params(self, params: dict[str, Any] | None = None) -> AutoRunner: """ Set the data analysis extra params. @@ -595,7 +609,9 @@ def set_analyze_params(self, params: dict[str, Any] | None = None) -> None: else: self.analyze_params = deepcopy(params) - def set_hpo_params(self, params: dict[str, Any] | None = None) -> None: + return self + + def set_hpo_params(self, params: dict[str, Any] | None = None) -> AutoRunner: """ Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the @@ -621,7 +637,9 @@ def set_hpo_params(self, params: dict[str, Any] | None = None) -> None: """ self.hpo_params = self.train_params if params is None else params - def set_nni_search_space(self, search_space): + return self + + def set_nni_search_space(self, search_space: dict[str, Any]) -> AutoRunner: """ Set the search space for NNI parameter search. @@ -638,6 +656,8 @@ def set_nni_search_space(self, search_space): self.search_space = search_space self.hpo_tasks = value_combinations + return self + def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None: """ Train the Algos in a sequential scheme. The order of training is randomized.