Skip to content

Commit

Permalink
Make it possible to insert several items at once (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
augustebaum authored Sep 24, 2024
1 parent 19926ff commit bac112d
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 13 deletions.
83 changes: 74 additions & 9 deletions src/skore/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Define a Project."""

import logging
from functools import singledispatchmethod
from pathlib import Path
from typing import Any
from typing import Any, Literal

from skore.item import (
Item,
Expand All @@ -16,9 +18,11 @@
from skore.layout import Layout, LayoutRepository
from skore.persistence.disk_cache_storage import DirectoryDoesNotExist, DiskCacheStorage

logger = logging.getLogger(__name__)

class KeyTypeError(Exception):
"""Key must be a string."""

class ProjectPutError(Exception):
"""One more key-value pairs could not be saved in the Project."""


class Project:
Expand All @@ -32,16 +36,77 @@ def __init__(
self.item_repository = item_repository
self.layout_repository = layout_repository

def put(self, key: str, value: Any):
"""Add a value to the Project."""
if not isinstance(key, str):
raise KeyTypeError(
f"Key must be a string; '{key}' is of type '{type(key)}'"
@singledispatchmethod
def put(self, key: str, value: Any, on_error: Literal["warn", "raise"] = "warn"):
"""Add a value to the Project.
If `on_error` is "raise", any error stops the execution. If `on_error`
is "warn" (or anything other than "raise"), a warning is shown instead.
Parameters
----------
key : str
The key to associate with `value` in the Project. Must be a string.
value : Any
The value to associate with `key` in the Project.
on_error : "warn" or "raise", optional
Upon error (e.g. if the key is not a string), whether to raise an error or
to print a warning. Default is "warn".
Raises
------
ProjectPutError
If the key-value pair cannot be saved properly, and `on_error` is "raise".
"""
try:
item = object_to_item(value)
self.put_item(key, item)
except (NotImplementedError, TypeError) as e:
if on_error == "raise":
raise ProjectPutError(
"Key-value pair could not be inserted in the Project"
) from e

logger.warning(
"Key-value pair could not be inserted in the Project "
f"due to the following error: {e}"
)
self.put_item(key, object_to_item(value))

@put.register
def put_several(
self, key_to_value: dict, on_error: Literal["warn", "raise"] = "warn"
):
"""Add several values to the Project.
If `on_error` is "raise", the first error stops the execution (so the
later key-value pairs will not be inserted). If `on_error` is "warn" (or
anything other than "raise"), errors do not stop the execution, and are
shown as they come as warnings; all the valid key-value pairs are inserted.
Parameters
----------
key_to_value : dict[str, Any]
The key-value pairs to put in the Project. Keys must be strings.
on_error : "warn" or "raise", optional
Upon error (e.g. if a key is not a string), whether to raise an error or
to print a warning. Default is "warn".
Raises
------
ProjectPutError
If a key-value pair in `key_to_value` cannot be saved properly,
and `on_error` is "raise".
"""
for key, value in key_to_value.items():
self.put(key, value, on_error=on_error)

def put_item(self, key: str, item: Item):
"""Add an Item to the Project."""
if not isinstance(key, str):
raise TypeError(
f"Key must be a string; key '{key}' is of type '{type(key)}'"
)

self.item_repository.put_item(key, item)

def get(self, key: str) -> Any:
Expand Down
62 changes: 58 additions & 4 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from skore.layout import LayoutRepository
from skore.layout.layout import LayoutItem, LayoutItemSize
from skore.persistence.in_memory_storage import InMemoryStorage
from skore.project import KeyTypeError, Project, ProjectLoadError, load
from skore.project import Project, ProjectLoadError, ProjectPutError, load


@pytest.fixture
Expand Down Expand Up @@ -137,9 +137,10 @@ def test_put_twice(project):
assert project.get("key2") == 5


def test_put_int_key(project):
with pytest.raises(KeyTypeError):
project.put(0, "hello")
def test_put_int_key(project, caplog):
# Warns that 0 is not a string, but doesn't raise
project.put(0, "hello")
assert len(caplog.record_tuples) == 1
assert project.list_keys() == []


Expand Down Expand Up @@ -175,3 +176,56 @@ def test_report_layout(project):

project.put_report_layout(layout)
assert project.get_report_layout() == layout


def test_put_several_happy_path(project):
project.put({"a": "foo", "b": "bar"})
assert project.list_keys() == ["a", "b"]


def test_put_several_canonical(project):
"""Use `put_several` instead of the `put` alias."""
project.put_several({"a": "foo", "b": "bar"})
assert project.list_keys() == ["a", "b"]


def test_put_several_some_errors(project, caplog):
project.put(
{
0: "hello",
1: "hello",
2: "hello",
}
)
assert len(caplog.record_tuples) == 3
assert project.list_keys() == []


def test_put_several_nested(project):
project.put({"a": {"b": "baz"}})
assert project.list_keys() == ["a"]
assert project.get("a") == {"b": "baz"}


def test_put_several_error(project):
"""If some key-value pairs are wrong, add all that are valid and print a warning."""
project.put({"a": "foo", "b": (lambda: "unsupported object")})
assert project.list_keys() == ["a"]


def test_put_key_is_a_tuple(project):
"""If key is not a string, warn."""
project.put(("a", "foo"), ("b", "bar"))
assert project.list_keys() == []


def test_put_key_is_a_set(project):
"""Cannot use an unhashable type as a key."""
with pytest.raises(ProjectPutError):
project.put(set(), "hello", on_error="raise")


def test_put_wrong_key_and_value_raise(project):
"""When `on_error` is "raise", raise the first error that occurs."""
with pytest.raises(ProjectPutError):
project.put(0, (lambda: "unsupported object"), on_error="raise")

0 comments on commit bac112d

Please sign in to comment.