Skip to content

Commit

Permalink
Turn SchemaSerializer into SchemaRegistry (#8)
Browse files Browse the repository at this point in the history
Rename `SchemaSerializer` to `SchemaRegistry`. Move message formatting into a
dedicated `_Message` class. Rename methods. Move schema coercion to and from
`_Message` strings into a `SocketSchemaTranslator` class. Better error
handling.

  [ committed by @MattToast ]
  [ reviewed by @al-rigazzi @ankona ]
  • Loading branch information
MattToast authored Mar 25, 2024
1 parent 372e152 commit 97292df
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 278 deletions.
26 changes: 9 additions & 17 deletions smartsim/_core/entrypoints/dragon.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@

import zmq

import smartsim._core.utils.helpers as _helpers
from smartsim._core.launcher.dragon import dragonSockets
from smartsim._core.launcher.dragon.dragonBackend import DragonBackend
from smartsim._core.schemas import (
DragonBootstrapRequest,
DragonBootstrapResponse,
DragonShutdownResponse,
)
from smartsim._core.schemas.dragonRequests import request_serializer
from smartsim._core.schemas.dragonResponses import response_serializer
from smartsim._core.utils.network import get_best_interface_and_address

# kill is not catchable
Expand Down Expand Up @@ -95,14 +93,15 @@ def run(dragon_head_address: str) -> None:
dragon_head_socket.bind(dragon_head_address)
dragon_backend = DragonBackend()

server = dragonSockets.as_server(dragon_head_socket)

while not SHUTDOWN_INITIATED:
print(f"Listening to {dragon_head_address}")
req = dragon_head_socket.recv_json()
req = server.recv()
print(f"Received request: {req}")
drg_req = request_serializer.deserialize_from_json(str(req))
resp = dragon_backend.process_request(drg_req)
resp = dragon_backend.process_request(req)
print(f"Sending response {resp}", flush=True)
dragon_head_socket.send_json(response_serializer.serialize_to_json(resp))
server.send(resp)
if isinstance(resp, DragonShutdownResponse):
SHUTDOWN_INITIATED = True

Expand All @@ -122,17 +121,10 @@ def main(args: argparse.Namespace) -> int:

launcher_socket = context.socket(zmq.REQ)
launcher_socket.connect(args.launching_address)
client = dragonSockets.as_client(launcher_socket)

response = (
_helpers.start_with(DragonBootstrapRequest(address=dragon_head_address))
.then(request_serializer.serialize_to_json)
.then(launcher_socket.send_json)
.then(lambda _: launcher_socket.recv_json())
.then(str)
.then(response_serializer.deserialize_from_json)
.get_result()
)

client.send(DragonBootstrapRequest(address=dragon_head_address))
response = client.recv()
if not isinstance(response, DragonBootstrapResponse):
raise ValueError(
"Could not receive connection confirmation from launcher. Aborting."
Expand Down
79 changes: 25 additions & 54 deletions smartsim/_core/launcher/dragon/dragonLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@

import zmq

import smartsim._core.utils.helpers as _helpers
from smartsim._core.schemas.dragonRequests import request_serializer
from smartsim._core.schemas.dragonResponses import response_serializer
from smartsim._core.launcher.dragon import dragonSockets
from smartsim._core.schemas.types import NonEmptyStr

from ....error import LauncherError
Expand Down Expand Up @@ -108,10 +106,8 @@ def _handshake(self, address: str) -> None:
self._dragon_head_socket = self._context.socket(zmq.REQ)
self._dragon_head_socket.connect(address)
try:
(
_helpers.start_with(DragonHandshakeRequest())
.then(self._send_request)
.then(_assert_schema_type(DragonHandshakeResponse))
_assert_schema_type(
self._send_request(DragonHandshakeRequest()), DragonHandshakeResponse
)
logger.debug(
f"Successful handshake with Dragon server at address {address}"
Expand Down Expand Up @@ -205,23 +201,13 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
)

if address is not None:
server = dragonSockets.as_server(launcher_socket)
logger.debug(f"Listening to {socket_addr}")
request = (
_helpers.start_with(launcher_socket.recv_json())
.then(str)
.then(request_serializer.deserialize_from_json)
.then(_assert_schema_type(DragonBootstrapRequest))
.get_result()
)
request = _assert_schema_type(server.recv(), DragonBootstrapRequest)

dragon_head_address = request.address
logger.debug(f"Connecting launcher to {dragon_head_address}")

(
_helpers.start_with(DragonBootstrapResponse())
.then(response_serializer.serialize_to_json)
.then(launcher_socket.send_json)
)
server.send(DragonBootstrapResponse())

launcher_socket.close()
self._set_timeout(self._timeout)
Expand Down Expand Up @@ -277,8 +263,8 @@ def run(self, step: Step) -> t.Optional[str]:
run_args = step.run_settings.run_args
env = step.run_settings.env_vars
nodes = int(run_args.get("nodes", None) or 1)
response = (
_helpers.start_with(
response = _assert_schema_type(
self._send_request(
DragonRunRequest(
exe=cmd[0],
exe_args=cmd[1:],
Expand All @@ -290,10 +276,8 @@ def run(self, step: Step) -> t.Optional[str]:
output_file=out,
error_file=err,
)
)
.then(self._send_request)
.then(_assert_schema_type(DragonRunResponse))
.get_result()
),
DragonRunResponse,
)
step_id = task_id = str(response.step_id)
elif isinstance(step, LocalStep):
Expand Down Expand Up @@ -328,10 +312,8 @@ def stop(self, step_name: str) -> StepInfo:

stepmap = self.step_mapping[step_name]
step_id = str(stepmap.step_id)
(
_helpers.start_with(DragonStopRequest(step_id=step_id))
.then(self._send_request)
.then(_assert_schema_type(DragonStopResponse))
_assert_schema_type(
self._send_request(DragonStopRequest(step_id=step_id)), DragonStopResponse
)

_, step_info = self.get_step_update([step_name])[0]
Expand All @@ -353,11 +335,9 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]:
if not self.is_connected:
raise LauncherError("Launcher is not connected to Dragon.")

response = (
_helpers.start_with(DragonUpdateStatusRequest(step_ids=step_ids))
.then(self._send_request)
.then(_assert_schema_type(DragonUpdateStatusResponse))
.get_result()
response = _assert_schema_type(
self._send_request(DragonUpdateStatusRequest(step_ids=step_ids)),
DragonUpdateStatusResponse,
)

# create StepInfo objects to return
Expand Down Expand Up @@ -390,7 +370,7 @@ def _send_request(self, request: DragonRequest, flags: int = 0) -> DragonRespons
if (socket := self._dragon_head_socket) is None:
raise LauncherError("Launcher is not connected to Dragon")

return self.send_req_as_json(socket, request, flags)
return self.send_req_with_socket(socket, request, flags)

def __str__(self) -> str:
return "Dragon"
Expand Down Expand Up @@ -429,35 +409,26 @@ def _parse_launched_dragon_server_info_from_files(
return dragon_envs

@staticmethod
def send_req_as_json(
def send_req_with_socket(
socket: zmq.Socket[t.Any], request: DragonRequest, flags: int = 0
) -> DragonResponse:
client = dragonSockets.as_client(socket)
with DRG_LOCK:
logger.debug(f"Sending {type(request).__name__}: {request}")
return (
_helpers.start_with(request)
.then(request_serializer.serialize_to_json)
.then(lambda req: socket.send_json(req, flags))
.then(lambda _: socket.recv_json())
.then(str)
.then(response_serializer.deserialize_from_json)
.get_result()
)

client.send(request, flags)
return client.recv()

def _assert_schema_type(typ: t.Type[_SchemaT], /) -> t.Callable[[object], _SchemaT]:
def _inner(obj: object) -> _SchemaT:
if not isinstance(obj, typ):
raise TypeError("Expected schema of type `{typ}`, but got {type(obj)}")
return obj

return _inner
def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT:
if not isinstance(obj, typ):
raise TypeError("Expected schema of type `{typ}`, but got {type(obj)}")
return obj


def _dragon_cleanup(server_socket: zmq.Socket[t.Any], server_process_pid: int) -> None:
try:
with DRG_LOCK:
DragonLauncher.send_req_as_json(server_socket, DragonShutdownRequest())
DragonLauncher.send_req_with_socket(server_socket, DragonShutdownRequest())
except zmq.error.ZMQError as e:
logger.error(
f"Could not send shutdown request to dragon server, ZMQ error: {e}"
Expand Down
56 changes: 56 additions & 0 deletions smartsim/_core/launcher/dragon/dragonSockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2023, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import typing as t

from smartsim._core.schemas import dragonRequests as _dragonRequests
from smartsim._core.schemas import dragonResponses as _dragonResponses
from smartsim._core.schemas import utils as _utils

if t.TYPE_CHECKING:
from zmq.sugar.socket import Socket


def as_server(
socket: "Socket[t.Any]",
) -> _utils.SocketSchemaTranslator[
_dragonResponses.DragonResponse,
_dragonRequests.DragonRequest,
]:
return _utils.SocketSchemaTranslator(
socket, _dragonResponses.response_registry, _dragonRequests.request_registry
)


def as_client(
socket: "Socket[t.Any]",
) -> _utils.SocketSchemaTranslator[
_dragonRequests.DragonRequest,
_dragonResponses.DragonResponse,
]:
return _utils.SocketSchemaTranslator(
socket, _dragonRequests.request_registry, _dragonResponses.response_registry
)
17 changes: 8 additions & 9 deletions smartsim/_core/schemas/dragonRequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@
# Black and Pylint disagree about where to put the `...`
# pylint: disable=multiple-statements


class DragonRequest(BaseModel): ...
request_registry = _utils.SchemaRegistry["DragonRequest"]()


request_serializer = _utils.SchemaSerializer[str, DragonRequest]("request_type")
class DragonRequest(BaseModel): ...


class DragonRunRequestView(DragonRequest):
Expand All @@ -55,32 +54,32 @@ class DragonRunRequestView(DragonRequest):
pmi_enabled: bool = True


@request_serializer.register("run")
@request_registry.register("run")
class DragonRunRequest(DragonRunRequestView):
current_env: t.Dict[str, t.Optional[str]] = {}

def __str__(self) -> str:
return str(DragonRunRequestView.parse_obj(self.dict(exclude={"current_env"})))


@request_serializer.register("update_status")
@request_registry.register("update_status")
class DragonUpdateStatusRequest(DragonRequest):
step_ids: t.List[NonEmptyStr]


@request_serializer.register("stop")
@request_registry.register("stop")
class DragonStopRequest(DragonRequest):
step_id: NonEmptyStr


@request_serializer.register("handshake")
@request_registry.register("handshake")
class DragonHandshakeRequest(DragonRequest): ...


@request_serializer.register("bootstrap")
@request_registry.register("bootstrap")
class DragonBootstrapRequest(DragonRequest):
address: NonEmptyStr


@request_serializer.register("shutdown")
@request_registry.register("shutdown")
class DragonShutdownRequest(DragonRequest): ...
17 changes: 8 additions & 9 deletions smartsim/_core/schemas/dragonResponses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,35 @@
# Black and Pylint disagree about where to put the `...`
# pylint: disable=multiple-statements

response_registry = _utils.SchemaRegistry["DragonResponse"]()


class DragonResponse(BaseModel):
error_message: t.Optional[str] = None


response_serializer = _utils.SchemaSerializer[str, DragonResponse]("response_type")


@response_serializer.register("run")
@response_registry.register("run")
class DragonRunResponse(DragonResponse):
step_id: NonEmptyStr


@response_serializer.register("status_update")
@response_registry.register("status_update")
class DragonUpdateStatusResponse(DragonResponse):
# status is a dict: {step_id: (is_alive, returncode)}
statuses: t.Mapping[NonEmptyStr, t.Tuple[NonEmptyStr, t.Optional[t.List[int]]]] = {}


@response_serializer.register("stop")
@response_registry.register("stop")
class DragonStopResponse(DragonResponse): ...


@response_serializer.register("handshake")
@response_registry.register("handshake")
class DragonHandshakeResponse(DragonResponse): ...


@response_serializer.register("bootstrap")
@response_registry.register("bootstrap")
class DragonBootstrapResponse(DragonResponse): ...


@response_serializer.register("shutdown")
@response_registry.register("shutdown")
class DragonShutdownResponse(DragonResponse): ...
Loading

0 comments on commit 97292df

Please sign in to comment.