Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turn SchemaSerializer into SchemaRegistry #8

Merged
merged 9 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions smartsim/_core/entrypoints/dragon.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@

import zmq

import smartsim._core.utils.helpers as _helpers
from smartsim._core.launcher.dragon import dragonSockets
from smartsim._core.launcher.dragon.dragonBackend import DragonBackend
from smartsim._core.schemas import (
DragonBootstrapRequest,
DragonBootstrapResponse,
DragonShutdownResponse,
)
from smartsim._core.schemas.dragonRequests import request_serializer
from smartsim._core.schemas.dragonResponses import response_serializer
from smartsim._core.utils.network import get_best_interface_and_address

# kill is not catchable
Expand Down Expand Up @@ -95,14 +93,15 @@ def run(dragon_head_address: str) -> None:
dragon_head_socket.bind(dragon_head_address)
dragon_backend = DragonBackend()

server = dragonSockets.as_server(dragon_head_socket)

while not SHUTDOWN_INITIATED:
print(f"Listening to {dragon_head_address}")
req = dragon_head_socket.recv_json()
req = server.recv()
print(f"Received request: {req}")
drg_req = request_serializer.deserialize_from_json(str(req))
resp = dragon_backend.process_request(drg_req)
resp = dragon_backend.process_request(req)
print(f"Sending response {resp}", flush=True)
dragon_head_socket.send_json(response_serializer.serialize_to_json(resp))
server.send(resp)
if isinstance(resp, DragonShutdownResponse):
SHUTDOWN_INITIATED = True

Expand All @@ -122,17 +121,10 @@ def main(args: argparse.Namespace) -> int:

launcher_socket = context.socket(zmq.REQ)
launcher_socket.connect(args.launching_address)
client = dragonSockets.as_client(launcher_socket)

response = (
_helpers.start_with(DragonBootstrapRequest(address=dragon_head_address))
.then(request_serializer.serialize_to_json)
.then(launcher_socket.send_json)
.then(lambda _: launcher_socket.recv_json())
.then(str)
.then(response_serializer.deserialize_from_json)
.get_result()
)

