diff --git a/smartsim/_core/schemas/utils.py b/smartsim/_core/schemas/utils.py index 7f2f45ced..56f0d62b3 100644 --- a/smartsim/_core/schemas/utils.py +++ b/smartsim/_core/schemas/utils.py @@ -37,41 +37,39 @@ _SendT = t.TypeVar("_SendT", bound=pydantic.BaseModel) _RecvT = t.TypeVar("_RecvT", bound=pydantic.BaseModel) +_DEFAULT_MSG_DELIM: t.Final[str] = "|" + @t.final @pydantic.dataclasses.dataclass(frozen=True) class _Message(t.Generic[_SchemaT]): payload: _SchemaT header: str = pydantic.Field(min_length=1) - delimiter: str = pydantic.Field(min_length=1) + delimiter: str = pydantic.Field(min_length=1, default=_DEFAULT_MSG_DELIM) def __str__(self) -> str: return self.delimiter.join((self.header, self.payload.json())) @classmethod def from_str( - cls, str_: str, delimiter: str, payload_type: t.Type[_SchemaT] + cls, + str_: str, + payload_type: t.Type[_SchemaT], + delimiter: str = _DEFAULT_MSG_DELIM, ) -> "_Message[_SchemaT]": header, payload = str_.split(delimiter, 1) return cls(payload_type.parse_raw(payload), header, delimiter) class SchemaRegistry(t.Generic[_SchemaT]): - _DEFAULT_DELIMITER = "|" - def __init__( - self, - message_delimiter: str = _DEFAULT_DELIMITER, - init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None, - ): - if not message_delimiter: - raise ValueError("Message delimiter cannot be an empty string") - self._msg_delim = message_delimiter + self, init_map: t.Optional[t.Mapping[str, t.Type[_SchemaT]]] = None + ) -> None: self._map = dict(init_map) if init_map else {} def register(self, key: str) -> t.Callable[[t.Type[_SchemaT]], t.Type[_SchemaT]]: - if self._msg_delim in key: - _msg = f"Registry key cannot contain delimiter `{self._msg_delim}`" + if _DEFAULT_MSG_DELIM in key: + _msg = f"Registry key cannot contain delimiter `{_DEFAULT_MSG_DELIM}`" raise ValueError(_msg) if not key: raise KeyError("Key cannot be the empty string") @@ -93,11 +91,11 @@ def _to_message(self, schema: _SchemaT) -> _Message[_SchemaT]: val = reverse_map[type(schema)] except KeyError: raise TypeError(f"Unregistered schema type: {type(schema)}") from None - return _Message(schema, val, self._msg_delim) + return _Message(schema, val, _DEFAULT_MSG_DELIM) def from_string(self, str_: str) -> _SchemaT: try: - type_, _ = str_.split(self._msg_delim, 1) + type_, _ = str_.split(_DEFAULT_MSG_DELIM, 1) except ValueError: _msg = f"Failed to determine schema type of the string {repr(str_)}" raise ValueError(_msg) from None @@ -105,7 +103,7 @@ def from_string(self, str_: str) -> _SchemaT: cls = self._map[type_] except KeyError: raise ValueError(f"No type of value `{type_}` is registered") from None - msg = _Message.from_str(str_, self._msg_delim, cls) + msg = _Message.from_str(str_, cls, _DEFAULT_MSG_DELIM) return self._from_message(msg) @staticmethod diff --git a/tests/test_schema_utils.py b/tests/test_schema_utils.py index 7db8f6070..238c5c9b5 100644 --- a/tests/test_schema_utils.py +++ b/tests/test_schema_utils.py @@ -31,6 +31,7 @@ import pytest from smartsim._core.schemas.utils import ( + _DEFAULT_MSG_DELIM, SchemaRegistry, SocketSchemaTranslator, _Message, @@ -57,13 +58,13 @@ class Book(pydantic.BaseModel): def test_equivalent_messages_are_equivalent(): book = Book(title="A Story", num_pages=250) - msg_1 = _Message(book, "header", "::") - msg_2 = _Message(book, "header", "::") + msg_1 = _Message(book, "header") + msg_2 = _Message(book, "header") assert msg_1 is not msg_2 assert msg_1 == msg_2 assert str(msg_1) == str(msg_2) - assert msg_1 == _Message.from_str(str(msg_1), "::", Book) + assert msg_1 == _Message.from_str(str(msg_1), Book) def test_schema_registrartion(): @@ -83,23 +84,14 @@ def test_cannot_register_a_schema_under_an_empty_str(): registry.register("") -@pytest.mark.parametrize( - "delim", - ( - pytest.param(SchemaRegistry._DEFAULT_DELIMITER, id="default delimiter"), - pytest.param("::", id="custom delimiter"), - ), -) -def test_schema_to_string(delim): - registry = SchemaRegistry(delim) +def test_schema_to_string(): + registry = SchemaRegistry() registry.register("person")(Person) registry.register("book")(Book) person = Person(name="Bob", age=36) book = Book(title="The Greatest Story of All Time", num_pages=10_000) - assert registry.to_string(person) == str( - _Message(person, "person", registry._msg_delim) - ) - assert registry.to_string(book) == str(_Message(book, "book", registry._msg_delim)) + assert registry.to_string(person) == str(_Message(person, "person")) + assert registry.to_string(book) == str(_Message(book, "book")) def test_schemas_with_same_shape_are_mapped_correctly(): @@ -128,15 +120,10 @@ def test_registry_errors_if_types_overloaded(): registry.register("schema")(Book) -def test_registry_errors_if_msg_delim_is_empty(): - with pytest.raises(ValueError, match="empty string"): - SchemaRegistry("") - - def test_registry_errors_if_msg_type_registered_with_delim_present(): - registry = SchemaRegistry("::") + registry = SchemaRegistry() with pytest.raises(ValueError, match="cannot contain delimiter"): - registry.register("new::type") + registry.register(f"some_key_with_the_{_DEFAULT_MSG_DELIM}_as_a_substring") def test_registry_errors_on_unknown_schema(): @@ -147,26 +134,14 @@ def test_registry_errors_on_unknown_schema(): registry.to_string(Book(title="The Shortest Story of All Time", num_pages=1)) -@pytest.mark.parametrize( - "delim", - ( - pytest.param(SchemaRegistry._DEFAULT_DELIMITER, id="default delimiter"), - pytest.param("::", id="custom delimiter"), - ), -) -def test_registry_correctly_maps_to_expected_type(delim): - registry = SchemaRegistry(delim) +def test_registry_correctly_maps_to_expected_type(): + registry = SchemaRegistry() registry.register("person")(Person) registry.register("book")(Book) person = Person(name="Bob", age=36) book = Book(title="The Most Average Story of All Time", num_pages=500) - assert ( - registry.from_string(str(_Message(person, "person", registry._msg_delim))) - == person - ) - assert ( - registry.from_string(str(_Message(book, "book", registry._msg_delim))) == book - ) + assert registry.from_string(str(_Message(person, "person"))) == person + assert registry.from_string(str(_Message(book, "book"))) == book def test_registery_errors_if_type_key_not_recognized(): @@ -174,13 +149,11 @@ def test_registery_errors_if_type_key_not_recognized(): registry.register("person")(Person) with pytest.raises(ValueError, match="^No type of value .* registered$"): - registry.from_string( - str(_Message(Person(name="Grunk", age=5_000), "alien", registry._msg_delim)) - ) + registry.from_string(str(_Message(Person(name="Grunk", age=5_000), "alien"))) def test_registry_errors_if_type_key_is_missing(): - registry = SchemaRegistry("::") + registry = SchemaRegistry() registry.register("person")(Person) with pytest.raises(ValueError, match="Failed to determine schema type"):