Skip to content

Commit

Permalink
rm msg delim attr
Browse files Browse the repository at this point in the history
  • Loading branch information
MattToast committed Mar 25, 2024
1 parent 2de8448 commit c010fca
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 59 deletions.
30 changes: 14 additions & 16 deletions smartsim/_core/schemas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,41 +37,39 @@
_SendT = t.TypeVar("_SendT", bound=pydantic.BaseModel)
_RecvT = t.TypeVar("_RecvT", bound=pydantic.BaseModel)

_DEFAULT_MSG_DELIM: t.Final[str] = "|"


@t.final
@pydantic.dataclasses.dataclass(frozen=True)
class _Message(t.Generic[_SchemaT]):
payload: _SchemaT
header: str = pydantic.Field(min_length=1)
delimiter: str = pydantic.Field(min_length=1)
delimiter: str = pydantic.Field(min_length=1, default=_DEFAULT_MSG_DELIM)

def __str__(self) -> str:
return self.delimiter.join((self.header, self.payload.json()))

@classmethod
def from_str(
cls, str_: str, delimiter: str, payload_type: t.Type[_SchemaT]
cls,
str_: str,
payload_type: t.Type[_SchemaT],
delimiter: str = _DEFAULT_MSG_DELIM,
) -> "_Message[_SchemaT]":
header, payload = str_.split(delimiter, 1)
return cls(payload_type.parse_raw(payload), header, delimiter)


class SchemaRegistry(t.Generic[_SchemaT]):
_DEFAULT_DELIMITER = "|"

def __init__(
self,
message_delimiter: str = _DEFAULT_DELIMITER,
init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None,
):
if not message_delimiter:
raise ValueError("Message delimiter cannot be an empty string")
self._msg_delim = message_delimiter
self, init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None
) -> None:
self._map = dict(init_map) if init_map else {}

def register(self, key: str) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]:
if self._msg_delim in key:
_msg = f"Registry key cannot contain delimiter `{self._msg_delim}`"
if _DEFAULT_MSG_DELIM in key:
_msg = f"Registry key cannot contain delimiter `{_DEFAULT_MSG_DELIM}`"
raise ValueError(_msg)
if not key:
raise KeyError("Key cannot be the empty string")
Expand All @@ -93,19 +91,19 @@ def _to_message(self, schema: _SchemaT) -> _Message[_SchemaT]:
val = reverse_map[type(schema)]
except KeyError:
raise TypeError(f"Unregistered schema type: {type(schema)}") from None
return _Message(schema, val, self._msg_delim)
return _Message(schema, val, _DEFAULT_MSG_DELIM)

def from_string(self, str_: str) -> _SchemaT:
try:
type_, _ = str_.split(self._msg_delim, 1)
type_, _ = str_.split(_DEFAULT_MSG_DELIM, 1)
except ValueError:
_msg = f"Failed to determine schema type of the string {repr(str_)}"
raise ValueError(_msg) from None
try:
cls = self._map[type_]
except KeyError:
raise ValueError(f"No type of value `{type_}` is registered") from None
msg = _Message.from_str(str_, self._msg_delim, cls)
msg = _Message.from_str(str_, cls, _DEFAULT_MSG_DELIM)
return self._from_message(msg)

@staticmethod
Expand Down
59 changes: 16 additions & 43 deletions tests/test_schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import pytest

from smartsim._core.schemas.utils import (
_DEFAULT_MSG_DELIM,
SchemaRegistry,
SocketSchemaTranslator,
_Message,
Expand All @@ -57,13 +58,13 @@ class Book(pydantic.BaseModel):

def test_equivalent_messages_are_equivalent():
book = Book(title="A Story", num_pages=250)
msg_1 = _Message(book, "header", "::")
msg_2 = _Message(book, "header", "::")
msg_1 = _Message(book, "header")
msg_2 = _Message(book, "header")

assert msg_1 is not msg_2
assert msg_1 == msg_2
assert str(msg_1) == str(msg_2)
assert msg_1 == _Message.from_str(str(msg_1), "::", Book)
assert msg_1 == _Message.from_str(str(msg_1), Book)


def test_schema_registrartion():
Expand All @@ -83,23 +84,14 @@ def test_cannot_register_a_schema_under_an_empty_str():
registry.register("")


@pytest.mark.parametrize(
"delim",
(
pytest.param(SchemaRegistry._DEFAULT_DELIMITER, id="default delimiter"),
pytest.param("::", id="custom delimiter"),
),
)
def test_schema_to_string(delim):
registry = SchemaRegistry(delim)
def test_schema_to_string():
registry = SchemaRegistry()
registry.register("person")(Person)
registry.register("book")(Book)
person = Person(name="Bob", age=36)
book = Book(title="The Greatest Story of All Time", num_pages=10_000)
assert registry.to_string(person) == str(
_Message(person, "person", registry._msg_delim)
)
assert registry.to_string(book) == str(_Message(book, "book", registry._msg_delim))
assert registry.to_string(person) == str(_Message(person, "person"))
assert registry.to_string(book) == str(_Message(book, "book"))


def test_schemas_with_same_shape_are_mapped_correctly():
Expand Down Expand Up @@ -128,15 +120,10 @@ def test_registry_errors_if_types_overloaded():
registry.register("schema")(Book)


def test_registry_errors_if_msg_delim_is_empty():
with pytest.raises(ValueError, match="empty string"):
SchemaRegistry("")


def test_registry_errors_if_msg_type_registered_with_delim_present():
registry = SchemaRegistry("::")
registry = SchemaRegistry()
with pytest.raises(ValueError, match="cannot contain delimiter"):
registry.register("new::type")
registry.register(f"some_key_with_the_{_DEFAULT_MSG_DELIM}_as_a_substring")


def test_registry_errors_on_unknown_schema():
Expand All @@ -147,40 +134,26 @@ def test_registry_errors_on_unknown_schema():
registry.to_string(Book(title="The Shortest Story of All Time", num_pages=1))


@pytest.mark.parametrize(
"delim",
(
pytest.param(SchemaRegistry._DEFAULT_DELIMITER, id="default delimiter"),
pytest.param("::", id="custom delimiter"),
),
)
def test_registry_correctly_maps_to_expected_type(delim):
registry = SchemaRegistry(delim)
def test_registry_correctly_maps_to_expected_type():
registry = SchemaRegistry()
registry.register("person")(Person)
registry.register("book")(Book)
person = Person(name="Bob", age=36)
book = Book(title="The Most Average Story of All Time", num_pages=500)
assert (
registry.from_string(str(_Message(person, "person", registry._msg_delim)))
== person
)
assert (
registry.from_string(str(_Message(book, "book", registry._msg_delim))) == book
)
assert registry.from_string(str(_Message(person, "person"))) == person
assert registry.from_string(str(_Message(book, "book"))) == book


def test_registery_errors_if_type_key_not_recognized():
registry = SchemaRegistry()
registry.register("person")(Person)

with pytest.raises(ValueError, match="^No type of value .* registered$"):
registry.from_string(
str(_Message(Person(name="Grunk", age=5_000), "alien", registry._msg_delim))
)
registry.from_string(str(_Message(Person(name="Grunk", age=5_000), "alien")))


def test_registry_errors_if_type_key_is_missing():
registry = SchemaRegistry("::")
registry = SchemaRegistry()
registry.register("person")(Person)

with pytest.raises(ValueError, match="Failed to determine schema type"):
Expand Down

0 comments on commit c010fca

Please sign in to comment.