Skip to content

Commit

Permalink
Make SchemaSerializer a SchemaRegistry
Browse files Browse the repository at this point in the history
  • Loading branch information
MattToast committed Mar 21, 2024
1 parent 372e152 commit e1ef2b2
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 136 deletions.
19 changes: 9 additions & 10 deletions smartsim/_core/entrypoints/dragon.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
DragonBootstrapResponse,
DragonShutdownResponse,
)
from smartsim._core.schemas.dragonRequests import request_serializer
from smartsim._core.schemas.dragonResponses import response_serializer
from smartsim._core.schemas.dragonRequests import request_registry
from smartsim._core.schemas.dragonResponses import response_registry
from smartsim._core.utils.network import get_best_interface_and_address

# kill is not catchable
Expand Down Expand Up @@ -97,12 +97,12 @@ def run(dragon_head_address: str) -> None:

while not SHUTDOWN_INITIATED:
print(f"Listening to {dragon_head_address}")
req = dragon_head_socket.recv_json()
req = dragon_head_socket.recv_string()
print(f"Received request: {req}")
drg_req = request_serializer.deserialize_from_json(str(req))
drg_req = request_registry.from_string(req)
resp = dragon_backend.process_request(drg_req)
print(f"Sending response {resp}", flush=True)
dragon_head_socket.send_json(response_serializer.serialize_to_json(resp))
dragon_head_socket.send_string(response_registry.to_string(resp))
if isinstance(resp, DragonShutdownResponse):
SHUTDOWN_INITIATED = True

Expand All @@ -125,11 +125,10 @@ def main(args: argparse.Namespace) -> int:

response = (
_helpers.start_with(DragonBootstrapRequest(address=dragon_head_address))
.then(request_serializer.serialize_to_json)
.then(launcher_socket.send_json)
.then(lambda _: launcher_socket.recv_json())
.then(str)
.then(response_serializer.deserialize_from_json)
.then(request_registry.to_string)
.then(launcher_socket.send_string)
.then(lambda _: launcher_socket.recv_string())
.then(response_registry.from_string)
.get_result()
)

Expand Down
22 changes: 10 additions & 12 deletions smartsim/_core/launcher/dragon/dragonLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
import zmq

import smartsim._core.utils.helpers as _helpers
from smartsim._core.schemas.dragonRequests import request_serializer
from smartsim._core.schemas.dragonResponses import response_serializer
from smartsim._core.schemas.dragonRequests import request_registry
from smartsim._core.schemas.dragonResponses import response_registry
from smartsim._core.schemas.types import NonEmptyStr

from ....error import LauncherError
Expand Down Expand Up @@ -207,9 +207,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
if address is not None:
logger.debug(f"Listening to {socket_addr}")
request = (
_helpers.start_with(launcher_socket.recv_json())
.then(str)
.then(request_serializer.deserialize_from_json)
_helpers.start_with(launcher_socket.recv_string())
.then(request_registry.from_string)
.then(_assert_schema_type(DragonBootstrapRequest))
.get_result()
)
Expand All @@ -219,8 +218,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:

(
_helpers.start_with(DragonBootstrapResponse())
.then(response_serializer.serialize_to_json)
.then(launcher_socket.send_json)
.then(response_registry.to_string)
.then(launcher_socket.send_string)
)

