diff --git a/smartsim/_core/entrypoints/dragon.py b/smartsim/_core/entrypoints/dragon.py index f2314564b..938656677 100644 --- a/smartsim/_core/entrypoints/dragon.py +++ b/smartsim/_core/entrypoints/dragon.py @@ -35,15 +35,13 @@ import zmq -import smartsim._core.utils.helpers as _helpers +from smartsim._core.launcher.dragon import dragonSockets from smartsim._core.launcher.dragon.dragonBackend import DragonBackend from smartsim._core.schemas import ( DragonBootstrapRequest, DragonBootstrapResponse, DragonShutdownResponse, ) -from smartsim._core.schemas.dragonRequests import request_serializer -from smartsim._core.schemas.dragonResponses import response_serializer from smartsim._core.utils.network import get_best_interface_and_address # kill is not catchable @@ -95,14 +93,15 @@ def run(dragon_head_address: str) -> None: dragon_head_socket.bind(dragon_head_address) dragon_backend = DragonBackend() + server = dragonSockets.as_server(dragon_head_socket) + while not SHUTDOWN_INITIATED: print(f"Listening to {dragon_head_address}") - req = dragon_head_socket.recv_json() + req = server.recv() print(f"Received request: {req}") - drg_req = request_serializer.deserialize_from_json(str(req)) - resp = dragon_backend.process_request(drg_req) + resp = dragon_backend.process_request(req) print(f"Sending response {resp}", flush=True) - dragon_head_socket.send_json(response_serializer.serialize_to_json(resp)) + server.send(resp) if isinstance(resp, DragonShutdownResponse): SHUTDOWN_INITIATED = True @@ -122,17 +121,10 @@ def main(args: argparse.Namespace) -> int: launcher_socket = context.socket(zmq.REQ) launcher_socket.connect(args.launching_address) + client = dragonSockets.as_client(launcher_socket) - response = ( - _helpers.start_with(DragonBootstrapRequest(address=dragon_head_address)) - .then(request_serializer.serialize_to_json) - .then(launcher_socket.send_json) - .then(lambda _: launcher_socket.recv_json()) - .then(str) - .then(response_serializer.deserialize_from_json) - .get_result() - ) - + client.send(DragonBootstrapRequest(address=dragon_head_address)) + response = client.recv() if not isinstance(response, DragonBootstrapResponse): raise ValueError( "Could not receive connection confirmation from launcher. Aborting." diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 3eb900f31..8022ce01b 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -40,9 +40,7 @@ import zmq -import smartsim._core.utils.helpers as _helpers -from smartsim._core.schemas.dragonRequests import request_serializer -from smartsim._core.schemas.dragonResponses import response_serializer +from smartsim._core.launcher.dragon import dragonSockets from smartsim._core.schemas.types import NonEmptyStr from ....error import LauncherError @@ -108,10 +106,8 @@ def _handshake(self, address: str) -> None: self._dragon_head_socket = self._context.socket(zmq.REQ) self._dragon_head_socket.connect(address) try: - ( - _helpers.start_with(DragonHandshakeRequest()) - .then(self._send_request) - .then(_assert_schema_type(DragonHandshakeResponse)) + _assert_schema_type( + self._send_request(DragonHandshakeRequest()), DragonHandshakeResponse ) logger.debug( f"Successful handshake with Dragon server at address {address}" @@ -205,23 +201,13 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None: ) if address is not None: + server = dragonSockets.as_server(launcher_socket) logger.debug(f"Listening to {socket_addr}") - request = ( - _helpers.start_with(launcher_socket.recv_json()) - .then(str) - .then(request_serializer.deserialize_from_json) - .then(_assert_schema_type(DragonBootstrapRequest)) - .get_result() - ) + request = _assert_schema_type(server.recv(), DragonBootstrapRequest) dragon_head_address = request.address logger.debug(f"Connecting launcher to {dragon_head_address}") - - ( - _helpers.start_with(DragonBootstrapResponse()) - .then(response_serializer.serialize_to_json) - .then(launcher_socket.send_json) - ) + server.send(DragonBootstrapResponse()) launcher_socket.close() self._set_timeout(self._timeout) @@ -277,8 +263,8 @@ def run(self, step: Step) -> t.Optional[str]: run_args = step.run_settings.run_args env = step.run_settings.env_vars nodes = int(run_args.get("nodes", None) or 1) - response = ( - _helpers.start_with( + response = _assert_schema_type( + self._send_request( DragonRunRequest( exe=cmd[0], exe_args=cmd[1:], @@ -290,10 +276,8 @@ def run(self, step: Step) -> t.Optional[str]: output_file=out, error_file=err, ) - ) - .then(self._send_request) - .then(_assert_schema_type(DragonRunResponse)) - .get_result() + ), + DragonRunResponse, ) step_id = task_id = str(response.step_id) elif isinstance(step, LocalStep): @@ -328,10 +312,8 @@ def stop(self, step_name: str) -> StepInfo: stepmap = self.step_mapping[step_name] step_id = str(stepmap.step_id) - ( - _helpers.start_with(DragonStopRequest(step_id=step_id)) - .then(self._send_request) - .then(_assert_schema_type(DragonStopResponse)) + _assert_schema_type( + self._send_request(DragonStopRequest(step_id=step_id)), DragonStopResponse ) _, step_info = self.get_step_update([step_name])[0] @@ -353,11 +335,9 @@ def _get_managed_step_update(self, step_ids: t.List[str]) -> t.List[StepInfo]: if not self.is_connected: raise LauncherError("Launcher is not connected to Dragon.") - response = ( - _helpers.start_with(DragonUpdateStatusRequest(step_ids=step_ids)) - .then(self._send_request) - .then(_assert_schema_type(DragonUpdateStatusResponse)) - .get_result() + response = _assert_schema_type( + self._send_request(DragonUpdateStatusRequest(step_ids=step_ids)), + DragonUpdateStatusResponse, ) # create StepInfo objects to return @@ -390,7 +370,7 @@ def _send_request(self, request: DragonRequest, flags: int = 0) -> DragonRespons if (socket := self._dragon_head_socket) is None: raise LauncherError("Launcher is not connected to Dragon") - return self.send_req_as_json(socket, request, flags) + return self.send_req_with_socket(socket, request, flags) def __str__(self) -> str: return "Dragon" @@ -429,35 +409,26 @@ def _parse_launched_dragon_server_info_from_files( return dragon_envs @staticmethod - def send_req_as_json( + def send_req_with_socket( socket: zmq.Socket[t.Any], request: DragonRequest, flags: int = 0 ) -> DragonResponse: + client = dragonSockets.as_client(socket) with DRG_LOCK: logger.debug(f"Sending {type(request).__name__}: {request}") - return ( - _helpers.start_with(request) - .then(request_serializer.serialize_to_json) - .then(lambda req: socket.send_json(req, flags)) - .then(lambda _: socket.recv_json()) - .then(str) - .then(response_serializer.deserialize_from_json) - .get_result() - ) - + client.send(request, flags) + return client.recv() -def _assert_schema_type(typ: t.Type[_SchemaT], /) -> t.Callable[[object], _SchemaT]: - def _inner(obj: object) -> _SchemaT: - if not isinstance(obj, typ): - raise TypeError("Expected schema of type `{typ}`, but got {type(obj)}") - return obj - return _inner +def _assert_schema_type(obj: object, typ: t.Type[_SchemaT], /) -> _SchemaT: + if not isinstance(obj, typ): + raise TypeError("Expected schema of type `{typ}`, but got {type(obj)}") + return obj def _dragon_cleanup(server_socket: zmq.Socket[t.Any], server_process_pid: int) -> None: try: with DRG_LOCK: - DragonLauncher.send_req_as_json(server_socket, DragonShutdownRequest()) + DragonLauncher.send_req_with_socket(server_socket, DragonShutdownRequest()) except zmq.error.ZMQError as e: logger.error( f"Could not send shutdown request to dragon server, ZMQ error: {e}" diff --git a/smartsim/_core/launcher/dragon/dragonSockets.py b/smartsim/_core/launcher/dragon/dragonSockets.py new file mode 100644 index 000000000..a300208d0 --- /dev/null +++ b/smartsim/_core/launcher/dragon/dragonSockets.py @@ -0,0 +1,56 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2023, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from smartsim._core.schemas import dragonRequests as _dragonRequests +from smartsim._core.schemas import dragonResponses as _dragonResponses +from smartsim._core.schemas import utils as _utils + +if t.TYPE_CHECKING: + from zmq.sugar.socket import Socket + + +def as_server( + socket: "Socket[t.Any]", +) -> _utils.SocketSchemaTranslator[ + _dragonResponses.DragonResponse, + _dragonRequests.DragonRequest, +]: + return _utils.SocketSchemaTranslator( + socket, _dragonResponses.response_registry, _dragonRequests.request_registry + ) + + +def as_client( + socket: "Socket[t.Any]", +) -> _utils.SocketSchemaTranslator[ + _dragonRequests.DragonRequest, + _dragonResponses.DragonResponse, +]: + return _utils.SocketSchemaTranslator( + socket, _dragonRequests.request_registry, _dragonResponses.response_registry + ) diff --git a/smartsim/_core/schemas/dragonRequests.py b/smartsim/_core/schemas/dragonRequests.py index 7ddfd28b5..0f05ae72f 100644 --- a/smartsim/_core/schemas/dragonRequests.py +++ b/smartsim/_core/schemas/dragonRequests.py @@ -34,11 +34,10 @@ # Black and Pylint disagree about where to put the `...` # pylint: disable=multiple-statements - -class DragonRequest(BaseModel): ... +request_registry = _utils.SchemaRegistry["DragonRequest"]() -request_serializer = _utils.SchemaSerializer[str, DragonRequest]("request_type") +class DragonRequest(BaseModel): ... class DragonRunRequestView(DragonRequest): @@ -55,7 +54,7 @@ class DragonRunRequestView(DragonRequest): pmi_enabled: bool = True -@request_serializer.register("run") +@request_registry.register("run") class DragonRunRequest(DragonRunRequestView): current_env: t.Dict[str, t.Optional[str]] = {} @@ -63,24 +62,24 @@ def __str__(self) -> str: return str(DragonRunRequestView.parse_obj(self.dict(exclude={"current_env"}))) -@request_serializer.register("update_status") +@request_registry.register("update_status") class DragonUpdateStatusRequest(DragonRequest): step_ids: t.List[NonEmptyStr] -@request_serializer.register("stop") +@request_registry.register("stop") class DragonStopRequest(DragonRequest): step_id: NonEmptyStr -@request_serializer.register("handshake") +@request_registry.register("handshake") class DragonHandshakeRequest(DragonRequest): ... -@request_serializer.register("bootstrap") +@request_registry.register("bootstrap") class DragonBootstrapRequest(DragonRequest): address: NonEmptyStr -@request_serializer.register("shutdown") +@request_registry.register("shutdown") class DragonShutdownRequest(DragonRequest): ... diff --git a/smartsim/_core/schemas/dragonResponses.py b/smartsim/_core/schemas/dragonResponses.py index 09562c80a..13a40bb4d 100644 --- a/smartsim/_core/schemas/dragonResponses.py +++ b/smartsim/_core/schemas/dragonResponses.py @@ -34,36 +34,35 @@ # Black and Pylint disagree about where to put the `...` # pylint: disable=multiple-statements +response_registry = _utils.SchemaRegistry["DragonResponse"]() + class DragonResponse(BaseModel): error_message: t.Optional[str] = None -response_serializer = _utils.SchemaSerializer[str, DragonResponse]("response_type") - - -@response_serializer.register("run") +@response_registry.register("run") class DragonRunResponse(DragonResponse): step_id: NonEmptyStr -@response_serializer.register("status_update") +@response_registry.register("status_update") class DragonUpdateStatusResponse(DragonResponse): # status is a dict: {step_id: (is_alive, returncode)} statuses: t.Mapping[NonEmptyStr, t.Tuple[NonEmptyStr, t.Optional[t.List[int]]]] = {} -@response_serializer.register("stop") +@response_registry.register("stop") class DragonStopResponse(DragonResponse): ... -@response_serializer.register("handshake") +@response_registry.register("handshake") class DragonHandshakeResponse(DragonResponse): ... -@response_serializer.register("bootstrap") +@response_registry.register("bootstrap") class DragonBootstrapResponse(DragonResponse): ... -@response_serializer.register("shutdown") +@response_registry.register("shutdown") class DragonShutdownResponse(DragonResponse): ... diff --git a/smartsim/_core/schemas/utils.py b/smartsim/_core/schemas/utils.py index 3944469a6..56f0d62b3 100644 --- a/smartsim/_core/schemas/utils.py +++ b/smartsim/_core/schemas/utils.py @@ -1,22 +1,78 @@ -import json +# BSD 2-Clause License +# +# Copyright (c) 2021-2023, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import dataclasses import typing as t import pydantic +import pydantic.dataclasses + +if t.TYPE_CHECKING: + from zmq.sugar.socket import Socket -_KeyT = t.TypeVar("_KeyT") _SchemaT = t.TypeVar("_SchemaT", bound=pydantic.BaseModel) +_SendT = t.TypeVar("_SendT", bound=pydantic.BaseModel) +_RecvT = t.TypeVar("_RecvT", bound=pydantic.BaseModel) + +_DEFAULT_MSG_DELIM: t.Final[str] = "|" + + +@t.final +@pydantic.dataclasses.dataclass(frozen=True) +class _Message(t.Generic[_SchemaT]): + payload: _SchemaT + header: str = pydantic.Field(min_length=1) + delimiter: str = pydantic.Field(min_length=1, default=_DEFAULT_MSG_DELIM) + + def __str__(self) -> str: + return self.delimiter.join((self.header, self.payload.json())) + + @classmethod + def from_str( + cls, + str_: str, + payload_type: t.Type[_SchemaT], + delimiter: str = _DEFAULT_MSG_DELIM, + ) -> "_Message[_SchemaT]": + header, payload = str_.split(delimiter, 1) + return cls(payload_type.parse_raw(payload), header, delimiter) -class SchemaSerializer(t.Generic[_KeyT, _SchemaT]): +class SchemaRegistry(t.Generic[_SchemaT]): def __init__( - self, - type_name: str, - init_map: t.Optional[t.Mapping[_KeyT, t.Type[_SchemaT]]] = None, - ): + self, init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None + ) -> None: self._map = dict(init_map) if init_map else {} - self._type_name_key = f"__{type_name}__" - def register(self, key: _KeyT) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]: + def register(self, key: str) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]: + if _DEFAULT_MSG_DELIM in key: + _msg = f"Registry key cannot contain delimiter `{_DEFAULT_MSG_DELIM}`" + raise ValueError(_msg) + if not key: + raise KeyError("Key cannot be the empty string") if key in self._map: raise KeyError(f"Key `{key}` has already been registered for this parser") @@ -26,30 +82,43 @@ def _register(cls: t.Type[_SchemaT]) -> t.Type[_SchemaT]: return _register - def schema_to_dict(self, schema: _SchemaT) -> t.Dict[str, t.Any]: + def to_string(self, schema: _SchemaT) -> str: + return str(self._to_message(schema)) + + def _to_message(self, schema: _SchemaT) -> _Message[_SchemaT]: reverse_map = dict((v, k) for k, v in self._map.items()) try: val = reverse_map[type(schema)] except KeyError: raise TypeError(f"Unregistered schema type: {type(schema)}") from None - # TODO: This method is deprectated in pydantic >= 2 - dict_ = schema.dict() - dict_[self._type_name_key] = val - return dict_ - - def serialize_to_json(self, schema: _SchemaT) -> str: - return json.dumps(self.schema_to_dict(schema)) + return _Message(schema, val, _DEFAULT_MSG_DELIM) - def mapping_to_schema(self, obj: t.Mapping[t.Any, t.Any]) -> _SchemaT: + def from_string(self, str_: str) -> _SchemaT: try: - type_ = obj[self._type_name_key] - except KeyError: - raise ValueError(f"Could not parse object: {obj}") from None + type_, _ = str_.split(_DEFAULT_MSG_DELIM, 1) + except ValueError: + _msg = f"Failed to determine schema type of the string {repr(str_)}" + raise ValueError(_msg) from None try: cls = self._map[type_] except KeyError: raise ValueError(f"No type of value `{type_}` is registered") from None - return cls.parse_obj(obj) + msg = _Message.from_str(str_, cls, _DEFAULT_MSG_DELIM) + return self._from_message(msg) + + @staticmethod + def _from_message(msg: _Message[_SchemaT]) -> _SchemaT: + return msg.payload + + +@dataclasses.dataclass(frozen=True) +class SocketSchemaTranslator(t.Generic[_SendT, _RecvT]): + socket: "Socket[t.Any]" + _send_registry: SchemaRegistry[_SendT] + _recv_registry: SchemaRegistry[_RecvT] + + def send(self, schema: _SendT, flags: int = 0) -> None: + self.socket.send_string(self._send_registry.to_string(schema), flags) - def deserialize_from_json(self, obj: str) -> _SchemaT: - return self.mapping_to_schema(json.loads(obj)) + def recv(self) -> _RecvT: + return self._recv_registry.from_string(self.socket.recv_string()) diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index f914f463b..e5061cf2a 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -31,7 +31,6 @@ import os import typing as t import uuid -from dataclasses import dataclass from datetime import datetime from functools import lru_cache from pathlib import Path @@ -39,9 +38,6 @@ from smartsim._core._install.builder import TRedisAIBackendStr as _TRedisAIBackendStr -_T = t.TypeVar("_T") -_U = t.TypeVar("_U") - def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]: """Unpack the unformatted database identifier @@ -306,43 +302,3 @@ def decode_cmd(encoded_cmd: str) -> t.List[str]: cleaned_cmd = decoded_cmd.decode("ascii").split("|") return cleaned_cmd - - -@t.final -@dataclass(frozen=True) -class _Pipeline(t.Generic[_T]): - """Utility class to turn - - ..highlight:: python - ..code-block:: python - - result = calls(function(nested(deeply(very(some(data)))))) - - into - - ..highlight:: python - ..code-block:: python - - result = (_Pipeline(data) - .then(some) - .then(very) - .then(deeply) - .then(nested) - .then(function) - .then(calls) - .get_result()) - - without the need to introduce confusing temporary variable names - """ - - _val: _T - - def then(self, fn: t.Callable[[_T], _U], /) -> "_Pipeline[_U]": - return _Pipeline(fn(self._val)) - - def get_result(self) -> _T: - return self._val - - -def start_with(obj: _T) -> _Pipeline[_T]: - return _Pipeline(obj) diff --git a/tests/test_schema_utils.py b/tests/test_schema_utils.py new file mode 100644 index 000000000..238c5c9b5 --- /dev/null +++ b/tests/test_schema_utils.py @@ -0,0 +1,217 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2023, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import collections +import json + +import pydantic +import pytest + +from smartsim._core.schemas.utils import ( + _DEFAULT_MSG_DELIM, + SchemaRegistry, + SocketSchemaTranslator, + _Message, +) + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_b + + +class Person(pydantic.BaseModel): + name: str + age: int + + +class Dog(pydantic.BaseModel): + name: str + age: int + + +class Book(pydantic.BaseModel): + title: str + num_pages: int + + +def test_equivalent_messages_are_equivalent(): + book = Book(title="A Story", num_pages=250) + msg_1 = _Message(book, "header") + msg_2 = _Message(book, "header") + + assert msg_1 is not msg_2 + assert msg_1 == msg_2 + assert str(msg_1) == str(msg_2) + assert msg_1 == _Message.from_str(str(msg_1), Book) + + +def test_schema_registrartion(): + registry = SchemaRegistry() + assert registry._map == {} + + registry.register("person")(Person) + assert registry._map == {"person": Person} + + registry.register("book")(Book) + assert registry._map == {"person": Person, "book": Book} + + +def test_cannot_register_a_schema_under_an_empty_str(): + registry = SchemaRegistry() + with pytest.raises(KeyError, match="Key cannot be the empty string"): + registry.register("") + + +def test_schema_to_string(): + registry = SchemaRegistry() + registry.register("person")(Person) + registry.register("book")(Book) + person = Person(name="Bob", age=36) + book = Book(title="The Greatest Story of All Time", num_pages=10_000) + assert registry.to_string(person) == str(_Message(person, "person")) + assert registry.to_string(book) == str(_Message(book, "book")) + + +def test_schemas_with_same_shape_are_mapped_correctly(): + registry = SchemaRegistry() + registry.register("person")(Person) + registry.register("dog")(Dog) + + person = Person(name="Mark", age=34) + dog = Dog(name="Fido", age=5) + + parsed_person = registry.from_string(registry.to_string(person)) + parsed_dog = registry.from_string(registry.to_string(dog)) + + assert isinstance(parsed_person, Person) + assert isinstance(parsed_dog, Dog) + + assert parsed_person == person + assert parsed_dog == dog + + +def test_registry_errors_if_types_overloaded(): + registry = SchemaRegistry() + registry.register("schema")(Person) + + with pytest.raises(KeyError): + registry.register("schema")(Book) + + +def test_registry_errors_if_msg_type_registered_with_delim_present(): + registry = SchemaRegistry() + with pytest.raises(ValueError, match="cannot contain delimiter"): + registry.register(f"some_key_with_the_{_DEFAULT_MSG_DELIM}_as_a_substring") + + +def test_registry_errors_on_unknown_schema(): + registry = SchemaRegistry() + registry.register("person")(Person) + + with pytest.raises(TypeError): + registry.to_string(Book(title="The Shortest Story of All Time", num_pages=1)) + + +def test_registry_correctly_maps_to_expected_type(): + registry = SchemaRegistry() + registry.register("person")(Person) + registry.register("book")(Book) + person = Person(name="Bob", age=36) + book = Book(title="The Most Average Story of All Time", num_pages=500) + assert registry.from_string(str(_Message(person, "person"))) == person + assert registry.from_string(str(_Message(book, "book"))) == book + + +def test_registery_errors_if_type_key_not_recognized(): + registry = SchemaRegistry() + registry.register("person")(Person) + + with pytest.raises(ValueError, match="^No type of value .* registered$"): + registry.from_string(str(_Message(Person(name="Grunk", age=5_000), "alien"))) + + +def test_registry_errors_if_type_key_is_missing(): + registry = SchemaRegistry() + registry.register("person")(Person) + + with pytest.raises(ValueError, match="Failed to determine schema type"): + registry.from_string("This string does not contain a delimiter") + + +class MockSocket: + def __init__(self, send_queue, recv_queue): + self.send_queue = send_queue + self.recv_queue = recv_queue + + def send_string(self, str_, *_args, **_kwargs): + assert isinstance(str_, str) + self.send_queue.append(str_) + + def recv_string(self, *_args, **_kwargs): + str_ = self.recv_queue.popleft() + assert isinstance(str_, str) + return str_ + + +class Request(pydantic.BaseModel): ... + + +class Response(pydantic.BaseModel): ... + + +def test_socket_schema_translator_uses_schema_registries(): + server_to_client = collections.deque() + client_to_server = collections.deque() + + server_socket = MockSocket(server_to_client, client_to_server) + client_socket = MockSocket(client_to_server, server_to_client) + + req_reg = SchemaRegistry() + res_reg = SchemaRegistry() + + req_reg.register("message")(Request) + res_reg.register("message")(Response) + + server = SocketSchemaTranslator(server_socket, res_reg, req_reg) + client = SocketSchemaTranslator(client_socket, req_reg, res_reg) + + # Check sockets are able to communicate seamlessly with schemas only + client.send(Request()) + assert len(client_to_server) == 1 + req = server.recv() + assert len(client_to_server) == 0 + assert isinstance(req, Request) + + server.send(Response()) + assert len(server_to_client) == 1 + res = client.recv() + assert len(server_to_client) == 0 + assert isinstance(res, Response) + + # Ensure users cannot send unexpected schemas + with pytest.raises(TypeError, match="Unregistered schema"): + client.send(Response()) + with pytest.raises(TypeError, match="Unregistered schema"): + server.send(Request()) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py deleted file mode 100644 index f73dfff59..000000000 --- a/tests/utils/test_helpers.py +++ /dev/null @@ -1,18 +0,0 @@ -import itertools - -import pytest - -from smartsim._core.utils import helpers - - -@pytest.mark.parametrize( - "func_1, func_2, func_3", - itertools.permutations((lambda x: x + 3, lambda x: x * 2, lambda x: x // 5)), -) -def test_pipline(func_1, func_2, func_3): - x = 30 - assert ( - func_3(func_2(func_1(x))) - == helpers._Pipeline(x).then(func_1).then(func_2).then(func_3).get_result() - == helpers.start_with(x).then(func_1).then(func_2).then(func_3).get_result() - ) diff --git a/tests/utils/test_schema_utils.py b/tests/utils/test_schema_utils.py deleted file mode 100644 index f1b0ef286..000000000 --- a/tests/utils/test_schema_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -import json - -import pydantic -import pytest - -from smartsim._core.schemas.utils import SchemaSerializer - - -class Person(pydantic.BaseModel): - name: str - age: int - - -class Book(pydantic.BaseModel): - title: str - num_pages: int - - -def test_schema_registrartion(): - serializer = SchemaSerializer("test_type") - assert serializer._map == {} - - serializer.register("person")(Person) - assert serializer._map == {"person": Person} - - serializer.register("book")(Book) - assert serializer._map == {"person": Person, "book": Book} - - -def test_serialize_schema(): - serializer = SchemaSerializer("test_type") - serializer.register("person")(Person) - serializer.register("book")(Book) - assert json.loads(serializer.serialize_to_json(Person(name="Bob", age=36))) == { - "__test_type__": "person", - "name": "Bob", - "age": 36, - } - assert json.loads( - serializer.serialize_to_json( - Book(title="The Greatest Story of All Time", num_pages=10_000) - ) - ) == { - "__test_type__": "book", - "title": "The Greatest Story of All Time", - "num_pages": 10_000, - } - - -def test_serializer_errors_if_types_overloaded(): - serializer = SchemaSerializer("test_type") - serializer.register("schema")(Person) - - with pytest.raises(KeyError): - serializer.register("schema")(Book) - - -def test_serializer_errors_on_unknown_schema(): - serializer = SchemaSerializer("test_type") - serializer.register("person")(Person) - - with pytest.raises(TypeError): - serializer.serialize_to_json( - Book(title="The Shortest Story of All Time", num_pages=1) - ) - - -def test_deserialize_json(): - serializer = SchemaSerializer("test_type") - serializer.register("person")(Person) - serializer.register("book")(Book) - assert serializer.deserialize_from_json( - json.dumps({"__test_type__": "person", "name": "Bob", "age": 36}) - ) == Person(name="Bob", age=36) - assert serializer.deserialize_from_json( - json.dumps( - { - "__test_type__": "book", - "title": "The Most Average Story of All Time", - "num_pages": 500, - } - ) - ) == Book(title="The Most Average Story of All Time", num_pages=500) - - -def test_deserialize_error_if_type_key_not_recognized(): - serializer = SchemaSerializer("test_type") - serializer.register("person")(Person) - - with pytest.raises(ValueError): - serializer.deserialize_from_json( - json.dumps( - {"__test_type__": "alien", "name": "Bob the Alien", "age": 5_000} - ) - ) - - -def test_deserialize_error_if_type_key_is_missing(): - serializer = SchemaSerializer("test_type") - serializer.register("person")(Person) - - with pytest.raises(ValueError): - serializer.deserialize_from_json(json.dumps({"name": "Bob", "age": 36}))