Skip to content

Commit

Permalink
Add support for Slurm heterogeneous jobs (#346)
Browse files Browse the repository at this point in the history
This PR adds basic support for hetereogeneous jobs
on Slurm. Users can set the `--het-group` flag through
the `SrunSettings::set_het_group` method. Some checks
are added to make sure users don't run MPMD workloads
within heterogeneous jobs, as this is not allowed by Slurm.
In particular, orchestrators cannot be run with `single_cmd=True`
in a heterogeneous job.

[ committed by @al-rigazzi ]
[ reviewed by @MattToast @ashao ]
  • Loading branch information
al-rigazzi authored Aug 25, 2023
1 parent d7a6b60 commit 1d4d7a9
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 10 deletions.
33 changes: 25 additions & 8 deletions smartsim/_core/launcher/slurm/slurmLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import time
import typing as t

from shutil import which

from ....error import LauncherError
Expand Down Expand Up @@ -188,14 +188,33 @@ def stop(self, step_name: str) -> StepInfo:
stepmap = self.step_mapping[step_name]
if stepmap.managed:
step_id = str(stepmap.step_id)
# Check if step_id is part of colon-separated run
# if that is the case, stop parent job step because
# Check if step_id is part of colon-separated run,
# this is reflected in a '+' in the step id,
# so that the format becomes 12345+1.0.
# If we find it it can mean two things:
# a MPMD srun command, or a heterogeneous job.
# If it is a MPMD srun, then stop parent step because
# sub-steps cannot be stopped singularly.
if "+" in step_id:
sub_step = "+" in step_id
het_job = os.getenv("SLURM_HET_SIZE") is not None
# If it is a het job, we can stop
# them like this. Slurm will throw an error, but
# will actually kill steps correctly.
if sub_step and not het_job:
step_id = step_id.split("+", maxsplit=1)[0]
scancel_rc, _, err = scancel([step_id])
if scancel_rc != 0:
logger.warning(f"Unable to cancel job step {step_name}\n {err}")
if het_job:
msg = (
"SmartSim received a non-zero exit code while canceling"
f" a heterogeneous job step {step_name}!\n"
"The following error might be internal to Slurm\n"
"and the heterogeneous job step could have been correctly"
" canceled.\nSmartSim will consider it canceled.\n"
)
else:
msg = f"Unable to cancel job step {step_name}\n{err}"
logger.warning(msg)
if stepmap.task_id:
self.task_manager.remove_task(str(stepmap.task_id))
else:
Expand Down Expand Up @@ -238,9 +257,7 @@ def _get_slurm_step_id(step: Step, interval: int = 2) -> str:
raise LauncherError("Could not find id of launched job step")
return step_id

def _get_managed_step_update(
self, step_ids: t.List[str]
) -> t.List[StepInfo]:
def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]:
"""Get step updates for WLM managed jobs
:param step_ids: list of job step ids
Expand Down
15 changes: 13 additions & 2 deletions smartsim/database/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sys
import typing as t

from os import getcwd
from os import getcwd, getenv
from shlex import split as sh_split

from smartredis import Client
Expand Down Expand Up @@ -108,7 +108,18 @@ def _check_run_command(launcher: str, run_command: str) -> None:


def _get_single_command(run_command: str, batch: bool, single_cmd: bool) -> bool:
if not batch or not single_cmd:
if not single_cmd:
return single_cmd

if run_command == "srun" and getenv("SLURM_HET_SIZE") is not None:
msg = (
"srun can not launch an orchestrator with single_cmd=True in "
+ "a hetereogeneous job. Automatically switching to single_cmd=False."
)
logger.info(msg)
return False

if not batch:
return single_cmd

if run_command == "aprun":
Expand Down
35 changes: 35 additions & 0 deletions smartsim/settings/slurmSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def make_mpmd(self, settings: RunSettings) -> None:
raise SSUnsupportedError(
"Containerized MPMD workloads are not yet supported."
)
if os.getenv("SLURM_HET_SIZE") is not None:
raise ValueError(
"Slurm does not support MPMD workloads in heterogeneous jobs."
)
self.mpmd.append(settings)

def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
Expand Down Expand Up @@ -270,6 +274,37 @@ def set_walltime(self, walltime: str) -> None:
"""
self.run_args["time"] = str(walltime)

def set_het_group(self, het_group: t.Iterable[int]) -> None:
"""Set the heterogeneous group for this job
this sets `--het-group`
:param het_group: list of heterogeneous groups
:type het_group: int or iterable of ints
"""
het_size_env = os.getenv("SLURM_HET_SIZE")
if het_size_env is None:
msg = "Requested to set het group, but the allocation is not a het job"
raise ValueError(msg)

het_size = int(het_size_env)
if self.mpmd:
msg = "Slurm does not support MPMD workloads in heterogeneous jobs\n"
raise ValueError(msg)
msg = (
"Support for heterogeneous groups is an experimental feature, "
"please report any unexpected behavior to SmartSim developers "
"by opening an issue on https://github.com/CrayLabs/SmartSim/issues"
)
if any(group >= het_size for group in het_group):
msg = (
f"Het group {max(het_group)} requested, "
f"but max het group in allocation is {het_size-1}"
)
raise ValueError(msg)
logger.warning(msg)
self.run_args["het-group"] = ",".join(str(group) for group in het_group)

def format_run_args(self) -> t.List[str]:
"""Return a list of slurm formatted run arguments
Expand Down
82 changes: 82 additions & 0 deletions tests/on_wlm/test_het_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2023, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

from smartsim import Experiment
from smartsim.settings import SrunSettings


# retrieved from pytest fixtures
if pytest.test_launcher != "slurm":
pytestmark = pytest.mark.skip(reason="Test is only for Slurm WLM systems")

def test_mpmd_errors(monkeypatch):
monkeypatch.setenv("SLURM_HET_SIZE", "1")
exp_name = "test-het-job-errors"
exp = Experiment(exp_name, launcher="slurm")
rs: SrunSettings = exp.create_run_settings("sleep", "1", run_command="srun")
rs2: SrunSettings = exp.create_run_settings("sleep", "1", run_command="srun")
with pytest.raises(ValueError):
rs.make_mpmd(rs2)

monkeypatch.delenv("SLURM_HET_SIZE")
rs.make_mpmd(rs2)
with pytest.raises(ValueError):
rs.set_het_group(1)


def test_set_het_groups(monkeypatch):
"""Test ability to set one or more het groups to run setting
"""
monkeypatch.setenv("SLURM_HET_SIZE", "4")
exp_name = "test-set-het-group"
exp = Experiment(exp_name, launcher="slurm")
rs: SrunSettings = exp.create_run_settings("sleep", "1", run_command="srun")
rs.set_het_group([1])
assert rs.run_args["het-group"] == "1"
rs.set_het_group([3,2])
assert rs.run_args["het-group"] == "3,2"
with pytest.raises(ValueError):
rs.set_het_group([4])


def test_orch_single_cmd(monkeypatch, wlmutils):
"""Test that single cmd is rejected in a heterogeneous job"""
monkeypatch.setenv("SLURM_HET_SIZE", "1")
exp_name = "test-orch-single-cmd"
exp = Experiment(exp_name, launcher="slurm")
orc = exp.create_database(
wlmutils.get_test_port(),
db_nodes=3,
batch=False,
interface=wlmutils.get_test_interface(),
single_cmd=True,
hosts=wlmutils.get_test_hostlist(),
)

for node in orc:
assert node.is_mpmd == False

0 comments on commit 1d4d7a9

Please sign in to comment.