launcher_socket.close()
Expand Down Expand Up @@ -436,11 +435,10 @@ def send_req_as_json(
logger.debug(f"Sending {type(request).__name__}: {request}")
return (
_helpers.start_with(request)
.then(request_serializer.serialize_to_json)
.then(lambda req: socket.send_json(req, flags))
.then(lambda _: socket.recv_json())
.then(str)
.then(response_serializer.deserialize_from_json)
.then(request_registry.to_string)
.then(lambda req: socket.send_string(req, flags))
.then(lambda _: socket.recv_string())
.then(response_registry.from_string)
.get_result()
)

Expand Down
17 changes: 8 additions & 9 deletions smartsim/_core/schemas/dragonRequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@
# Black and Pylint disagree about where to put the `...`
# pylint: disable=multiple-statements


class DragonRequest(BaseModel): ...
request_registry = _utils.SchemaRegistry["DragonRequest"]()


request_serializer = _utils.SchemaSerializer[str, DragonRequest]("request_type")
class DragonRequest(BaseModel): ...


class DragonRunRequestView(DragonRequest):
Expand All @@ -55,32 +54,32 @@ class DragonRunRequestView(DragonRequest):
pmi_enabled: bool = True


@request_serializer.register("run")
@request_registry.register("run")
class DragonRunRequest(DragonRunRequestView):
current_env: t.Dict[str, t.Optional[str]] = {}

def __str__(self) -> str:
return str(DragonRunRequestView.parse_obj(self.dict(exclude={"current_env"})))


@request_serializer.register("update_status")
@request_registry.register("update_status")
class DragonUpdateStatusRequest(DragonRequest):
step_ids: t.List[NonEmptyStr]


@request_serializer.register("stop")
@request_registry.register("stop")
class DragonStopRequest(DragonRequest):
step_id: NonEmptyStr


@request_serializer.register("handshake")
@request_registry.register("handshake")
class DragonHandshakeRequest(DragonRequest): ...


@request_serializer.register("bootstrap")
@request_registry.register("bootstrap")
class DragonBootstrapRequest(DragonRequest):
address: NonEmptyStr


@request_serializer.register("shutdown")
@request_registry.register("shutdown")
class DragonShutdownRequest(DragonRequest): ...
17 changes: 8 additions & 9 deletions smartsim/_core/schemas/dragonResponses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,35 @@
# Black and Pylint disagree about where to put the `...`
# pylint: disable=multiple-statements

response_registry = _utils.SchemaRegistry["DragonResponse"]()


class DragonResponse(BaseModel):
error_message: t.Optional[str] = None


response_serializer = _utils.SchemaSerializer[str, DragonResponse]("response_type")


@response_serializer.register("run")
@response_registry.register("run")
class DragonRunResponse(DragonResponse):
step_id: NonEmptyStr


@response_serializer.register("status_update")
@response_registry.register("status_update")
class DragonUpdateStatusResponse(DragonResponse):
# status is a dict: {step_id: (is_alive, returncode)}
statuses: t.Mapping[NonEmptyStr, t.Tuple[NonEmptyStr, t.Optional[t.List[int]]]] = {}


@response_serializer.register("stop")
@response_registry.register("stop")
class DragonStopResponse(DragonResponse): ...


@response_serializer.register("handshake")
@response_registry.register("handshake")
class DragonHandshakeResponse(DragonResponse): ...


@response_serializer.register("bootstrap")
@response_registry.register("bootstrap")
class DragonBootstrapResponse(DragonResponse): ...


@response_serializer.register("shutdown")
@response_registry.register("shutdown")
class DragonShutdownResponse(DragonResponse): ...
72 changes: 48 additions & 24 deletions smartsim/_core/schemas/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,46 @@
import json
import typing as t
from dataclasses import dataclass

import pydantic

_KeyT = t.TypeVar("_KeyT")
_SchemaT = t.TypeVar("_SchemaT", bound=pydantic.BaseModel)


class SchemaSerializer(t.Generic[_KeyT, _SchemaT]):
@t.final
@dataclass(frozen=True)
class _Message(t.Generic[_SchemaT]):
header: str
payload: _SchemaT
delimiter: str

def __str__(self) -> str:
return self.delimiter.join((self.header, self.payload.json()))

@classmethod
def from_str(
cls, str_: str, delimiter: str, payload_type: t.Type[_SchemaT]
) -> "_Message[_SchemaT]":
header, payload = str_.split(delimiter, 1)
return cls(header, payload_type.parse_raw(payload), delimiter)


class SchemaRegistry(t.Generic[_SchemaT]):
_DEFAULT_DELIMITER = "|"

def __init__(
self,
type_name: str,
init_map: t.Optional[t.Mapping[_KeyT, t.Type[_SchemaT]]] = None,
message_delimiter: str = _DEFAULT_DELIMITER,
init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None,
):
if not message_delimiter:
raise ValueError("Message delimiter cannot be an empty string")
self._msg_delim = message_delimiter
self._map = dict(init_map) if init_map else {}
self._type_name_key = f"__{type_name}__"

def register(self, key: _KeyT) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]:
def register(self, key: str) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]:
if self._msg_delim in key:
_msg = f"Registry key cannot contain delimiter `{self._msg_delim}`"
raise ValueError(_msg)
if key in self._map:
raise KeyError(f"Key `{key}` has already been registered for this parser")

Expand All @@ -26,30 +50,30 @@ def _register(cls: t.Type[_SchemaT]) -> t.Type[_SchemaT]:

return _register

def schema_to_dict(self, schema: _SchemaT) -> t.Dict[str, t.Any]:
def to_string(self, schema: _SchemaT) -> str:
return str(self._to_message(schema))

def _to_message(self, schema: _SchemaT) -> _Message[_SchemaT]:
reverse_map = dict((v, k) for k, v in self._map.items())
try:
val = reverse_map[type(schema)]
except KeyError:
raise TypeError(f"Unregistered schema type: {type(schema)}") from None
# TODO: This method is deprectated in pydantic >= 2
dict_ = schema.dict()
dict_[self._type_name_key] = val
return dict_
return _Message(val, schema, self._msg_delim)

def serialize_to_json(self, schema: _SchemaT) -> str:
return json.dumps(self.schema_to_dict(schema))

def mapping_to_schema(self, obj: t.Mapping[t.Any, t.Any]) -> _SchemaT:
def from_string(self, str_: str) -> _SchemaT:
try:
type_ = obj[self._type_name_key]
except KeyError:
raise ValueError(f"Could not parse object: {obj}") from None
header, _ = str_.split(self._msg_delim, 1)
except ValueError:
_msg = f"Failed to find message header in string {repr(str_)}"
raise ValueError(_msg) from None
try:
cls = self._map[type_]
cls = self._map[header]
except KeyError:
raise ValueError(f"No type of value `{type_}` is registered") from None
return cls.parse_obj(obj)
raise ValueError(f"No type of value `{header}` is registered") from None
msg = _Message.from_str(str_, self._msg_delim, cls)
return self._from_message(msg)

def deserialize_from_json(self, obj: str) -> _SchemaT:
return self.mapping_to_schema(json.loads(obj))
@staticmethod
def _from_message(msg: _Message[_SchemaT]) -> _SchemaT:
return msg.payload
Loading

0 comments on commit e1ef2b2

Please sign in to comment.