diff --git a/pympipool/scheduler/__init__.py b/pympipool/scheduler/__init__.py index 3138bc2f..118e6b46 100644 --- a/pympipool/scheduler/__init__.py +++ b/pympipool/scheduler/__init__.py @@ -12,8 +12,8 @@ check_threads_per_core, check_oversubscribe, check_executor, - check_backend, check_init_function, + validate_backend, validate_number_of_cores, ) from pympipool.scheduler.slurm import ( @@ -87,8 +87,10 @@ def create_executor( """ max_cores = validate_number_of_cores(max_cores=max_cores, max_workers=max_workers) check_init_function(block_allocation=block_allocation, init_function=init_function) - check_backend(backend=backend) - if backend == "flux" or (backend == "auto" and flux_installed): + backend = validate_backend( + backend=backend, flux_installed=flux_installed, slurm_installed=slurm_installed + ) + if backend == "flux": check_oversubscribe(oversubscribe=oversubscribe) check_command_line_argument_lst( command_line_argument_lst=command_line_argument_lst @@ -114,7 +116,7 @@ def create_executor( executor=executor, hostname_localhost=hostname_localhost, ) - elif backend == "slurm" or (backend == "auto" and slurm_installed): + elif backend == "slurm": check_executor(executor=executor) if block_allocation: return PySlurmExecutor( diff --git a/pympipool/shared/inputcheck.py b/pympipool/shared/inputcheck.py index 11595dfe..036b8c20 100644 --- a/pympipool/shared/inputcheck.py +++ b/pympipool/shared/inputcheck.py @@ -67,7 +67,9 @@ def check_refresh_rate(refresh_rate: float): ) -def check_backend(backend: str): +def validate_backend( + backend: str, flux_installed: bool = False, slurm_installed: bool = False +) -> str: if backend not in ["auto", "mpi", "slurm", "flux"]: raise ValueError( 'The currently implemented backends are ["flux", "mpi", "slurm"]. ' @@ -75,6 +77,12 @@ def check_backend(backend: str): + backend + " is not a valid choice." ) + elif backend == "flux" or (backend == "auto" and flux_installed): + return "flux" + elif backend == "slurm" or (backend == "auto" and slurm_installed): + return "slurm" + else: + return "mpi" def check_init_function(block_allocation: bool, init_function: callable): diff --git a/tests/test_shared_input_check.py b/tests/test_shared_input_check.py index bf1076f2..ad322193 100644 --- a/tests/test_shared_input_check.py +++ b/tests/test_shared_input_check.py @@ -6,11 +6,11 @@ check_threads_per_core, check_oversubscribe, check_executor, - check_backend, check_init_function, check_refresh_rate, check_resource_dict, check_resource_dict_is_empty, + validate_backend, ) @@ -37,7 +37,9 @@ def test_check_executor(self): def test_check_backend(self): with self.assertRaises(ValueError): - check_backend(backend="test") + validate_backend( + backend="test", slurm_installed=False, flux_installed=False + ) def test_check_init_function(self): with self.assertRaises(ValueError):