client.send(DragonBootstrapRequest(address=dragon_head_address))
response = client.recv()
if not isinstance(response, DragonBootstrapResponse):
raise ValueError(
"Could not receive connection confirmation from launcher. Aborting."
Expand Down
79 changes: 25 additions & 54 deletions smartsim/_core/launcher/dragon/dragonLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@

import zmq

import smartsim._core.utils.helpers as _helpers
from smartsim._core.schemas.dragonRequests import request_serializer
from smartsim._core.schemas.dragonResponses import response_serializer
from smartsim._core.launcher.dragon import dragonSockets
from smartsim._core.schemas.types import NonEmptyStr

from ....error import LauncherError
Expand Down Expand Up @@ -108,10 +106,8 @@ def _handshake(self, address: str) -> None:
self._dragon_head_socket = self._context.socket(zmq.REQ)
self._dragon_head_socket.connect(address)
try:
(
_helpers.start_with(DragonHandshakeRequest())
.then(self._send_request)
.then(_assert_schema_type(DragonHandshakeResponse))
_assert_schema_type(
self._send_request(DragonHandshakeRequest()), DragonHandshakeResponse
)
logger.debug(
f"Successful handshake with Dragon server at address {address}"
Expand Down Expand Up @@ -205,23 +201,13 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
)

if address is not None:
server = dragonSockets.as_server(launcher_socket)
logger.debug(f"Listening to {socket_addr}")
request = (
_helpers.start_with(launcher_socket.recv_json())
.then(str)
.then(request_serializer.deserialize_from_json)
.then(_assert_schema_type(DragonBootstrapRequest))
.get_result()
)
request = _assert_schema_type(server.recv(), DragonBootstrapRequest)

dragon_head_address = request.address
logger.debug(f"Connecting launcher to {dragon_head_address}")

(
_helpers.start_with(DragonBootstrapResponse())
.then(response_serializer.serialize_to_json)
.then(launcher_socket.send_json)
)
server.send(DragonBootstrapResponse())

launcher_socket.close()
self._set_timeout(self._timeout)
Expand Down Expand Up @@ -277,8 +263,8 @@ def run(self, step: Step) -> t.Optional[str]:
run_args = step.run_settings.run_args
env = step.run_settings.env_vars
nodes = int(run_args.get("nodes", None) or 1)
response = (
_helpers.start_with(
response = _assert_schema_type(
self._send_request(
DragonRunRequest(
exe=cmd[0],
exe_args=cmd[1:],
Expand All @@ -290,10 +276,8 @@ def run(self, step: Step) -> t.Optional[str]:
output_file=out,
error_file=err,
)
)
.then(self._send_request)
.then(_assert_schema_type(DragonRunResponse))
.get_result()
),
DragonRunResponse,
)
step_id = task_id = str(response.step_id)
elif isinstance(step, LocalStep):
Expand Down Expand Up @@ -328,10 +312,8 @@ def stop(self, step_name: str) -> StepInfo:

stepmap = self.step_mapping[step_name]
step_id = str(stepmap.step_id)
(
_helpers.start_with(DragonStopRequest(step_id=step_id))
.then(self._send_request)
.then(_assert_schema_type(DragonStopResponse))
_assert_schema_type(
self._send_request(DragonStopRequest(step_id=step_id)), DragonStopResponse
)

_, step_info = self.get_step_update([step_name])[0]
Expand All @@ -353,11 +335,9 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]:
if not self.is_connected:
raise LauncherError("Launcher is not connected to Dragon.")

response = (
_helpers.start_with(DragonUpdateStatusRequest(step_ids=step_ids))
.then(self._send_request)
.then(_assert_schema_type(DragonUpdateStatusResponse))
.get_result()
response = _assert_schema_type(
self._send_request(DragonUpdateStatusRequest(step_ids=step_ids)),
DragonUpdateStatusResponse,
)

# create StepInfo objects to return
Expand Down Expand Up @@ -390,7 +370,7 @@ def _send_request(self, request: DragonRequest, flags: int = 0) -> DragonRespons
if (socket := self._dragon_head_socket) is None:
raise LauncherError("Launcher is not connected to Dragon")

return self.send_req_as_json(socket, request, flags)
return self.send_req_with_socket(socket, request, flags)

def __str__(self) -> str:
return "Dragon"
Expand Down Expand Up @@ -429,35 +409,26 @@ def _parse_launched_dragon_server_info_from_files(
return dragon_envs

@staticmethod
def send_req_as_json(
def send_req_with_socket(
socket: zmq.Socket[t.Any], request: DragonRequest, flags: int = 0
) -> DragonResponse:
client = dragonSockets.as_client(socket)
with DRG_LOCK:
logger.debug(f"Sending {type(request).__name__}: {request}")
return (
_helpers.start_with(request)
.then(request_serializer.serialize_to_json)
.then(lambda req: socket.send_json(req, flags))
.then(lambda _: socket.recv_json())
.then(str)
.then(response_serializer.deserialize_from_json)
.get_result()
)

client.send(request, flags)
return client.recv()

def _assert_schema_type(typ: t.Type[_SchemaT], /) -> t.Callable[[object], _SchemaT]:
def _inner(obj: object) -> _SchemaT:
if not isinstance(obj, typ):
raise TypeError("Expected schema of type `{typ}`, but got {type(obj)}")
return obj

return _inner
def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT:
if not isinstance(obj, typ):
raise TypeError("Expected schema of type `{typ}`, but got {type(obj)}")
return obj


def _dragon_cleanup(server_socket: zmq.Socket[t.Any], server_process_pid: int) -> None:
try:
with DRG_LOCK:
DragonLauncher.send_req_as_json(server_socket, DragonShutdownRequest())
DragonLauncher.send_req_with_socket(server_socket, DragonShutdownRequest())
except zmq.error.ZMQError as e:
logger.error(
f"Could not send shutdown request to dragon server, ZMQ error: {e}"
Expand Down
56 changes: 56 additions & 0 deletions smartsim/_core/launcher/dragon/dragonSockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# BSD 2-Clause License
#
# Copyright (c) 2021-2023, Hewlett Packard Enterprise
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import typing as t

from smartsim._core.schemas import dragonRequests as _dragonRequests
from smartsim._core.schemas import dragonResponses as _dragonResponses
from smartsim._core.schemas import utils as _utils

if t.TYPE_CHECKING:
from zmq.sugar.socket import Socket


def as_server(
socket: "Socket[t.Any]",
) -> _utils.SocketSchemaTranslator[
_dragonResponses.DragonResponse,
_dragonRequests.DragonRequest,
]:
return _utils.SocketSchemaTranslator(
socket, _dragonResponses.response_registry, _dragonRequests.request_registry
)


def as_client(
socket: "Socket[t.Any]",
) -> _utils.SocketSchemaTranslator[
_dragonRequests.DragonRequest,
_dragonResponses.DragonResponse,
]:
return _utils.SocketSchemaTranslator(
socket, _dragonRequests.request_registry, _dragonResponses.response_registry
)
17 changes: 8 additions & 9 deletions smartsim/_core/schemas/dragonRequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@
# Black and Pylint disagree about where to put the `...`
# pylint: disable=multiple-statements


class DragonRequest(BaseModel): ...
request_registry = _utils.SchemaRegistry["DragonRequest"]()


request_serializer = _utils.SchemaSerializer[str, DragonRequest]("request_type")
class DragonRequest(BaseModel): ...


class DragonRunRequestView(DragonRequest):
Expand All @@ -55,32 +54,32 @@ class DragonRunRequestView(DragonRequest):
pmi_enabled: bool = True


@request_serializer.register("run")
@request_registry.register("run")
class DragonRunRequest(DragonRunRequestView):
current_env: t.Dict[str, t.Optional[str]] = {}

def __str__(self) -> str:
return str(DragonRunRequestView.parse_obj(self.dict(exclude={"current_env"})))


@request_serializer.register("update_status")
@request_registry.register("update_status")
class DragonUpdateStatusRequest(DragonRequest):
step_ids: t.List[NonEmptyStr]


@request_serializer.register("stop")
@request_registry.register("stop")
class DragonStopRequest(DragonRequest):
step_id: NonEmptyStr


@request_serializer.register("handshake")
@request_registry.register("handshake")
class DragonHandshakeRequest(DragonRequest): ...


@request_serializer.register("bootstrap")
@request_registry.register("bootstrap")
class DragonBootstrapRequest(DragonRequest):
address: NonEmptyStr


@request_serializer.register("shutdown")
@request_registry.register("shutdown")
class DragonShutdownRequest(DragonRequest): ...
17 changes: 8 additions & 9 deletions smartsim/_core/schemas/dragonResponses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,35 @@
# Black and Pylint disagree about where to put the `...`
# pylint: disable=multiple-statements

response_registry = _utils.SchemaRegistry["DragonResponse"]()


class DragonResponse(BaseModel):
error_message: t.Optional[str] = None


response_serializer = _utils.SchemaSerializer[str, DragonResponse]("response_type")


@response_serializer.register("run")
@response_registry.register("run")
class DragonRunResponse(DragonResponse):
step_id: NonEmptyStr


@response_serializer.register("status_update")
@response_registry.register("status_update")
class DragonUpdateStatusResponse(DragonResponse):
# status is a dict: {step_id: (is_alive, returncode)}
statuses: t.Mapping[NonEmptyStr, t.Tuple[NonEmptyStr, t.Optional[t.List[int]]]] = {}


@response_serializer.register("stop")
@response_registry.register("stop")
class DragonStopResponse(DragonResponse): ...


@response_serializer.register("handshake")
@response_registry.register("handshake")
class DragonHandshakeResponse(DragonResponse): ...


@response_serializer.register("bootstrap")
@response_registry.register("bootstrap")
class DragonBootstrapResponse(DragonResponse): ...


@response_serializer.register("shutdown")
@response_registry.register("shutdown")
class DragonShutdownResponse(DragonResponse): ...
Loading
Loading