diff --git a/conftest.py b/conftest.py index 5427fddc6..3e721da6f 100644 --- a/conftest.py +++ b/conftest.py @@ -35,6 +35,7 @@ from smartsim import Experiment from smartsim.entity import Model from smartsim.database import Orchestrator +from smartsim.log import get_logger from smartsim.settings import ( SrunSettings, AprunSettings, @@ -50,10 +51,12 @@ from subprocess import run import sys import tempfile +import time import typing as t import uuid import warnings +logger = get_logger(__name__) # pylint: disable=redefined-outer-name,invalid-name,global-statement @@ -68,6 +71,7 @@ test_port = CONFIG.test_port test_account = CONFIG.test_account or "" test_batch_resources: t.Dict[t.Any,t.Any] = CONFIG.test_batch_resources +test_output_dirs = 0 # Fill this at runtime if needed test_hostlist = None @@ -119,6 +123,9 @@ def pytest_sessionstart( if os.path.isdir(test_output_root): shutil.rmtree(test_output_root) os.makedirs(test_output_root) + while not os.path.isdir(test_output_root): + time.sleep(0.1) + print_test_configuration() @@ -130,10 +137,20 @@ def pytest_sessionfinish( returning the exit status to the system. """ if exitstatus == 0: - shutil.rmtree(test_output_root) - else: - # kill all spawned processes in case of error - kill_all_test_spawned_processes() + cleanup_attempts = 5 + while cleanup_attempts > 0: + try: + shutil.rmtree(test_output_root) + except OSError as e: + cleanup_attempts -= 1 + time.sleep(1) + if not cleanup_attempts: + raise + else: + break + + # kill all spawned processes + kill_all_test_spawned_processes() def kill_all_test_spawned_processes() -> None: @@ -455,6 +472,13 @@ def environment_cleanup(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("SSKEYOUT", raising=False) +@pytest.fixture(scope="function", autouse=True) +def check_output_dir() -> None: + global test_output_dirs + assert os.path.isdir(test_output_root) + assert len(os.listdir(test_output_root)) >= test_output_dirs + test_output_dirs = len(os.listdir(test_output_root)) + @pytest.fixture def dbutils() -> t.Type[DBUtils]: return DBUtils diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index f496fe052..532e3c5b6 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -252,6 +252,14 @@ def telemetry_cooldown(self) -> int: def telemetry_subdir(self) -> str: return ".smartsim/telemetry" + @property + def dragon_default_subdir(self) -> str: + return ".smartsim/dragon" + + @property + def dragon_log_filename(self) -> str: + return "dragon_config.log" + @lru_cache(maxsize=128, typed=False) def get_config() -> Config: diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index e1a3126ea..c7ab29904 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -131,7 +131,7 @@ def start( logger.warning(msg) self._launcher.connect_to_dragon(dragon_server_paths[0]) else: - dragon_path = osp.join(exp_path, ".smartsim", "dragon") + dragon_path = osp.join(exp_path, CONFIG.dragon_default_subdir) self._launcher.connect_to_dragon(dragon_path) if not self._launcher.is_connected: raise LauncherError("Could not connect to Dragon server") diff --git a/smartsim/_core/entrypoints/dragon.py b/smartsim/_core/entrypoints/dragon.py index a137da32b..f2314564b 100644 --- a/smartsim/_core/entrypoints/dragon.py +++ b/smartsim/_core/entrypoints/dragon.py @@ -28,6 +28,7 @@ import json import os import signal +import socket import textwrap import typing as t from types import FrameType @@ -36,7 +37,11 @@ import smartsim._core.utils.helpers as _helpers from smartsim._core.launcher.dragon.dragonBackend import DragonBackend -from smartsim._core.schemas import DragonBootstrapRequest, DragonBootstrapResponse +from smartsim._core.schemas import ( + DragonBootstrapRequest, + DragonBootstrapResponse, + DragonShutdownResponse, +) from smartsim._core.schemas.dragonRequests import request_serializer from smartsim._core.schemas.dragonResponses import response_serializer from smartsim._core.utils.network import get_best_interface_and_address @@ -44,6 +49,8 @@ # kill is not catchable SIGNALS = [signal.SIGINT, signal.SIGQUIT, signal.SIGTERM, signal.SIGABRT] +SHUTDOWN_INITIATED = False + def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: if not signo: @@ -56,7 +63,7 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: context = zmq.Context() """ -Redis/KeyDB entrypoint script +Dragon server entrypoint script """ DBPID: t.Optional[int] = None @@ -71,6 +78,7 @@ def print_summary(network_interface: str, ip_address: str) -> None: -------- Dragon Configuration -------- IPADDRESS: {ip_address} NETWORK: {network_interface} + HOSTNAME: {socket.gethostname()} DRAGON_SERVER_CONFIG: {json.dumps(zmq_config)} -------------------------------------- @@ -81,13 +89,13 @@ def print_summary(network_interface: str, ip_address: str) -> None: def run(dragon_head_address: str) -> None: + global SHUTDOWN_INITIATED # pylint: disable=global-statement print(f"Opening socket {dragon_head_address}") dragon_head_socket = context.socket(zmq.REP) dragon_head_socket.bind(dragon_head_address) - dragon_backend = DragonBackend() - while True: + while not SHUTDOWN_INITIATED: print(f"Listening to {dragon_head_address}") req = dragon_head_socket.recv_json() print(f"Received request: {req}") @@ -95,6 +103,8 @@ def run(dragon_head_address: str) -> None: resp = dragon_backend.process_request(drg_req) print(f"Sending response {resp}", flush=True) dragon_head_socket.send_json(response_serializer.serialize_to_json(resp)) + if isinstance(resp, DragonShutdownResponse): + SHUTDOWN_INITIATED = True def main(args: argparse.Namespace) -> int: @@ -132,11 +142,14 @@ def main(args: argparse.Namespace) -> int: run(dragon_head_address=dragon_head_address) + print("Shutting down! Bye bye!") return 0 def cleanup() -> None: + global SHUTDOWN_INITIATED # pylint: disable=global-statement print("Cleaning up", flush=True) + SHUTDOWN_INITIATED = True if __name__ == "__main__": @@ -156,8 +169,6 @@ def cleanup() -> None: ) args_ = parser.parse_args() - print(args_) - # make sure to register the cleanup before the start # the process so our signaller will be able to stop # the database process. diff --git a/smartsim/_core/generation/generator.py b/smartsim/_core/generation/generator.py index 79cea06b7..5633fcc36 100644 --- a/smartsim/_core/generation/generator.py +++ b/smartsim/_core/generation/generator.py @@ -148,7 +148,7 @@ def _gen_exp_dir(self) -> None: ) if not path.isdir(self.gen_path): # keep exists ok for race conditions on NFS - pathlib.Path(self.gen_path).mkdir(exist_ok=True) + pathlib.Path(self.gen_path).mkdir(exist_ok=True, parents=True) else: logger.log( level=self.log_level, msg="Working in previously created experiment" @@ -177,7 +177,7 @@ def _gen_orc_dir(self, orchestrator_list: t.List[Orchestrator]) -> None: # Always remove orchestrator files if present. if path.isdir(orc_path): shutil.rmtree(orc_path, ignore_errors=True) - pathlib.Path(orc_path).mkdir(exist_ok=self.overwrite) + pathlib.Path(orc_path).mkdir(exist_ok=self.overwrite, parents=True) def _gen_entity_list_dir(self, entity_lists: t.List[Ensemble]) -> None: """Generate directories for Ensemble instances diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 1b4188a81..84290b0b3 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -39,6 +39,8 @@ DragonResponse, DragonRunRequest, DragonRunResponse, + DragonShutdownRequest, + DragonShutdownResponse, DragonStopRequest, DragonStopResponse, DragonUpdateStatusRequest, @@ -79,6 +81,7 @@ def process_request(self, request: DragonRequest) -> DragonResponse: @process_request.register def _(self, request: DragonRunRequest) -> DragonRunResponse: + proc = TemplateProcess( target=request.exe, args=request.exe_args, @@ -88,7 +91,7 @@ def _(self, request: DragonRunRequest) -> DragonRunResponse: # stderr=Popen.PIPE, ) - grp = ProcessGroup(restart=False, pmi_enabled=True) + grp = ProcessGroup(restart=False, pmi_enabled=request.pmi_enabled) grp.add_process(nproc=request.tasks, template=proc) step_id = self._get_new_id() grp.init() @@ -108,10 +111,13 @@ def _(self, request: DragonUpdateStatusRequest) -> DragonUpdateStatusResponse: updated_statuses[step_id] = (STATUS_RUNNING, return_codes) else: if all(proc_id is not None for proc_id in proc_group_tuple[1]): - return_codes = [ - Process(None, ident=puid).returncode - for puid in proc_group_tuple[1] - ] + try: + return_codes = [ + Process(None, ident=puid).returncode + for puid in proc_group_tuple[1] + ] + except (ValueError, TypeError): + return_codes = [-1 for _ in proc_group_tuple[1]] else: return_codes = [0] status = ( @@ -141,3 +147,9 @@ def _(self, request: DragonStopRequest) -> DragonStopResponse: # pylint: disable-next=no-self-use,unused-argument def _(self, request: DragonHandshakeRequest) -> DragonHandshakeResponse: return DragonHandshakeResponse() + + @process_request.register + # Deliberately suppressing errors so that overloads have the same signature + # pylint: disable-next=no-self-use,unused-argument + def _(self, request: DragonShutdownRequest) -> DragonShutdownResponse: + return DragonShutdownResponse() diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 5b0c96579..1ae7b7c66 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -26,10 +26,12 @@ from __future__ import annotations +import atexit import fileinput import itertools import json import os +import signal import subprocess import sys import typing as t @@ -57,6 +59,7 @@ DragonResponse, DragonRunRequest, DragonRunResponse, + DragonShutdownRequest, DragonStopRequest, DragonStopResponse, DragonUpdateStatusRequest, @@ -71,6 +74,9 @@ _SchemaT = t.TypeVar("_SchemaT", bound=t.Union[DragonRequest, DragonResponse]) +DRG_LOCK = RLock() +DRG_CTX = zmq.Context() + class DragonLauncher(WLMLauncher): """This class encapsulates the functionality needed @@ -85,27 +91,26 @@ class DragonLauncher(WLMLauncher): def __init__(self) -> None: super().__init__() - self._context = zmq.Context() + self._context = DRG_CTX self._timeout = CONFIG.dragon_server_timeout self._reconnect_timeout = CONFIG.dragon_server_reconnect_timeout self._startup_timeout = CONFIG.dragon_server_startup_timeout self._context.setsockopt(zmq.SNDTIMEO, value=self._timeout) self._context.setsockopt(zmq.RCVTIMEO, value=self._timeout) self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None - self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] - self._comm_lock = RLock() + self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None @property def is_connected(self) -> bool: return self._dragon_head_socket is not None - def _handsake(self, address: str) -> None: + def _handshake(self, address: str) -> None: self._dragon_head_socket = self._context.socket(zmq.REQ) self._dragon_head_socket.connect(address) try: ( _helpers.start_with(DragonHandshakeRequest()) - .then(self._send_request_as_json) + .then(self._send_request) .then(_assert_schema_type(DragonHandshakeResponse)) ) logger.debug( @@ -123,13 +128,14 @@ def _set_timeout(self, timeout: int) -> None: self._context.setsockopt(zmq.SNDTIMEO, value=timeout) self._context.setsockopt(zmq.RCVTIMEO, value=timeout) + # pylint: disable-next=too-many-statements def connect_to_dragon(self, path: str) -> None: - with self._comm_lock: + with DRG_LOCK: # TODO use manager instead if self.is_connected: return - dragon_config_log = os.path.join(path, "dragon_config.log") + dragon_config_log = os.path.join(path, CONFIG.dragon_log_filename) if Path.is_file(Path(dragon_config_log)): dragon_confs = ( @@ -146,7 +152,7 @@ def connect_to_dragon(self, path: str) -> None: logger.debug(msg) try: self._set_timeout(self._reconnect_timeout) - self._handsake(dragon_conf["address"]) + self._handshake(dragon_conf["address"]) except LauncherError as e: logger.warning(e) finally: @@ -214,7 +220,22 @@ def connect_to_dragon(self, path: str) -> None: launcher_socket.close() self._set_timeout(self._timeout) - self._handsake(dragon_head_address) + self._handshake(dragon_head_address) + + # Only the launcher which started the server is + # responsible of it, that's why we register the + # cleanup in this code branch. + # The cleanup function should not have references + # to this object to avoid Garbage Collector lockup + server_socket = self._dragon_head_socket + server_process_pid = self._dragon_head_process.pid + + if server_socket is not None and server_process_pid: + atexit.register( + _dragon_cleanup, + server_socket=server_socket, + server_process_pid=server_process_pid, + ) else: # TODO parse output file raise LauncherError("Could not receive address of Dragon head process") @@ -259,7 +280,7 @@ def run(self, step: Step) -> t.Optional[str]: response = ( _helpers.start_with(req) - .then(self._send_request_as_json) + .then(self._send_request) .then(_assert_schema_type(DragonRunResponse)) .get_result() ) @@ -287,7 +308,7 @@ def stop(self, step_name: str) -> StepInfo: step_id = str(stepmap.step_id) ( _helpers.start_with(DragonStopRequest(step_id=step_id)) - .then(self._send_request_as_json) + .then(self._send_request) .then(_assert_schema_type(DragonStopResponse)) ) @@ -312,12 +333,12 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: response = ( _helpers.start_with(DragonUpdateStatusRequest(step_ids=step_ids)) - .then(self._send_request_as_json) + .then(self._send_request) .then(_assert_schema_type(DragonUpdateStatusResponse)) .get_result() ) - # create SlurmStepInfo objects to return + # create StepInfo objects to return updates: t.List[StepInfo] = [] # Order matters as we return an ordered list of StepInfo objects for step_id in step_ids: @@ -342,23 +363,11 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: updates.append(info) return updates - def _send_request_as_json( - self, request: DragonRequest, flags: int = 0 - ) -> DragonResponse: + def _send_request(self, request: DragonRequest, flags: int = 0) -> DragonResponse: if (socket := self._dragon_head_socket) is None: raise LauncherError("Launcher is not connected to Dragon") - with self._comm_lock: - logger.debug(f"Sending request: {request}") - return ( - _helpers.start_with(request) - .then(request_serializer.serialize_to_json) - .then(lambda req: socket.send_json(req, flags)) - .then(lambda _: socket.recv_json()) - .then(str) - .then(response_serializer.deserialize_from_json) - .get_result() - ) + return self.send_req_as_json(socket, request, flags) def __str__(self) -> str: return "Dragon" @@ -394,6 +403,22 @@ def _parse_launched_dragon_server_info_from_files( return dragon_envs + @staticmethod + def send_req_as_json( + socket: zmq.Socket[t.Any], request: DragonRequest, flags: int = 0 + ) -> DragonResponse: + with DRG_LOCK: + logger.debug(f"Sending {type(request).__name__}: {request}") + return ( + _helpers.start_with(request) + .then(request_serializer.serialize_to_json) + .then(lambda req: socket.send_json(req, flags)) + .then(lambda _: socket.recv_json()) + .then(str) + .then(response_serializer.deserialize_from_json) + .get_result() + ) + def _assert_schema_type(typ: t.Type[_SchemaT], /) -> t.Callable[[object], _SchemaT]: def _inner(obj: object) -> _SchemaT: @@ -402,3 +427,15 @@ def _inner(obj: object) -> _SchemaT: return obj return _inner + + +def _dragon_cleanup(server_socket: zmq.Socket[t.Any], server_process_pid: int) -> None: + try: + with DRG_LOCK: + DragonLauncher.send_req_as_json(server_socket, DragonShutdownRequest()) + except zmq.error.ZMQError as e: + logger.error( + f"Could not send shutdown request to dragon server, ZMQ error: {e}" + ) + finally: + os.kill(server_process_pid, signal.SIGINT) diff --git a/smartsim/_core/schemas/dragonRequests.py b/smartsim/_core/schemas/dragonRequests.py index 7db76459b..7ddfd28b5 100644 --- a/smartsim/_core/schemas/dragonRequests.py +++ b/smartsim/_core/schemas/dragonRequests.py @@ -47,6 +47,7 @@ class DragonRunRequestView(DragonRequest): path: NonEmptyStr nodes: PositiveInt = 1 tasks: PositiveInt = 1 + tasks_per_node: PositiveInt = 1 output_file: t.Optional[NonEmptyStr] = None error_file: t.Optional[NonEmptyStr] = None env: t.Dict[str, t.Optional[str]] = {} @@ -79,3 +80,7 @@ class DragonHandshakeRequest(DragonRequest): ... @request_serializer.register("bootstrap") class DragonBootstrapRequest(DragonRequest): address: NonEmptyStr + + +@request_serializer.register("shutdown") +class DragonShutdownRequest(DragonRequest): ... diff --git a/smartsim/_core/schemas/dragonResponses.py b/smartsim/_core/schemas/dragonResponses.py index 16f35850f..09562c80a 100644 --- a/smartsim/_core/schemas/dragonResponses.py +++ b/smartsim/_core/schemas/dragonResponses.py @@ -63,3 +63,7 @@ class DragonHandshakeResponse(DragonResponse): ... @response_serializer.register("bootstrap") class DragonBootstrapResponse(DragonResponse): ... + + +@response_serializer.register("shutdown") +class DragonShutdownResponse(DragonResponse): ... diff --git a/smartsim/log.py b/smartsim/log.py index eb4af0611..b72860bbc 100644 --- a/smartsim/log.py +++ b/smartsim/log.py @@ -39,7 +39,8 @@ # constants DEFAULT_DATE_FORMAT: t.Final[str] = "%H:%M:%S" DEFAULT_LOG_FORMAT: t.Final[str] = ( - "%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s" + "%(asctime)s %(hostname)s %(name)s[%(process)d:%(threadName)s] " + "%(levelname)s %(message)s" ) EXPERIMENT_LOG_FORMAT = DEFAULT_LOG_FORMAT.replace("s[%", "s {%(exp_path)s} [%") diff --git a/smartsim/settings/slurmSettings.py b/smartsim/settings/slurmSettings.py index 09689366e..b020d23f6 100644 --- a/smartsim/settings/slurmSettings.py +++ b/smartsim/settings/slurmSettings.py @@ -338,7 +338,7 @@ def check_env_vars(self) -> None: "environment. If the job is running in an interactive " f"allocation, the value {v} will not be set. Please " "consider removing the variable from the environment " - "and re-run the experiment." + "and re-running the experiment." ) logger.warning(msg)