Skip to content

Commit

Permalink
typehints batch 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Jun 7, 2023
1 parent be0f3c8 commit 2ce2bd7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 67 deletions.
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/slurm/slurmCommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def scontrol(args: t.List[str]) -> t.Tuple[str, str]:
return out, error


def scancel(args: t.List[str]) -> t.Tuple[str, str]:
def scancel(args: t.List[str]) -> t.Tuple[str, str, str]:
"""Calls slurm scancel with args.
returncode is also supplied in this function.
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/stepMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_task_id(self, step_id: int) -> t.Optional[int]:

def get_ids(
self, step_names: t.List[str], managed: bool = True
) -> t.Tuple[t.List[str], t.List[int]]:
) -> t.Tuple[t.List[str], t.List[t.Union[int, None]]]:
ids = []
names = []
for name in step_names:
Expand Down
31 changes: 16 additions & 15 deletions smartsim/_core/launcher/taskManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TaskManager:
def __init__(self) -> None:
"""Initialize a task manager thread."""
self.actively_monitoring = False
self.task_history = dict()
self.task_history: t.Dict[str, t.Tuple[t.Optional[int], t.Optional[str], t.Optional[str]]] = {}
self.tasks: t.List[Task] = []
self._lock = RLock()

Expand Down Expand Up @@ -104,10 +104,10 @@ def start_task(
self,
cmd_list: t.List[str],
cwd: str,
env: t.Dict[str, str] = None,
out=PIPE,
err=PIPE,
) -> int:
env: t.Optional[t.Dict[str, str]] = None,
out: int = PIPE,
err: int = PIPE,
) -> str:
"""Start a task managed by the TaskManager
This is an "unmanaged" task, meaning it is NOT managed
Expand Down Expand Up @@ -211,7 +211,7 @@ def remove_task(self, task_id: str) -> None:
finally:
self._lock.release()

def get_task_update(self, task_id: str) -> t.Tuple[str, int, str, str]:
def get_task_update(self, task_id: str) -> t.Tuple[str, t.Optional[int], t.Optional[str], t.Optional[str]]:
"""Get the update of a task
:param task_id: task id
Expand Down Expand Up @@ -245,7 +245,7 @@ def get_task_update(self, task_id: str) -> t.Tuple[str, int, str, str]:
def add_task_history(
self,
task_id: str,
returncode: int,
returncode: t.Optional[int] = None,
out: t.Optional[str] = None,
err: t.Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -283,7 +283,7 @@ def __len__(self) -> int:


class Task:
def __init__(self, process: psutil.Popen) -> None:
def __init__(self, process: t.Union[psutil.Popen, psutil.Process]) -> None:
"""Initialize a task
:param process: Popen object
Expand All @@ -292,27 +292,28 @@ def __init__(self, process: psutil.Popen) -> None:
self.process = process
self.pid = str(self.process.pid)

def check_status(self) -> int:
def check_status(self) -> t.Optional[int]:
"""Ping the job and return the returncode if finished
:return: returncode if finished otherwise None
:rtype: int
"""
if self.owned:
if self.owned and isinstance(self.process, psutil.Popen):
return self.process.poll()
# we can't manage Processed we don't own
# have to rely on .kill() to stop.
return self.returncode

def get_io(self) -> t.Tuple[str, str]:
def get_io(self) -> t.Tuple[t.Optional[str], t.Optional[str]]:
"""Get the IO from the subprocess
:return: output and error from the Popen
:rtype: str, str
"""
# Process class does not implement communicate
if not self.owned:
if not self.owned or not isinstance(self.process, psutil.Popen):
return None, None

output, error = self.process.communicate()
if output:
output = output.decode("utf-8")
Expand All @@ -323,7 +324,7 @@ def get_io(self) -> t.Tuple[str, str]:
def kill(self, timeout: int = 10) -> None:
"""Kill the subprocess and all children"""

def kill_callback(proc: psutil.Popen) -> None:
def kill_callback(proc: psutil.Process) -> None:
logger.debug(f"Process terminated with kill {proc.pid}")

