diff --git a/src/skore/project.py b/src/skore/project.py index ae8fdde1e..9dfd6aa32 100644 --- a/src/skore/project.py +++ b/src/skore/project.py @@ -1,7 +1,9 @@ """Define a Project.""" +import logging +from functools import singledispatchmethod from pathlib import Path -from typing import Any +from typing import Any, Literal from skore.item import ( Item, @@ -16,9 +18,11 @@ from skore.layout import Layout, LayoutRepository from skore.persistence.disk_cache_storage import DirectoryDoesNotExist, DiskCacheStorage +logger = logging.getLogger(__name__) -class KeyTypeError(Exception): - """Key must be a string.""" + +class ProjectPutError(Exception): + """One more key-value pairs could not be saved in the Project.""" class Project: @@ -32,16 +36,77 @@ def __init__( self.item_repository = item_repository self.layout_repository = layout_repository - def put(self, key: str, value: Any): - """Add a value to the Project.""" - if not isinstance(key, str): - raise KeyTypeError( - f"Key must be a string; '{key}' is of type '{type(key)}'" + @singledispatchmethod + def put(self, key: str, value: Any, on_error: Literal["warn", "raise"] = "warn"): + """Add a value to the Project. + + If `on_error` is "raise", any error stops the execution. If `on_error` + is "warn" (or anything other than "raise"), a warning is shown instead. + + Parameters + ---------- + key : str + The key to associate with `value` in the Project. Must be a string. + value : Any + The value to associate with `key` in the Project. + on_error : "warn" or "raise", optional + Upon error (e.g. if the key is not a string), whether to raise an error or + to print a warning. Default is "warn". + + Raises + ------ + ProjectPutError + If the key-value pair cannot be saved properly, and `on_error` is "raise". + """ + try: + item = object_to_item(value) + self.put_item(key, item) + except (NotImplementedError, TypeError) as e: + if on_error == "raise": + raise ProjectPutError( + "Key-value pair could not be inserted in the Project" + ) from e + + logger.warning( + "Key-value pair could not be inserted in the Project " + f"due to the following error: {e}" ) - self.put_item(key, object_to_item(value)) + + @put.register + def put_several( + self, key_to_value: dict, on_error: Literal["warn", "raise"] = "warn" + ): + """Add several values to the Project. + + If `on_error` is "raise", the first error stops the execution (so the + later key-value pairs will not be inserted). If `on_error` is "warn" (or + anything other than "raise"), errors do not stop the execution, and are + shown as they come as warnings; all the valid key-value pairs are inserted. + + Parameters + ---------- + key_to_value : dict[str, Any] + The key-value pairs to put in the Project. Keys must be strings. + on_error : "warn" or "raise", optional + Upon error (e.g. if a key is not a string), whether to raise an error or + to print a warning. Default is "warn". + + Raises + ------ + ProjectPutError + If a key-value pair in `key_to_value` cannot be saved properly, + and `on_error` is "raise". + """ + for key, value in key_to_value.items(): + self.put(key, value, on_error=on_error) def put_item(self, key: str, item: Item): """Add an Item to the Project.""" + if not isinstance(key, str): + raise TypeError( + f"Key must be a string; key '{key}' is of type '{type(key)}'" + ) + self.item_repository.put_item(key, item) def get(self, key: str) -> Any: diff --git a/tests/unit/test_project.py b/tests/unit/test_project.py index 553e08130..94c390687 100644 --- a/tests/unit/test_project.py +++ b/tests/unit/test_project.py @@ -14,7 +14,7 @@ from skore.layout import LayoutRepository from skore.layout.layout import LayoutItem, LayoutItemSize from skore.persistence.in_memory_storage import InMemoryStorage -from skore.project import KeyTypeError, Project, ProjectLoadError, load +from skore.project import Project, ProjectLoadError, ProjectPutError, load @pytest.fixture @@ -137,9 +137,10 @@ def test_put_twice(project): assert project.get("key2") == 5 -def test_put_int_key(project): - with pytest.raises(KeyTypeError): - project.put(0, "hello") +def test_put_int_key(project, caplog): + # Warns that 0 is not a string, but doesn't raise + project.put(0, "hello") + assert len(caplog.record_tuples) == 1 assert project.list_keys() == [] @@ -175,3 +176,56 @@ def test_report_layout(project): project.put_report_layout(layout) assert project.get_report_layout() == layout + + +def test_put_several_happy_path(project): + project.put({"a": "foo", "b": "bar"}) + assert project.list_keys() == ["a", "b"] + + +def test_put_several_canonical(project): + """Use `put_several` instead of the `put` alias.""" + project.put_several({"a": "foo", "b": "bar"}) + assert project.list_keys() == ["a", "b"] + + +def test_put_several_some_errors(project, caplog): + project.put( + { + 0: "hello", + 1: "hello", + 2: "hello", + } + ) + assert len(caplog.record_tuples) == 3 + assert project.list_keys() == [] + + +def test_put_several_nested(project): + project.put({"a": {"b": "baz"}}) + assert project.list_keys() == ["a"] + assert project.get("a") == {"b": "baz"} + + +def test_put_several_error(project): + """If some key-value pairs are wrong, add all that are valid and print a warning.""" + project.put({"a": "foo", "b": (lambda: "unsupported object")}) + assert project.list_keys() == ["a"] + + +def test_put_key_is_a_tuple(project): + """If key is not a string, warn.""" + project.put(("a", "foo"), ("b", "bar")) + assert project.list_keys() == [] + + +def test_put_key_is_a_set(project): + """Cannot use an unhashable type as a key.""" + with pytest.raises(ProjectPutError): + project.put(set(), "hello", on_error="raise") + + +def test_put_wrong_key_and_value_raise(project): + """When `on_error` is "raise", raise the first error that occurs.""" + with pytest.raises(ProjectPutError): + project.put(0, (lambda: "unsupported object"), on_error="raise")