From c06088da86c97fc49e622e3fb06edb61b4615575 Mon Sep 17 00:00:00 2001 From: Lin Guo Date: Thu, 19 Dec 2024 05:21:22 +0000 Subject: [PATCH] Move status checking logic to python (for analyze only) --- lib/ramble/ramble/application.py | 3 +- .../slurm/workflow_manager.py | 60 ++++++++++++++----- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/lib/ramble/ramble/application.py b/lib/ramble/ramble/application.py index dcb33b8dd..4dc7345da 100644 --- a/lib/ramble/ramble/application.py +++ b/lib/ramble/ramble/application.py @@ -65,7 +65,8 @@ from enum import Enum experiment_status = Enum( - "experiment_status", ["UNKNOWN", "SETUP", "RUNNING", "COMPLETE", "SUCCESS", "FAILED"] + "experiment_status", + ["UNKNOWN", "SETUP", "RUNNING", "COMPLETE", "SUCCESS", "FAILED", "CANCELLED"], ) _NULL_CONTEXT = "null" diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index ef3fb90d9..afef4059c 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -12,7 +12,18 @@ from ramble.expander import ExpanderError from ramble.application import experiment_status -from spack.util.executable import Executable +from spack.util.executable import ProcessError + +# Mapping from squeue/sacct status to Ramble status +_STATUS_MAP = { + "PD": "SETUP", + "R": "RUNNING", + "CF": "SETUP", + "CG": "COMPLETE", + "COMPLETED": "COMPLETE", + "CANCELLED": "CANCELLED", + "CANCELLED+": "CANCELLED", +} _ensure_job_id_snippet = r""" job_id=$(< {experiment_run_dir}/.slurm_job) @@ -125,9 +136,11 @@ def get_status(self, workspace): if not os.path.isfile(job_id_file): logger.warn("job_id file is missing") return status + with open(job_id_file) as f: + job_id = f.read().strip() self.runner.set_dry_run(workspace.dry_run) - self.runner.set_run_dir(run_dir) - wm_status = self.runner.get_status() + wm_status_raw = self.runner.get_status(job_id) + wm_status = _STATUS_MAP.get(wm_status_raw) if wm_status is not None and hasattr(experiment_status, wm_status): status = getattr(experiment_status, wm_status) return status @@ -138,20 +151,25 @@ class SlurmRunner: def __init__(self, dry_run=False): self.dry_run = dry_run + self.squeue_runner = None + self.sacct_runner = None self.run_dir = None + def _ensure_runner(self, runner_name: str): + attr = f"{runner_name}_runner" + if getattr(self, attr) is None: + setattr( + self, + attr, + CommandRunner(name=runner_name, command=runner_name), + ) + def set_dry_run(self, dry_run=False): """ Set the dry_run state of this runner """ self.dry_run = dry_run - def set_run_dir(self, run_dir): - """ - Set the experiment_run_dir of this runner - """ - self.run_dir = run_dir - def generate_query_command(self, job_id): return rf""" status=$(squeue -h -o "%t" -j {job_id} 2>/dev/null) @@ -166,6 +184,7 @@ def generate_query_command(self, job_id): status_map["CF"]="SETUP" status_map["CG"]="COMPLETE" status_map["COMPLETED"]="COMPLETE" + status_map["CANCELLED+"]="CANCELLED" if [ -v status_map["$status"] ]; then status=${{status_map["$status"]}} fi @@ -179,9 +198,22 @@ def generate_cancel_command(self, job_id): def generate_hostfile_command(self): return "scontrol show hostnames" - def get_status(self): - if self.dry_run or self.run_dir is None: + def get_status(self, job_id): + if self.dry_run: return None - query_cmd = Executable(os.path.join(self.run_dir, "batch_query")) - out = query_cmd(output=str) - return out.split(":")[-1].strip() + self._ensure_runner("squeue") + squeue_args = ["-h", "-o", "%t", "-j", job_id] + try: + status_out = self.squeue_runner.command( + *squeue_args, output=str, error=os.devnull + ) + except ProcessError as e: + status_out = "" + logger.debug( + f"squeue returns error {e}. This is normal if the job has already been completed." + ) + if not status_out: + self._ensure_runner("sacct") + sacct_args = ["-o", "state", "-X", "-n", "-j", job_id] + status_out = self.sacct_runner.command(*sacct_args, output=str) + return status_out.strip()