Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Slurm heterogeneous jobs #346

Merged
merged 20 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions smartsim/_core/launcher/slurm/slurmLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import time
import typing as t

from shutil import which

from ....error import LauncherError
Expand Down Expand Up @@ -188,14 +188,33 @@ def stop(self, step_name: str) -> StepInfo:
stepmap = self.step_mapping[step_name]
if stepmap.managed:
step_id = str(stepmap.step_id)
# Check if step_id is part of colon-separated run
# if that is the case, stop parent job step because
# Check if step_id is part of colon-separated run,
# this is reflected in a '+' in the step id,
# so that the format becomes 12345+1.0.
# If we find it it can mean two things:
# a MPMD srun command, or a heterogeneous job.
# If it is a MPMD srun, then stop parent step because
# sub-steps cannot be stopped singularly.
if "+" in step_id:
sub_step = "+" in step_id
het_job = os.getenv("SLURM_HET_SIZE") is not None
# If it is a het job, we can stop
# them like this. Slurm will throw an error, but
# will actually kill steps correctly.
MattToast marked this conversation as resolved.
Show resolved Hide resolved
if sub_step and not het_job:
step_id = step_id.split("+", maxsplit=1)[0]
scancel_rc, _, err = scancel([step_id])
if scancel_rc != 0:
logger.warning(f"Unable to cancel job step {step_name}\n {err}")
if het_job:
msg = (
"SmartSim received a non-zero exit code while canceling"
f" a heterogeneous job step {step_name}!\n"
"The following error might be internal to Slurm\n"
"and the heterogeneous job step could have been correctly"
" canceled.\nSmartSim will consider it canceled.\n"
)
else:
msg = f"Unable to cancel job step {step_name}\n{err}"
logger.warning(msg)
if stepmap.task_id:
self.task_manager.remove_task(str(stepmap.task_id))
else:
Expand Down Expand Up @@ -238,9 +257,7 @@ def _get_slurm_step_id(step: Step, interval: int = 2) -> str:
raise LauncherError("Could not find id of launched job step")
return step_id

def _get_managed_step_update(
self, step_ids: t.List[str]
) -> t.List[StepInfo]:
def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]:
"""Get step updates for WLM managed jobs

:param step_ids: list of job step ids
Expand Down
15 changes: 13 additions & 2 deletions smartsim/database/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sys
import typing as t

from os import getcwd
from os import getcwd, getenv
from shlex import split as sh_split

from smartredis import Client
Expand Down Expand Up @@ -108,7 +108,18 @@ def _check_run_command(launcher: str, run_command: str) -> None:


def _get_single_command(run_command: str, batch: bool, single_cmd: bool) -> bool:
if not batch or not single_cmd:
if not single_cmd:
return single_cmd

if run_command == "srun" and getenv("SLURM_HET_SIZE") is not None:
msg = (
"srun can not launch an orchestrator with single_cmd=True in "
+ "a hetereogeneous job. Automatically switching to single_cmd=False."
)
logger.info(msg)
return False

if not batch:
return single_cmd

if run_command == "aprun":
Expand Down
35 changes: 35 additions & 0 deletions smartsim/settings/slurmSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def make_mpmd(self, settings: RunSettings) -> None:
raise SSUnsupportedError(
"Containerized MPMD workloads are not yet supported."
)
if os.getenv("SLURM_HET_SIZE") is not None:
raise ValueError(
"Slurm does not support MPMD workloads in heterogeneous jobs."
)
self.mpmd.append(settings)

def set_hostlist(self, host_list: t.Union[str, t.List[str]]) -> None:
Expand Down Expand Up @@ -270,6 +274,37 @@ def set_walltime(self, walltime: str) -> None:
"""
self.run_args["time"] = str(walltime)

def set_het_group(self, het_group: t.Iterable[int]) -> None:
"""Set the heterogeneous group for this job

this sets `--het-group`

:param het_group: list of heterogeneous groups
:type het_group: int or iterable of ints
"""
het_size_env = os.getenv("SLURM_HET_SIZE")
if het_size_env is None:
msg = "Requested to set het group, but the allocation is not a het job"
raise ValueError(msg)

het_size = int(het_size_env)
if self.mpmd:
msg = "Slurm does not support MPMD workloads in heterogeneous jobs\n"
raise ValueError(msg)
msg = (
"Support for heterogeneous groups is an experimental feature, "
"please report any unexpected behavior to SmartSim developers "
"by opening an issue on https://github.com/CrayLabs/SmartSim/issues"
)
if any(group >= het_size for group in het_group):
msg = (
f"Het group {max(het_group)} requested, "
f"but max het group in allocation is {het_size-1}"
)
raise ValueError(msg)
logger.warning(msg)
self.run_args["het-group"] = ",".join(str(group) for group in het_group)

def format_run_args(self) -> t.List[str]:
"""Return a list of slurm formatted run arguments

Expand Down
82 changes: 82 additions & 0 deletions tests/on_wlm/test_het_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2023, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

from smartsim import Experiment
from smartsim.settings import SrunSettings


# retrieved from pytest fixtures
if pytest.test_launcher != "slurm":
pytestmark = pytest.mark.skip(reason="Test is only for Slurm WLM systems")

def test_mpmd_errors(monkeypatch):
monkeypatch.setenv("SLURM_HET_SIZE", "1")
exp_name = "test-het-job-errors"
exp = Experiment(exp_name, launcher="slurm")
rs: SrunSettings = exp.create_run_settings("sleep", "1", run_command="srun")
rs2: SrunSettings = exp.create_run_settings("sleep", "1", run_command="srun")
with pytest.raises(ValueError):
rs.make_mpmd(rs2)

monkeypatch.delenv("SLURM_HET_SIZE")
rs.make_mpmd(rs2)
with pytest.raises(ValueError):
rs.set_het_group(1)


def test_set_het_groups(monkeypatch):
"""Test ability to set one or more het groups to run setting
"""
monkeypatch.setenv("SLURM_HET_SIZE", "4")
exp_name = "test-set-het-group"
exp = Experiment(exp_name, launcher="slurm")
rs: SrunSettings = exp.create_run_settings("sleep", "1", run_command="srun")
rs.set_het_group([1])
assert rs.run_args["het-group"] == "1"
rs.set_het_group([3,2])
assert rs.run_args["het-group"] == "3,2"
with pytest.raises(ValueError):
rs.set_het_group([4])


def test_orch_single_cmd(monkeypatch, wlmutils):
"""Test that single cmd is rejected in a heterogeneous job"""
monkeypatch.setenv("SLURM_HET_SIZE", "1")
exp_name = "test-orch-single-cmd"
exp = Experiment(exp_name, launcher="slurm")
orc = exp.create_database(
wlmutils.get_test_port(),
db_nodes=3,
batch=False,
interface=wlmutils.get_test_interface(),
single_cmd=True,
hosts=wlmutils.get_test_hostlist(),
)

for node in orc:
assert node.is_mpmd == False
Loading