Skip to content

Commit

Permalink
decouple authenticator and socket creation
Browse files Browse the repository at this point in the history
update dragon lib
  • Loading branch information
Chris McBride authored and ankona committed Apr 15, 2024
1 parent 547e20a commit 2f52bb6
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 162 deletions.
27 changes: 21 additions & 6 deletions smartsim/_core/control/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _map_standard_metadata(
entity_dict: t.Dict[str, t.Any],
entity: "JobEntity",
exp_dir: str,
raw_experiment: t.Dict[str, t.Any],
) -> None:
"""Map universal properties from a runtime manifest onto a `JobEntity`
Expand All @@ -147,13 +148,21 @@ def _map_standard_metadata(
:param entity: The entity instance to modify
:type entity: JobEntity
:param exp_dir: The path to the experiment working directory
:type exp_dir: str"""
:type exp_dir: str
:param raw_experiment: The raw experiment dictionary deserialized from
manifest JSON
:type raw_experiment: Dict[str, Any]"""
metadata = entity_dict["telemetry_metadata"]
status_dir = pathlib.Path(metadata.get("status_dir"))
is_dragon = raw_experiment["launcher"].lower() == "dragon"

# all entities contain shared properties that identify the task
entity.type = entity_type
entity.name = entity_dict["name"]
entity.name = (
entity_dict["name"]
if not is_dragon
else entity_dict["telemetry_metadata"]["step_id"]
)
entity.step_id = str(metadata.get("step_id") or "")
entity.task_id = str(metadata.get("task_id") or "")
entity.timestamp = int(entity_dict.get("timestamp", "0"))
Expand All @@ -162,19 +171,25 @@ def _map_standard_metadata(

@classmethod
def from_manifest(
cls, entity_type: str, entity_dict: t.Dict[str, t.Any], exp_dir: str
cls,
entity_type: str,
entity_dict: t.Dict[str, t.Any],
exp_dir: str,
raw_experiment: t.Dict[str, t.Any],
) -> "JobEntity":
"""Instantiate a `JobEntity` from the dictionary deserialized from manifest JSON
:param entity_type: The type of the associated `SmartSimEntity`
:type entity_type: str
:param entity_dict: The raw dictionary deserialized from manifest JSON
:type entity_dict: Dict[str, Any]
:param raw_experiment: raw experiment deserialized from manifest JSON
:type raw_experiment: Dict[str, Any]
:param exp_dir: The path to the experiment working directory
:type exp_dir: str"""
entity = JobEntity()

cls._map_standard_metadata(entity_type, entity_dict, entity, exp_dir)
cls._map_standard_metadata(
entity_type, entity_dict, entity, exp_dir, raw_experiment
)
cls._map_db_metadata(entity_dict, entity)

return entity
Expand Down
42 changes: 23 additions & 19 deletions smartsim/_core/entrypoints/dragon.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None:
if not signo:
logger.info("Received signal with no signo")
else:
logger.info(f"Received {signo}")
logger.info(f"Received signal {signo}")
cleanup()


"""
Dragon server entrypoint script
"""

DBPID: t.Optional[int] = None


def print_summary(network_interface: str, ip_address: str) -> None:
zmq_config = {"interface": network_interface, "address": ip_address}
Expand All @@ -86,10 +84,9 @@ def print_summary(network_interface: str, ip_address: str) -> None:


def run(
zmq_context: "zmq.Context[t.Any]",
dragon_head_address: str,
dragon_pid: int,
zmq_context: zmq.Context[t.Any],
zmq_authenticator: zmq.auth.thread.ThreadAuthenticator,
) -> None:
logger.debug(f"Opening socket {dragon_head_address}")

Expand All @@ -98,9 +95,7 @@ def run(
zmq_context.setsockopt(zmq.REQ_CORRELATE, 1)
zmq_context.setsockopt(zmq.REQ_RELAXED, 1)

dragon_head_socket, zmq_authenticator = dragonSockets.get_secure_socket(
context, zmq.REP, True, zmq_authenticator
)
dragon_head_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REP, True)
dragon_head_socket.bind(dragon_head_address)
dragon_backend = DragonBackend(pid=dragon_pid)

Expand Down Expand Up @@ -133,7 +128,7 @@ def run(
break


def main(args: argparse.Namespace, zmq_context: zmq.Context[t.Any]) -> int:
def main(args: argparse.Namespace) -> int:
if_config = get_best_interface_and_address()
interface = if_config.interface
address = if_config.address
Expand All @@ -142,39 +137,50 @@ def main(args: argparse.Namespace, zmq_context: zmq.Context[t.Any]) -> int:
dragon_head_address = f"tcp://{address}"

if args.launching_address:
zmq_context = zmq.Context()

if str(args.launching_address).split(":", maxsplit=1)[0] == dragon_head_address:
address = "localhost"
dragon_head_address = "tcp://localhost:5555"
else:
dragon_head_address += ":5555"

launcher_socket, authenticator = dragonSockets.get_secure_socket(
context, zmq.REQ, False
)
zmq_authenticator = dragonSockets.get_authenticator(zmq_context)

logger.debug("Getting launcher socket")
launcher_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REQ, False)

logger.debug(f"Connecting launcher socket to: {args.launching_address}")
launcher_socket.connect(args.launching_address)
client = dragonSockets.as_client(launcher_socket)

logger.debug(
f"Sending bootstrap request to launcher_socket with {dragon_head_address}"
)
client.send(DragonBootstrapRequest(address=dragon_head_address))
response = client.recv()

logger.debug(f"Received bootstrap response: {response}")
if not isinstance(response, DragonBootstrapResponse):
raise ValueError(
"Could not receive connection confirmation from launcher. Aborting."
)

print_summary(interface, dragon_head_address)

try:
logger.debug("Executing event loop")
run(
zmq_context=zmq_context,
dragon_head_address=dragon_head_address,
dragon_pid=response.dragon_pid,
zmq_context=zmq_context,
zmq_authenticator=authenticator,
)
except Exception as e:
logger.error(f"Dragon server failed with {e}", exc_info=True)
return os.EX_SOFTWARE
finally:
if authenticator.is_alive():
authenticator.stop()
if zmq_authenticator is not None and zmq_authenticator.is_alive():
zmq_authenticator.stop()

logger.info("Shutting down! Bye bye!")
return 0
Expand Down Expand Up @@ -210,6 +216,4 @@ def cleanup() -> None:
for sig in SIGNALS:
signal.signal(sig, handle_signal)

context = zmq.Context()

sys.exit(main(args_, context))
sys.exit(main(args_))
1 change: 1 addition & 0 deletions smartsim/_core/entrypoints/telemetrymonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def configure_logger(logger_: logging.Logger, log_level_: int, exp_dir: str) ->
# Must register cleanup before the main loop is running
def cleanup_telemetry_monitor(_signo: int, _frame: t.Optional[FrameType]) -> None:
"""Create an enclosure on `manifest_observer` to avoid global variables"""
logger.info("Shutdown signal received by telemetry monitor entrypoint")
telemetry_monitor.cleanup()

register_signal_handlers(cleanup_telemetry_monitor)
Expand Down
11 changes: 4 additions & 7 deletions smartsim/_core/launcher/dragon/dragonBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections
import functools
import typing as t
Expand All @@ -33,12 +32,10 @@
# pylint: disable=import-error
# isort: off
from dragon.infrastructure.policy import Policy
from dragon.native.process import Process, TemplateProcess
from dragon.native.process import Process, ProcessTemplate
from dragon.native.process_group import (
ProcessGroup,
DragonProcessGroupError,
Error,
Running,
)
from dragon.native.machine import System, Node

Expand All @@ -61,8 +58,8 @@
from smartsim._core.utils.helpers import create_short_id_str
from smartsim.status import TERMINAL_STATUSES, SmartSimStatus

DRG_ERROR_STATUS = str(Error())
DRG_RUNNING_STATUS = str(Running())
DRG_ERROR_STATUS = "Error"
DRG_RUNNING_STATUS = "Running"


@dataclass
Expand Down Expand Up @@ -206,7 +203,7 @@ def update(self) -> None:
local_policy = Policy(
placement=Policy.Placement.HOST_NAME, host_name=node_name
)
tmp_proc = TemplateProcess(
tmp_proc = ProcessTemplate(
target=request.exe,
args=request.exe_args,
cwd=request.path,
Expand Down
40 changes: 30 additions & 10 deletions smartsim/_core/launcher/dragon/dragonLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
_SchemaT = t.TypeVar("_SchemaT", bound=t.Union[DragonRequest, DragonResponse])

DRG_LOCK = RLock()
DRG_CTX = zmq.Context()


class DragonLauncher(WLMLauncher):
Expand All @@ -92,26 +91,26 @@ class DragonLauncher(WLMLauncher):

def __init__(self) -> None:
super().__init__()
self._context = DRG_CTX
self._context = zmq.Context()
self._timeout = CONFIG.dragon_server_timeout
self._reconnect_timeout = CONFIG.dragon_server_reconnect_timeout
self._startup_timeout = CONFIG.dragon_server_startup_timeout
self._context.setsockopt(zmq.SNDTIMEO, value=self._timeout)
self._context.setsockopt(zmq.RCVTIMEO, value=self._timeout)
self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None
self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None
# Returned by dragon head, useful if shutdown is to be requested
# but process was started by another launcher
self._dragon_head_pid: t.Optional[int] = None
self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None

self._set_timeout(self._timeout)

@property
def is_connected(self) -> bool:
return self._dragon_head_socket is not None

def _handshake(self, address: str) -> None:
self._dragon_head_socket, self._authenticator = dragonSockets.get_secure_socket(
self._context, zmq.REQ, False, self._authenticator
self._dragon_head_socket = dragonSockets.get_secure_socket(
self._context, zmq.REQ, False
)
self._dragon_head_socket.connect(address)
try:
Expand Down Expand Up @@ -149,6 +148,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
path = _resolve_dragon_path(path)
dragon_config_log = path / CONFIG.dragon_log_filename

self._authenticator = dragonSockets.get_authenticator(self._context)

if dragon_config_log.is_file():
dragon_confs = self._parse_launched_dragon_server_info_from_files(
[dragon_config_log]
Expand Down Expand Up @@ -180,14 +181,19 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
"smartsim._core.entrypoints.dragon",
]

# Optionally configure the dragon host for debug output
if CONFIG.log_level in ["debug", "developer"]:
cmd.insert(1, "DEBUG")
cmd.insert(1, "-l")

address = get_best_interface_and_address().address
socket_addr = ""
launcher_socket: t.Optional[zmq.Socket[t.Any]] = None
if address is not None:
self._set_timeout(self._startup_timeout)

launcher_socket, self._authenticator = dragonSockets.get_secure_socket(
self._context, zmq.REP, True, self._authenticator
launcher_socket = dragonSockets.get_secure_socket(
self._context, zmq.REP, True
)

# find first available port >= 5995
Expand All @@ -206,6 +212,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
) as dragon_err:
current_env = os.environ.copy()
current_env.update({"PYTHONUNBUFFERED": "1"})
logger.debug(f"Starting Dragon environment: {' '.join(cmd)}")

# pylint: disable-next=consider-using-with
self._dragon_head_process = subprocess.Popen(
args=cmd,
Expand Down Expand Up @@ -258,6 +266,7 @@ def log_dragon_outputs() -> None:
server_socket = self._dragon_head_socket
server_process_pid = self._dragon_head_process.pid

# avoid registering the same cleanup more than once
if server_socket is not None and self._dragon_head_process is not None:
atexit.register(
_dragon_cleanup,
Expand All @@ -271,11 +280,15 @@ def log_dragon_outputs() -> None:
raise LauncherError("Could not receive address of Dragon head process")

def cleanup(self) -> None:
logger.debug("Starting Dragon launcher cleanup")
_dragon_cleanup(
server_socket=self._dragon_head_socket,
server_process_pid=self._dragon_head_pid,
server_authenticator=self._authenticator,
)
self._dragon_head_socket = None
self._dragon_head_pid = 0
self._authenticator = None

# RunSettings types supported by this launcher
@property
Expand Down Expand Up @@ -492,27 +505,34 @@ def _dragon_cleanup(

try:
if server_socket is not None:
print("Sending shutdown request to dragon environment")
DragonLauncher.send_req_with_socket(server_socket, DragonShutdownRequest())
except zmq.error.ZMQError as e:
# Can't use the logger as I/O file may be closed
print("Could not send shutdown request to dragon server")
print(f"ZMQ error: {e}", flush=True)
finally:
print("Sending shutdown request is complete")
time.sleep(1)

try:
if server_process_pid and psutil.pid_exists(server_process_pid):
print("Sending SIGINT to dragon server")
os.kill(server_process_pid, signal.SIGINT)
print("Sent SIGINT to dragon server")
except ProcessLookupError:
# Can't use the logger as I/O file may be closed
print("Dragon server is not running.", flush=True)
finally:
print("Dragon server process shutdown is complete")

try:
if server_authenticator is not None:
if server_authenticator is not None and server_authenticator.is_alive():
print("Shutting down ZMQ authenticator")
server_authenticator.stop()
except Exception:
print("Authenticator shutdown error")
finally:
print("Authenticator shutdown is complete")


def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path:
Expand Down
Loading

0 comments on commit 2f52bb6

Please sign in to comment.