children = self.process.children(recursive=True)
Expand All @@ -344,7 +345,7 @@ def terminate(self, timeout: int = 10) -> None:
:type timeout: int, optional
"""

def terminate_callback(proc: psutil.Popen) -> None:
def terminate_callback(proc: psutil.Process) -> None:
logger.debug(f"Cleanly terminated task {proc.pid}")

children = self.process.children(recursive=True)
Expand All @@ -370,7 +371,7 @@ def wait(self) -> None:

@property
def returncode(self) -> t.Optional[int]:
if self.owned:
if self.owned and isinstance(self.process, psutil.Popen):
return self.process.returncode
if self.is_alive:
return None
Expand Down
45 changes: 24 additions & 21 deletions smartsim/wlm/pbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

import json
import os
import typing as t
from shutil import which

from smartsim.error.errors import LauncherError, SmartSimError

from .._core.launcher.pbs.pbsCommands import qstat


def get_hosts():
def get_hosts() -> t.List[str]:
"""Get the name of the hosts used in a PBS allocation.
:returns: Names of the host nodes
Expand All @@ -54,19 +55,19 @@ def get_hosts():
)


def get_queue():
def get_queue() -> str:
"""Get the name of queue in a PBS allocation.
:returns: The name of the queue
:rtype: str
:raises SmartSimError: ``PBS_QUEUE`` is not set
"""
if "PBS_QUEUE" in os.environ:
return os.environ.get("PBS_QUEUE")
return os.environ["PBS_QUEUE"]
raise SmartSimError("Could not parse queue from SLURM_JOB_PARTITION")


def get_tasks():
def get_tasks() -> int:
"""Get the number of processes on each chunk in a PBS allocation.
.. note::
Expand All @@ -87,16 +88,17 @@ def get_tasks():
"PBS(qstat) at the call site"
)
)
job_id = os.environ.get("PBS_JOBID")
job_info_str, _ = qstat(["-f", "-F", "json", job_id])
job_info = json.loads(job_info_str)
return int(job_info["Jobs"][job_id]["resources_used"]["ncpus"])

if job_id := os.environ.get("PBS_JOBID"):
job_info_str, _ = qstat(["-f", "-F", "json", job_id])
job_info = json.loads(job_info_str)
return int(job_info["Jobs"][job_id]["resources_used"]["ncpus"])
raise SmartSimError(
"Could not parse number of requested tasks without an allocation"
)


def get_tasks_per_node():
def get_tasks_per_node() -> t.Dict[str, int]:
"""Get the number of processes on each chunk in a PBS allocation.
.. note::
Expand All @@ -117,16 +119,17 @@ def get_tasks_per_node():
"PBS(qstat) at the call site"
)
)
job_id = os.environ.get("PBS_JOBID")
job_info_str, _ = qstat(["-f", "-F", "json", job_id])
job_info = json.loads(job_info_str)
chunks_and_ncpus = job_info["Jobs"][job_id]["exec_vnode"] # type: str

chunk_cpu_map = {}
for cunck_and_ncpu in chunks_and_ncpus.split("+"):
chunk, ncpu = cunck_and_ncpu.strip("()").split(":")
ncpu = ncpu.lstrip("ncpus=")
chunk_cpu_map[chunk] = int(ncpu)

return chunk_cpu_map

if job_id := os.environ.get("PBS_JOBID"):
job_info_str, _ = qstat(["-f", "-F", "json", job_id])
job_info = json.loads(job_info_str)
chunks_and_ncpus = job_info["Jobs"][job_id]["exec_vnode"] # type: str

chunk_cpu_map = {}
for cunck_and_ncpu in chunks_and_ncpus.split("+"):
chunk, ncpu = cunck_and_ncpu.strip("()").split(":")
ncpu = ncpu.lstrip("ncpus=")
chunk_cpu_map[chunk] = int(ncpu)

return chunk_cpu_map
raise SmartSimError("Could not parse tasks per node without an allocation")
Loading

0 comments on commit 2ce2bd7

Please sign in to comment.