From 666f0946f16489de11c747fe3c7f296b035021a7 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 7 Jan 2025 15:43:54 +0100
Subject: [PATCH 01/22] [skip ci] Move `item/` to `persistence/item/`
---
skore/src/skore/{ => persistence}/item/__init__.py | 0
skore/src/skore/{ => persistence}/item/cross_validation_item.py | 0
skore/src/skore/{ => persistence}/item/item.py | 0
skore/src/skore/{ => persistence}/item/media_item.py | 0
skore/src/skore/{ => persistence}/item/numpy_array_item.py | 0
skore/src/skore/{ => persistence}/item/pandas_dataframe_item.py | 0
skore/src/skore/{ => persistence}/item/pandas_series_item.py | 0
skore/src/skore/{ => persistence}/item/polars_dataframe_item.py | 0
skore/src/skore/{ => persistence}/item/polars_series_item.py | 0
skore/src/skore/{ => persistence}/item/primitive_item.py | 0
.../skore/{ => persistence}/item/sklearn_base_estimator_item.py | 0
skore/src/skore/{ => persistence}/item/skrub_table_report_item.py | 0
.../src/skore/{ => persistence}/item/standalone_widget.html.jinja | 0
skore/src/skore/persistence/repository/__init__.py | 0
.../src/skore/{item => persistence/repository}/item_repository.py | 0
skore/src/skore/persistence/{ => storage}/abstract_storage.py | 0
skore/src/skore/persistence/{ => storage}/disk_cache_storage.py | 0
skore/src/skore/persistence/{ => storage}/in_memory_storage.py | 0
18 files changed, 0 insertions(+), 0 deletions(-)
rename skore/src/skore/{ => persistence}/item/__init__.py (100%)
rename skore/src/skore/{ => persistence}/item/cross_validation_item.py (100%)
rename skore/src/skore/{ => persistence}/item/item.py (100%)
rename skore/src/skore/{ => persistence}/item/media_item.py (100%)
rename skore/src/skore/{ => persistence}/item/numpy_array_item.py (100%)
rename skore/src/skore/{ => persistence}/item/pandas_dataframe_item.py (100%)
rename skore/src/skore/{ => persistence}/item/pandas_series_item.py (100%)
rename skore/src/skore/{ => persistence}/item/polars_dataframe_item.py (100%)
rename skore/src/skore/{ => persistence}/item/polars_series_item.py (100%)
rename skore/src/skore/{ => persistence}/item/primitive_item.py (100%)
rename skore/src/skore/{ => persistence}/item/sklearn_base_estimator_item.py (100%)
rename skore/src/skore/{ => persistence}/item/skrub_table_report_item.py (100%)
rename skore/src/skore/{ => persistence}/item/standalone_widget.html.jinja (100%)
create mode 100644 skore/src/skore/persistence/repository/__init__.py
rename skore/src/skore/{item => persistence/repository}/item_repository.py (100%)
rename skore/src/skore/persistence/{ => storage}/abstract_storage.py (100%)
rename skore/src/skore/persistence/{ => storage}/disk_cache_storage.py (100%)
rename skore/src/skore/persistence/{ => storage}/in_memory_storage.py (100%)
diff --git a/skore/src/skore/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
similarity index 100%
rename from skore/src/skore/item/__init__.py
rename to skore/src/skore/persistence/item/__init__.py
diff --git a/skore/src/skore/item/cross_validation_item.py b/skore/src/skore/persistence/item/cross_validation_item.py
similarity index 100%
rename from skore/src/skore/item/cross_validation_item.py
rename to skore/src/skore/persistence/item/cross_validation_item.py
diff --git a/skore/src/skore/item/item.py b/skore/src/skore/persistence/item/item.py
similarity index 100%
rename from skore/src/skore/item/item.py
rename to skore/src/skore/persistence/item/item.py
diff --git a/skore/src/skore/item/media_item.py b/skore/src/skore/persistence/item/media_item.py
similarity index 100%
rename from skore/src/skore/item/media_item.py
rename to skore/src/skore/persistence/item/media_item.py
diff --git a/skore/src/skore/item/numpy_array_item.py b/skore/src/skore/persistence/item/numpy_array_item.py
similarity index 100%
rename from skore/src/skore/item/numpy_array_item.py
rename to skore/src/skore/persistence/item/numpy_array_item.py
diff --git a/skore/src/skore/item/pandas_dataframe_item.py b/skore/src/skore/persistence/item/pandas_dataframe_item.py
similarity index 100%
rename from skore/src/skore/item/pandas_dataframe_item.py
rename to skore/src/skore/persistence/item/pandas_dataframe_item.py
diff --git a/skore/src/skore/item/pandas_series_item.py b/skore/src/skore/persistence/item/pandas_series_item.py
similarity index 100%
rename from skore/src/skore/item/pandas_series_item.py
rename to skore/src/skore/persistence/item/pandas_series_item.py
diff --git a/skore/src/skore/item/polars_dataframe_item.py b/skore/src/skore/persistence/item/polars_dataframe_item.py
similarity index 100%
rename from skore/src/skore/item/polars_dataframe_item.py
rename to skore/src/skore/persistence/item/polars_dataframe_item.py
diff --git a/skore/src/skore/item/polars_series_item.py b/skore/src/skore/persistence/item/polars_series_item.py
similarity index 100%
rename from skore/src/skore/item/polars_series_item.py
rename to skore/src/skore/persistence/item/polars_series_item.py
diff --git a/skore/src/skore/item/primitive_item.py b/skore/src/skore/persistence/item/primitive_item.py
similarity index 100%
rename from skore/src/skore/item/primitive_item.py
rename to skore/src/skore/persistence/item/primitive_item.py
diff --git a/skore/src/skore/item/sklearn_base_estimator_item.py b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py
similarity index 100%
rename from skore/src/skore/item/sklearn_base_estimator_item.py
rename to skore/src/skore/persistence/item/sklearn_base_estimator_item.py
diff --git a/skore/src/skore/item/skrub_table_report_item.py b/skore/src/skore/persistence/item/skrub_table_report_item.py
similarity index 100%
rename from skore/src/skore/item/skrub_table_report_item.py
rename to skore/src/skore/persistence/item/skrub_table_report_item.py
diff --git a/skore/src/skore/item/standalone_widget.html.jinja b/skore/src/skore/persistence/item/standalone_widget.html.jinja
similarity index 100%
rename from skore/src/skore/item/standalone_widget.html.jinja
rename to skore/src/skore/persistence/item/standalone_widget.html.jinja
diff --git a/skore/src/skore/persistence/repository/__init__.py b/skore/src/skore/persistence/repository/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/skore/src/skore/item/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py
similarity index 100%
rename from skore/src/skore/item/item_repository.py
rename to skore/src/skore/persistence/repository/item_repository.py
diff --git a/skore/src/skore/persistence/abstract_storage.py b/skore/src/skore/persistence/storage/abstract_storage.py
similarity index 100%
rename from skore/src/skore/persistence/abstract_storage.py
rename to skore/src/skore/persistence/storage/abstract_storage.py
diff --git a/skore/src/skore/persistence/disk_cache_storage.py b/skore/src/skore/persistence/storage/disk_cache_storage.py
similarity index 100%
rename from skore/src/skore/persistence/disk_cache_storage.py
rename to skore/src/skore/persistence/storage/disk_cache_storage.py
diff --git a/skore/src/skore/persistence/in_memory_storage.py b/skore/src/skore/persistence/storage/in_memory_storage.py
similarity index 100%
rename from skore/src/skore/persistence/in_memory_storage.py
rename to skore/src/skore/persistence/storage/in_memory_storage.py
From 6811b6917ee306e1f9b48b4c89752f37050f81df Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 7 Jan 2025 15:55:35 +0100
Subject: [PATCH 02/22] [skip ci] Move `view/` to `persistence/view`
---
.../src/skore/{view => persistence/repository}/view_repository.py | 0
skore/src/skore/{ => persistence}/view/__init__.py | 0
skore/src/skore/{ => persistence}/view/view.py | 0
3 files changed, 0 insertions(+), 0 deletions(-)
rename skore/src/skore/{view => persistence/repository}/view_repository.py (100%)
rename skore/src/skore/{ => persistence}/view/__init__.py (100%)
rename skore/src/skore/{ => persistence}/view/view.py (100%)
diff --git a/skore/src/skore/view/view_repository.py b/skore/src/skore/persistence/repository/view_repository.py
similarity index 100%
rename from skore/src/skore/view/view_repository.py
rename to skore/src/skore/persistence/repository/view_repository.py
diff --git a/skore/src/skore/view/__init__.py b/skore/src/skore/persistence/view/__init__.py
similarity index 100%
rename from skore/src/skore/view/__init__.py
rename to skore/src/skore/persistence/view/__init__.py
diff --git a/skore/src/skore/view/view.py b/skore/src/skore/persistence/view/view.py
similarity index 100%
rename from skore/src/skore/view/view.py
rename to skore/src/skore/persistence/view/view.py
From c08f284fe97449731600ba9bf5901c0edc0b2d6b Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 7 Jan 2025 16:22:43 +0100
Subject: [PATCH 03/22] [skip ci] Remove `put_item`/`get_item`
---
skore/src/skore/persistence/item/__init__.py | 27 +++++-
skore/src/skore/project/project.py | 92 ++++----------------
2 files changed, 43 insertions(+), 76 deletions(-)
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index b7c89a040..0735647f5 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -43,7 +43,32 @@ def object_to_item(object: Any) -> Item:
# correct type. If not, they throw a ItemTypeError exception.
return cls.factory(object)
- raise NotImplementedError(f"Type '{object.__class__}' is not supported.")
+ return PickleItem(object)
+
+
+def item_to_object(item: Item) -> Any:
+ if isinstance(item, PrimitiveItem):
+ return item.primitive
+ elif isinstance(item, NumpyArrayItem):
+ return item.array
+ elif isinstance(item, PandasDataFrameItem):
+ return item.dataframe
+ elif isinstance(item, PandasSeriesItem):
+ return item.series
+ elif isinstance(item, PolarsDataFrameItem):
+ return item.dataframe
+ elif isinstance(item, PolarsSeriesItem):
+ return item.series
+ elif isinstance(item, SklearnBaseEstimatorItem):
+ return item.estimator
+ elif isinstance(item, CrossValidationItem):
+ return item.cv_results_serialized
+ elif isinstance(item, MediaItem):
+ return item.media_bytes
+ elif isinstance(item, PickleItem):
+ return repr(item.pickle_bytes)
+ else:
+ raise ValueError(f"Item {item} is not a known item type.")
__all__ = [
diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py
index 289b04d10..299409425 100644
--- a/skore/src/skore/project/project.py
+++ b/skore/src/skore/project/project.py
@@ -1,28 +1,18 @@
"""Define a Project."""
-import logging
-from typing import Any, Optional, Union
-
-from skore.item import (
- CrossValidationItem,
- Item,
- ItemRepository,
- MediaItem,
- NumpyArrayItem,
- PandasDataFrameItem,
- PandasSeriesItem,
- PolarsDataFrameItem,
- PolarsSeriesItem,
- PrimitiveItem,
- SklearnBaseEstimatorItem,
- object_to_item,
-)
-from skore.view.view import View
-from skore.view.view_repository import ViewRepository
-
-logger = logging.getLogger(__name__)
-logger.addHandler(logging.NullHandler()) # Default to no output
-logger.setLevel(logging.INFO)
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+from skore.persistence import item_to_object, object_to_item
+
+if TYPE_CHECKING:
+ from skore.persistence import (
+ Item,
+ ItemRepository,
+ View,
+ ViewRepository,
+ )
MISSING = object()
@@ -128,28 +118,21 @@ def put(self, key: Union[str, dict[str, Any]], value: Optional[Any] = MISSING):
If the value type is not supported.
"""
if value is not MISSING:
- key_to_item = {key: value}
+ key_to_value = {key: value}
elif isinstance(key, dict):
- key_to_item = key
+ key_to_value = key
else:
raise TypeError(
f"Bad parameters. "
f"When value is not specified, key must be a dict (found '{type(key)}')"
)
- for key, value in key_to_item.items():
+ for key, value in key_to_value.items():
if not isinstance(key, str):
raise TypeError(f"Key must be a string (found '{type(key)}')")
self.item_repository.put_item(key, object_to_item(value))
- def put_item(self, key: str, item: Item):
- """Add an Item to the Project."""
- if not isinstance(key, str):
- raise TypeError(f"Key must be a string (found '{type(key)}')")
-
- self.item_repository.put_item(key, item)
-
def get(self, key: str) -> Any:
"""Get the value corresponding to ``key`` from the Project.
@@ -163,48 +146,7 @@ def get(self, key: str) -> Any:
KeyError
If the key does not correspond to any item.
"""
- item = self.get_item(key)
-
- if isinstance(item, PrimitiveItem):
- return item.primitive
- elif isinstance(item, NumpyArrayItem):
- return item.array
- elif isinstance(item, PandasDataFrameItem):
- return item.dataframe
- elif isinstance(item, PandasSeriesItem):
- return item.series
- elif isinstance(item, PolarsDataFrameItem):
- return item.dataframe
- elif isinstance(item, PolarsSeriesItem):
- return item.series
- elif isinstance(item, SklearnBaseEstimatorItem):
- return item.estimator
- elif isinstance(item, CrossValidationItem):
- return item.cv_results_serialized
- elif isinstance(item, MediaItem):
- return item.media_bytes
- else:
- raise ValueError(f"Item {item} is not a known item type.")
-
- def get_item(self, key: str) -> Item:
- """Get the item corresponding to ``key`` from the Project.
-
- Parameters
- ----------
- key : str
- The key corresponding to the item to get.
-
- Returns
- -------
- item : Item
- The Item corresponding to ``key``.
-
- Raises
- ------
- KeyError
- If the key does not correspond to any item.
- """
- return self.item_repository.get_item(key)
+ return item_to_object(self.item_repository.get_item(key))
def get_item_versions(self, key: str) -> list[Item]:
"""
From 0277edb1f87125c48e4cb33b24f98d0274a29fba Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 7 Jan 2025 16:40:14 +0100
Subject: [PATCH 04/22] [skip ci] Add `PickleItem`
---
skore/src/skore/persistence/item/__init__.py | 37 +++++++++----------
.../src/skore/persistence/item/pickle_item.py | 25 +++++++++++++
2 files changed, 43 insertions(+), 19 deletions(-)
create mode 100644 skore/src/skore/persistence/item/pickle_item.py
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index 0735647f5..57e3b4c10 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -5,18 +5,19 @@
from contextlib import suppress
from typing import Any
-from skore.item import skrub_table_report_item as SkrubTableReportItem
-from skore.item.cross_validation_item import CrossValidationItem
-from skore.item.item import Item, ItemTypeError
-from skore.item.item_repository import ItemRepository
-from skore.item.media_item import MediaItem
-from skore.item.numpy_array_item import NumpyArrayItem
-from skore.item.pandas_dataframe_item import PandasDataFrameItem
-from skore.item.pandas_series_item import PandasSeriesItem
-from skore.item.polars_dataframe_item import PolarsDataFrameItem
-from skore.item.polars_series_item import PolarsSeriesItem
-from skore.item.primitive_item import PrimitiveItem
-from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem
+from .cross_validation_item import CrossValidationItem
+from .item import Item, ItemTypeError
+from .item_repository import ItemRepository
+from .media_item import MediaItem
+from .numpy_array_item import NumpyArrayItem
+from .pandas_dataframe_item import PandasDataFrameItem
+from .pandas_series_item import PandasSeriesItem
+from .pickle_item import PickleItem
+from .polars_dataframe_item import PolarsDataFrameItem
+from .polars_series_item import PolarsSeriesItem
+from .primitive_item import PrimitiveItem
+from .sklearn_base_estimator_item import SklearnBaseEstimatorItem
+from .skrub_table_report_item import SkrubTableReportItem
def object_to_item(object: Any) -> Item:
@@ -51,13 +52,9 @@ def item_to_object(item: Item) -> Any:
return item.primitive
elif isinstance(item, NumpyArrayItem):
return item.array
- elif isinstance(item, PandasDataFrameItem):
+ elif isinstance(item, PandasDataFrameItem) or isinstance(item, PolarsDataFrameItem):
return item.dataframe
- elif isinstance(item, PandasSeriesItem):
- return item.series
- elif isinstance(item, PolarsDataFrameItem):
- return item.dataframe
- elif isinstance(item, PolarsSeriesItem):
+ elif isinstance(item, PandasSeriesItem) or isinstance(item, PolarsSeriesItem):
return item.series
elif isinstance(item, SklearnBaseEstimatorItem):
return item.estimator
@@ -66,7 +63,7 @@ def item_to_object(item: Item) -> Any:
elif isinstance(item, MediaItem):
return item.media_bytes
elif isinstance(item, PickleItem):
- return repr(item.pickle_bytes)
+ return item.object
else:
raise ValueError(f"Item {item} is not a known item type.")
@@ -79,10 +76,12 @@ def item_to_object(item: Item) -> Any:
"NumpyArrayItem",
"PandasDataFrameItem",
"PandasSeriesItem",
+ "PickleItem",
"PolarsDataFrameItem",
"PolarsSeriesItem",
"PrimitiveItem",
"SklearnBaseEstimatorItem",
"SkrubTableReportItem",
+ "item_to_object",
"object_to_item",
]
diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py
new file mode 100644
index 000000000..ce9019488
--- /dev/null
+++ b/skore/src/skore/persistence/item/pickle_item.py
@@ -0,0 +1,25 @@
+from functools import cached_property
+from pickle import dumps, loads
+from typing import Any
+
+from skore.item.item import Item
+
+
+class PickleItem(Item):
+ def __init__(
+ self,
+ pickle_bytes: bytes,
+ created_at: str | None = None,
+ updated_at: str | None = None,
+ ):
+ super().__init__(created_at, updated_at)
+
+ self.pickle_bytes = pickle_bytes
+
+ @cached_property
+ def object(self) -> Any:
+ return loads(self.pickle_bytes)
+
+ @classmethod
+ def factory(cls, object: Any) -> PickleItem:
+ return cls(dumps(object))
From f4552e97f40e02af2adec7e3ec0a69eedd3c67c3 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 7 Jan 2025 16:52:33 +0100
Subject: [PATCH 05/22] [skip ci] Fix import of `SkrubTableReportItem`
---
skore/src/skore/persistence/item/__init__.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index 57e3b4c10..770ce5979 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -5,6 +5,8 @@
from contextlib import suppress
from typing import Any
+import skrub_table_report_item as SkrubTableReportItem
+
from .cross_validation_item import CrossValidationItem
from .item import Item, ItemTypeError
from .item_repository import ItemRepository
@@ -17,7 +19,6 @@
from .polars_series_item import PolarsSeriesItem
from .primitive_item import PrimitiveItem
from .sklearn_base_estimator_item import SklearnBaseEstimatorItem
-from .skrub_table_report_item import SkrubTableReportItem
def object_to_item(object: Any) -> Item:
From aff2a5f0c8db2edfa9cfcd58f7f7cb4d286ef425 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Wed, 8 Jan 2025 12:05:51 +0100
Subject: [PATCH 06/22] Fix imports and tests
Co-authored-by: Auguste Baum
commit 86281ca475f59c2fa43760b243c366c23c38b787
Author: Auguste Baum
Date: Wed Jan 8 12:00:46 2025 +0100
mob next [ci-skip] [ci skip] [skip ci]
lastFile:skore/tests/integration/ui/test_ui.py
commit 9f6f1931584de22ab08a9e4a9779a5f17dee4501
Author: Thomas S
Date: Wed Jan 8 11:52:12 2025 +0100
mob next [ci-skip] [ci skip] [skip ci]
lastFile:skore/tests/integration/ui/test_ui.py
commit cd385f66a805390893a6e415761ab1c0ea65550e
Author: Auguste Baum
Date: Wed Jan 8 11:41:10 2025 +0100
mob next [ci-skip] [ci skip] [skip ci]
lastFile:skore/src/skore/persistence/item/sklearn_base_estimator_item.py
commit 41c53559266f139c785211fc97a43e1e2b281203
Author: Thomas S
Date: Wed Jan 8 11:27:16 2025 +0100
mob next [ci-skip] [ci skip] [skip ci]
lastFile:skore/tests/integration/sklearn/test_cross_validate.py
commit a7d956edfeab6ef56020aaae8752e59f7f4e3491
Author: Auguste Baum
Date: Wed Jan 8 11:14:59 2025 +0100
mob next [ci-skip] [ci skip] [skip ci]
lastFile:skore/tests/unit/view/test_view_repository.py
commit 38446cf8499b2d3bcf79d371b00e374a916dd77a
Author: Thomas S
Date: Wed Jan 8 11:01:43 2025 +0100
mob next [ci-skip] [ci skip] [skip ci]
lastFile:skore/src/skore/persistence/item/pickle_item.py
commit a6e2a7914ada1e003c5295ccf233c73238972dfc
Author: Thomas S
Date: Wed Jan 8 10:51:07 2025 +0100
mob start [ci-skip] [ci skip] [skip ci]
Co-authored-by: Auguste Baum
---
skore/src/skore/persistence/item/__init__.py | 5 +-
.../persistence/item/cross_validation_item.py | 3 +-
.../src/skore/persistence/item/media_item.py | 2 +-
.../persistence/item/numpy_array_item.py | 2 +-
.../persistence/item/pandas_dataframe_item.py | 2 +-
.../persistence/item/pandas_series_item.py | 2 +-
.../src/skore/persistence/item/pickle_item.py | 4 +-
.../persistence/item/polars_dataframe_item.py | 2 +-
.../persistence/item/polars_series_item.py | 2 +-
.../skore/persistence/item/primitive_item.py | 2 +-
.../item/sklearn_base_estimator_item.py | 7 ++-
.../item/skrub_table_report_item.py | 4 +-
.../skore/persistence/repository/__init__.py | 7 +++
.../persistence/repository/item_repository.py | 27 +++++-----
.../src/skore/persistence/storage/__init__.py | 9 ++++
skore/src/skore/project/create.py | 2 +-
skore/src/skore/project/load.py | 8 +--
skore/src/skore/project/project.py | 7 ++-
skore/src/skore/ui/project_routes.py | 4 +-
skore/tests/conftest.py | 5 +-
.../sklearn/test_cross_validate.py | 4 +-
skore/tests/integration/ui/test_ui.py | 51 ++++++++++++-------
.../unit/item/test_cross_validation_item.py | 8 +--
skore/tests/unit/item/test_item_repository.py | 3 +-
skore/tests/unit/item/test_media_item.py | 4 +-
.../tests/unit/item/test_numpy_array_item.py | 4 +-
.../unit/item/test_pandas_dataframe_item.py | 4 +-
.../unit/item/test_pandas_series_item.py | 4 +-
.../unit/item/test_polars_dataframe_item.py | 6 +--
.../unit/item/test_polars_series_item.py | 4 +-
skore/tests/unit/item/test_primitive_item.py | 6 +--
.../item/test_sklearn_base_estimator_item.py | 4 +-
.../unit/item/test_skrub_table_report_item.py | 4 +-
skore/tests/unit/persistence/test_disk.py | 2 +-
skore/tests/unit/persistence/test_memory.py | 2 +-
skore/tests/unit/test_project.py | 20 ++++----
skore/tests/unit/view/test_view_repository.py | 6 +--
37 files changed, 142 insertions(+), 100 deletions(-)
create mode 100644 skore/src/skore/persistence/storage/__init__.py
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index 770ce5979..7b2d2f002 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -5,11 +5,9 @@
from contextlib import suppress
from typing import Any
-import skrub_table_report_item as SkrubTableReportItem
-
+from . import skrub_table_report_item as SkrubTableReportItem
from .cross_validation_item import CrossValidationItem
from .item import Item, ItemTypeError
-from .item_repository import ItemRepository
from .media_item import MediaItem
from .numpy_array_item import NumpyArrayItem
from .pandas_dataframe_item import PandasDataFrameItem
@@ -72,7 +70,6 @@ def item_to_object(item: Item) -> Any:
__all__ = [
"CrossValidationItem",
"Item",
- "ItemRepository",
"MediaItem",
"NumpyArrayItem",
"PandasDataFrameItem",
diff --git a/skore/src/skore/persistence/item/cross_validation_item.py b/skore/src/skore/persistence/item/cross_validation_item.py
index e4e4c4a2a..1121117be 100644
--- a/skore/src/skore/persistence/item/cross_validation_item.py
+++ b/skore/src/skore/persistence/item/cross_validation_item.py
@@ -20,9 +20,10 @@
import plotly.graph_objects
import plotly.io
-from skore.item.item import Item, ItemTypeError
from skore.sklearn.cross_validation import CrossValidationReporter
+from .item import Item, ItemTypeError
+
if TYPE_CHECKING:
import sklearn.base
diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py
index b83de93e4..b0c957bd6 100644
--- a/skore/src/skore/persistence/item/media_item.py
+++ b/skore/src/skore/persistence/item/media_item.py
@@ -9,7 +9,7 @@
from io import BytesIO
from typing import TYPE_CHECKING, Any
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
from altair.vegalite.v5.schema.core import TopLevelSpec as Altair
diff --git a/skore/src/skore/persistence/item/numpy_array_item.py b/skore/src/skore/persistence/item/numpy_array_item.py
index 59d80bf66..5ebe0c82c 100644
--- a/skore/src/skore/persistence/item/numpy_array_item.py
+++ b/skore/src/skore/persistence/item/numpy_array_item.py
@@ -9,7 +9,7 @@
from json import dumps, loads
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
import numpy
diff --git a/skore/src/skore/persistence/item/pandas_dataframe_item.py b/skore/src/skore/persistence/item/pandas_dataframe_item.py
index f61fef26c..3d124677e 100644
--- a/skore/src/skore/persistence/item/pandas_dataframe_item.py
+++ b/skore/src/skore/persistence/item/pandas_dataframe_item.py
@@ -9,7 +9,7 @@
from functools import cached_property
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
import pandas
diff --git a/skore/src/skore/persistence/item/pandas_series_item.py b/skore/src/skore/persistence/item/pandas_series_item.py
index cd27ebea3..bfc52457c 100644
--- a/skore/src/skore/persistence/item/pandas_series_item.py
+++ b/skore/src/skore/persistence/item/pandas_series_item.py
@@ -9,7 +9,7 @@
from functools import cached_property
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
import pandas
diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py
index ce9019488..42a0ace5d 100644
--- a/skore/src/skore/persistence/item/pickle_item.py
+++ b/skore/src/skore/persistence/item/pickle_item.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
from functools import cached_property
from pickle import dumps, loads
from typing import Any
-from skore.item.item import Item
+from .item import Item
class PickleItem(Item):
diff --git a/skore/src/skore/persistence/item/polars_dataframe_item.py b/skore/src/skore/persistence/item/polars_dataframe_item.py
index d42292dbe..dc6b05533 100644
--- a/skore/src/skore/persistence/item/polars_dataframe_item.py
+++ b/skore/src/skore/persistence/item/polars_dataframe_item.py
@@ -9,7 +9,7 @@
from functools import cached_property
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
import polars
diff --git a/skore/src/skore/persistence/item/polars_series_item.py b/skore/src/skore/persistence/item/polars_series_item.py
index 7d846f9a9..ba0dd3f8c 100644
--- a/skore/src/skore/persistence/item/polars_series_item.py
+++ b/skore/src/skore/persistence/item/polars_series_item.py
@@ -9,7 +9,7 @@
from functools import cached_property
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
import polars
diff --git a/skore/src/skore/persistence/item/primitive_item.py b/skore/src/skore/persistence/item/primitive_item.py
index 3e7d1e1a5..039032312 100644
--- a/skore/src/skore/persistence/item/primitive_item.py
+++ b/skore/src/skore/persistence/item/primitive_item.py
@@ -7,7 +7,7 @@
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
from typing import Union
diff --git a/skore/src/skore/persistence/item/sklearn_base_estimator_item.py b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py
index d40b61bba..40321970e 100644
--- a/skore/src/skore/persistence/item/sklearn_base_estimator_item.py
+++ b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py
@@ -9,7 +9,7 @@
from functools import cached_property
from typing import TYPE_CHECKING
-from skore.item.item import Item, ItemTypeError
+from .item import Item, ItemTypeError
if TYPE_CHECKING:
import sklearn.base
@@ -101,11 +101,14 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI
"""
import sklearn.base
import sklearn.utils
- import skops.io
if not isinstance(estimator, sklearn.base.BaseEstimator):
raise ItemTypeError(f"Type '{estimator.__class__}' is not supported.")
+ # This line is only needed if we know `estimator` has the right type, so we do
+ # it after the type check
+ import skops.io
+
estimator_html_repr = sklearn.utils.estimator_html_repr(estimator)
estimator_skops = skops.io.dumps(estimator)
estimator_skops_untrusted_types = skops.io.get_untrusted_types(
diff --git a/skore/src/skore/persistence/item/skrub_table_report_item.py b/skore/src/skore/persistence/item/skrub_table_report_item.py
index b411256f2..07eae5bbd 100644
--- a/skore/src/skore/persistence/item/skrub_table_report_item.py
+++ b/skore/src/skore/persistence/item/skrub_table_report_item.py
@@ -4,8 +4,8 @@
from typing import TYPE_CHECKING
-from skore.item.item import ItemTypeError
-from skore.item.media_item import MediaItem
+from .item import ItemTypeError
+from .media_item import MediaItem
if TYPE_CHECKING:
from skrub import TableReport
diff --git a/skore/src/skore/persistence/repository/__init__.py b/skore/src/skore/persistence/repository/__init__.py
index e69de29bb..c088fdacf 100644
--- a/skore/src/skore/persistence/repository/__init__.py
+++ b/skore/src/skore/persistence/repository/__init__.py
@@ -0,0 +1,7 @@
+from .item_repository import ItemRepository
+from .view_repository import ViewRepository
+
+__all__ = [
+ "ItemRepository",
+ "ViewRepository",
+]
diff --git a/skore/src/skore/persistence/repository/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py
index 74210964f..aea849f80 100644
--- a/skore/src/skore/persistence/repository/item_repository.py
+++ b/skore/src/skore/persistence/repository/item_repository.py
@@ -8,20 +8,21 @@
from typing import TYPE_CHECKING
+from skore.persistence.item import (
+ CrossValidationItem,
+ MediaItem,
+ NumpyArrayItem,
+ PandasDataFrameItem,
+ PandasSeriesItem,
+ PolarsDataFrameItem,
+ PolarsSeriesItem,
+ PrimitiveItem,
+ SklearnBaseEstimatorItem,
+)
+
if TYPE_CHECKING:
- from skore.item.item import Item
- from skore.persistence.abstract_storage import AbstractStorage
-
-
-from skore.item.cross_validation_item import CrossValidationItem
-from skore.item.media_item import MediaItem
-from skore.item.numpy_array_item import NumpyArrayItem
-from skore.item.pandas_dataframe_item import PandasDataFrameItem
-from skore.item.pandas_series_item import PandasSeriesItem
-from skore.item.polars_dataframe_item import PolarsDataFrameItem
-from skore.item.polars_series_item import PolarsSeriesItem
-from skore.item.primitive_item import PrimitiveItem
-from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem
+ from skore.persistence.item import Item
+ from skore.persistence.storage import AbstractStorage
class ItemRepository:
diff --git a/skore/src/skore/persistence/storage/__init__.py b/skore/src/skore/persistence/storage/__init__.py
new file mode 100644
index 000000000..21b6a3b1a
--- /dev/null
+++ b/skore/src/skore/persistence/storage/__init__.py
@@ -0,0 +1,9 @@
+from .abstract_storage import AbstractStorage
+from .disk_cache_storage import DiskCacheStorage
+from .in_memory_storage import InMemoryStorage
+
+__all__ = [
+ "AbstractStorage",
+ "DiskCacheStorage",
+ "InMemoryStorage",
+]
diff --git a/skore/src/skore/project/create.py b/skore/src/skore/project/create.py
index bee8449d5..fcffce9d0 100644
--- a/skore/src/skore/project/create.py
+++ b/skore/src/skore/project/create.py
@@ -11,10 +11,10 @@
ProjectCreationError,
ProjectPermissionError,
)
+from skore.persistence.view.view import View
from skore.project.load import load
from skore.project.project import Project, logger
from skore.utils._logger import logger_context
-from skore.view.view import View
def _validate_project_name(project_name: str) -> tuple[bool, Optional[Exception]]:
diff --git a/skore/src/skore/project/load.py b/skore/src/skore/project/load.py
index c78948690..0e8e7f064 100644
--- a/skore/src/skore/project/load.py
+++ b/skore/src/skore/project/load.py
@@ -3,10 +3,12 @@
from pathlib import Path
from typing import Union
-from skore.item import ItemRepository
-from skore.persistence.disk_cache_storage import DirectoryDoesNotExist, DiskCacheStorage
+from skore.persistence.repository import ItemRepository, ViewRepository
+from skore.persistence.storage.disk_cache_storage import (
+ DirectoryDoesNotExist,
+ DiskCacheStorage,
+)
from skore.project.project import Project
-from skore.view.view_repository import ViewRepository
class ProjectLoadError(Exception):
diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py
index 299409425..3cca337f2 100644
--- a/skore/src/skore/project/project.py
+++ b/skore/src/skore/project/project.py
@@ -2,9 +2,10 @@
from __future__ import annotations
+import logging
from typing import TYPE_CHECKING, Any, Optional, Union
-from skore.persistence import item_to_object, object_to_item
+from skore.persistence.item import item_to_object, object_to_item
if TYPE_CHECKING:
from skore.persistence import (
@@ -15,6 +16,10 @@
)
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.NullHandler()) # Default to no output
+logger.setLevel(logging.INFO)
+
MISSING = object()
diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py
index 040e5dbca..79eff5ce6 100644
--- a/skore/src/skore/ui/project_routes.py
+++ b/skore/src/skore/ui/project_routes.py
@@ -9,9 +9,9 @@
from fastapi import APIRouter, HTTPException, Request, status
-from skore.item import Item
+from skore.persistence.item import Item
+from skore.persistence.view.view import Layout, View
from skore.project import Project
-from skore.view.view import Layout, View
router = APIRouter(prefix="/project")
diff --git a/skore/tests/conftest.py b/skore/tests/conftest.py
index 6b2e3e0ad..61642fb5b 100644
--- a/skore/tests/conftest.py
+++ b/skore/tests/conftest.py
@@ -1,10 +1,9 @@
from datetime import datetime, timezone
import pytest
-from skore.item.item_repository import ItemRepository
-from skore.persistence.in_memory_storage import InMemoryStorage
+from skore.persistence.repository import ItemRepository, ViewRepository
+from skore.persistence.storage import InMemoryStorage
from skore.project import Project
-from skore.view.view_repository import ViewRepository
@pytest.fixture
diff --git a/skore/tests/integration/sklearn/test_cross_validate.py b/skore/tests/integration/sklearn/test_cross_validate.py
index 8c8992728..e39f6ba58 100644
--- a/skore/tests/integration/sklearn/test_cross_validate.py
+++ b/skore/tests/integration/sklearn/test_cross_validate.py
@@ -10,7 +10,7 @@
from sklearn.multiclass import OneVsOneClassifier
from sklearn.svm import SVC
from skore import CrossValidationReporter
-from skore.item.cross_validation_item import CrossValidationItem
+from skore.persistence.item.cross_validation_item import CrossValidationItem
from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add
@@ -200,7 +200,7 @@ def test_cross_validation_reporter(in_memory_project, fixture_name, request):
in_memory_project.put("cross-validation", reporter)
- retrieved_item = in_memory_project.get_item("cross-validation")
+ retrieved_item = in_memory_project.item_repository.get_item("cross-validation")
assert isinstance(retrieved_item, CrossValidationItem)
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index f467ff699..f5f984e26 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -1,7 +1,9 @@
import datetime
+import json
import numpy
import pandas
+import plotly
import polars
import pytest
from fastapi.testclient import TestClient
@@ -9,9 +11,8 @@
from sklearn.linear_model import Lasso
from sklearn.model_selection import KFold
from skore import CrossValidationReporter
-from skore.item.media_item import MediaItem
+from skore.persistence.view.view import View
from skore.ui.app import create_app
-from skore.view.view import View
@pytest.fixture
@@ -136,16 +137,12 @@ def test_serialize_media_item(client, in_memory_project):
html = "éપUœALDXIWDŸΩΩ
"
in_memory_project.put("html", html)
- in_memory_project.put_item(
- "media html", MediaItem.factory_str(html, media_type="text/html")
- )
response = client.get("/api/project/items")
assert response.status_code == 200
project = response.json()
assert "image" in project["items"]["img"][0]["media_type"]
assert project["items"]["html"][0]["value"] == html
- assert project["items"]["media html"][0]["value"] == html
@pytest.fixture
@@ -178,9 +175,21 @@ def test_serialize_cross_validation_item(
mock_nowstr,
fake_cross_validate,
):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+ monkeypatch.setattr(
+ "skore.persistence.item.cross_validation_item.CrossValidationItem.plots", {}
+ )
+ monkeypatch.setattr(
+ "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_compare_scores",
+ lambda _: {},
+ )
+ monkeypatch.setattr(
+ "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_timing_normalized",
+ lambda _: {},
+ )
monkeypatch.setattr(
- "skore.item.cross_validation_item.CrossValidationItem.plots", {}
+ "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_timing",
+ lambda _: {},
)
def prepare_cv():
@@ -196,17 +205,12 @@ def prepare_cv():
reporter = CrossValidationReporter(model, X, y, cv=KFold(3))
in_memory_project.put("cv", reporter)
- # Mock the item to make the plot empty
- item = in_memory_project.get_item("cv")
- item.plots_bytes = {"compare_scores": b"{}"}
- in_memory_project.put_item("cv_mocked", item)
-
response = client.get("/api/project/items")
assert response.status_code == 200
project = response.json()
expected = {
- "name": "cv_mocked",
+ "name": "cv",
"media_type": "application/vnd.skore.cross_validation+json",
"value": {
"scalar_results": [
@@ -227,7 +231,20 @@ def prepare_cv():
],
}
],
- "plots": [{"name": "compare_scores", "value": {}}],
+ "plots": [
+ {
+ "name": "Scores",
+ "value": json.loads(plotly.io.to_json({}, engine="json")),
+ },
+ {
+ "name": "Timings",
+ "value": json.loads(plotly.io.to_json({}, engine="json")),
+ },
+ {
+ "name": "Normalized timings",
+ "value": json.loads(plotly.io.to_json({}, engine="json")),
+ },
+ ],
"sections": [
{
"title": "Model",
@@ -266,7 +283,7 @@ def prepare_cv():
"updated_at": mock_nowstr,
"created_at": mock_nowstr,
}
- actual = project["items"]["cv_mocked"][0]
+ actual = project["items"]["cv"][0]
assert expected == actual
@@ -282,7 +299,7 @@ def now(*args, **kwargs):
MockDatetime.NOW += MockDatetime.TIMEDELTA
return MockDatetime.NOW
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
for i in range(5):
in_memory_project.put(str(i), i)
diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_item.py
index 6ee82a3a7..af691fc06 100644
--- a/skore/tests/unit/item/test_cross_validation_item.py
+++ b/skore/tests/unit/item/test_cross_validation_item.py
@@ -4,9 +4,9 @@
import plotly.graph_objects
import pytest
from sklearn.model_selection import StratifiedKFold
-from skore.item.cross_validation_item import (
+from skore.persistence.item import ItemTypeError
+from skore.persistence.item.cross_validation_item import (
CrossValidationItem,
- ItemTypeError,
_hash_numpy,
)
from skore.sklearn.cross_validation import CrossValidationReporter
@@ -67,7 +67,7 @@ class FakeCrossValidationReporterNoGetParams(CrossValidationReporter):
class TestCrossValidationItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
@@ -111,7 +111,7 @@ def test_factory(self, mock_nowstr, reporter):
def test_get_serializable_dict(self, monkeypatch, mock_nowstr):
monkeypatch.setattr(
- "skore.item.cross_validation_item.CrossValidationReporter",
+ "skore.persistence.item.cross_validation_item.CrossValidationReporter",
FakeCrossValidationReporter,
)
diff --git a/skore/tests/unit/item/test_item_repository.py b/skore/tests/unit/item/test_item_repository.py
index 79b93548d..b2672bd30 100644
--- a/skore/tests/unit/item/test_item_repository.py
+++ b/skore/tests/unit/item/test_item_repository.py
@@ -1,7 +1,8 @@
from datetime import datetime, timezone
import pytest
-from skore.item import ItemRepository, MediaItem
+from skore.persistence.item import MediaItem
+from skore.persistence.repository import ItemRepository
class TestItemRepository:
diff --git a/skore/tests/unit/item/test_media_item.py b/skore/tests/unit/item/test_media_item.py
index 97d4b579e..024f2fb5a 100644
--- a/skore/tests/unit/item/test_media_item.py
+++ b/skore/tests/unit/item/test_media_item.py
@@ -5,13 +5,13 @@
import PIL as pillow
import plotly.graph_objects as go
import pytest
-from skore.item import ItemTypeError, MediaItem
+from skore.persistence.item import ItemTypeError, MediaItem
class TestMediaItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_numpy_array_item.py b/skore/tests/unit/item/test_numpy_array_item.py
index a793cfed1..ed1c29317 100644
--- a/skore/tests/unit/item/test_numpy_array_item.py
+++ b/skore/tests/unit/item/test_numpy_array_item.py
@@ -2,13 +2,13 @@
import numpy
import pytest
-from skore.item import ItemTypeError, NumpyArrayItem
+from skore.persistence.item import ItemTypeError, NumpyArrayItem
class TestNumpyArrayItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_pandas_dataframe_item.py b/skore/tests/unit/item/test_pandas_dataframe_item.py
index da9b8f80a..2b6b0edf9 100644
--- a/skore/tests/unit/item/test_pandas_dataframe_item.py
+++ b/skore/tests/unit/item/test_pandas_dataframe_item.py
@@ -2,13 +2,13 @@
import pytest
from pandas import DataFrame, Index, MultiIndex
from pandas.testing import assert_frame_equal
-from skore.item import ItemTypeError, PandasDataFrameItem
+from skore.persistence.item import ItemTypeError, PandasDataFrameItem
class TestPandasDataFrameItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_pandas_series_item.py b/skore/tests/unit/item/test_pandas_series_item.py
index eb160f765..4a2e396ac 100644
--- a/skore/tests/unit/item/test_pandas_series_item.py
+++ b/skore/tests/unit/item/test_pandas_series_item.py
@@ -2,13 +2,13 @@
import pytest
from pandas import Index, MultiIndex, Series
from pandas.testing import assert_series_equal
-from skore.item import ItemTypeError, PandasSeriesItem
+from skore.persistence.item import ItemTypeError, PandasSeriesItem
class TestPandasSeriesItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_polars_dataframe_item.py b/skore/tests/unit/item/test_polars_dataframe_item.py
index 8be5250fc..0335e7d06 100644
--- a/skore/tests/unit/item/test_polars_dataframe_item.py
+++ b/skore/tests/unit/item/test_polars_dataframe_item.py
@@ -2,14 +2,14 @@
import pytest
from polars import DataFrame
from polars.testing import assert_frame_equal
-from skore.item import ItemTypeError, PolarsDataFrameItem
-from skore.item.polars_dataframe_item import PolarsToJSONError
+from skore.persistence.item import ItemTypeError, PolarsDataFrameItem
+from skore.persistence.item.polars_dataframe_item import PolarsToJSONError
class TestPolarsDataFrameItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_polars_series_item.py b/skore/tests/unit/item/test_polars_series_item.py
index 8ca235c41..1e40a0c3f 100644
--- a/skore/tests/unit/item/test_polars_series_item.py
+++ b/skore/tests/unit/item/test_polars_series_item.py
@@ -2,13 +2,13 @@
import pytest
from polars import Series
from polars.testing import assert_series_equal
-from skore.item import ItemTypeError, PolarsSeriesItem
+from skore.persistence.item import ItemTypeError, PolarsSeriesItem
class TestPolarsSeriesItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_primitive_item.py b/skore/tests/unit/item/test_primitive_item.py
index 2a97f6eff..babcdf2b8 100644
--- a/skore/tests/unit/item/test_primitive_item.py
+++ b/skore/tests/unit/item/test_primitive_item.py
@@ -1,5 +1,5 @@
import pytest
-from skore.item import ItemTypeError, PrimitiveItem
+from skore.persistence.item import ItemTypeError, PrimitiveItem
class TestPrimitiveItem:
@@ -16,7 +16,7 @@ class TestPrimitiveItem:
],
)
def test_factory(self, monkeypatch, mock_nowstr, MockDatetime, primitive):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
item = PrimitiveItem.factory(primitive)
@@ -43,7 +43,7 @@ def test_factory_exception(self):
def test_get_serializable_dict(
self, monkeypatch, mock_nowstr, MockDatetime, primitive
):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
item = PrimitiveItem.factory(primitive)
serializable = item.as_serializable_dict()
diff --git a/skore/tests/unit/item/test_sklearn_base_estimator_item.py b/skore/tests/unit/item/test_sklearn_base_estimator_item.py
index 469a705bc..d0cbe93cf 100644
--- a/skore/tests/unit/item/test_sklearn_base_estimator_item.py
+++ b/skore/tests/unit/item/test_sklearn_base_estimator_item.py
@@ -1,7 +1,7 @@
import pytest
import sklearn.svm
import skops.io
-from skore.item import ItemTypeError, SklearnBaseEstimatorItem
+from skore.persistence.item import ItemTypeError, SklearnBaseEstimatorItem
class Estimator(sklearn.svm.SVC):
@@ -11,7 +11,7 @@ class Estimator(sklearn.svm.SVC):
class TestSklearnBaseEstimatorItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
diff --git a/skore/tests/unit/item/test_skrub_table_report_item.py b/skore/tests/unit/item/test_skrub_table_report_item.py
index eef317096..d35236166 100644
--- a/skore/tests/unit/item/test_skrub_table_report_item.py
+++ b/skore/tests/unit/item/test_skrub_table_report_item.py
@@ -1,12 +1,12 @@
import pytest
from pandas import DataFrame
-from skore.item import ItemTypeError, SkrubTableReportItem
+from skore.persistence.item import ItemTypeError, SkrubTableReportItem
from skrub import TableReport
class TestSkrubTableReportItem:
def test_factory(self, monkeypatch, mock_nowstr, MockDatetime):
- monkeypatch.setattr("skore.item.item.datetime", MockDatetime)
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
monkeypatch.setattr("secrets.token_hex", lambda: "azertyuiop")
df = DataFrame(dict(a=[1, 2], b=["one", "two"], c=[11.1, 11.1]))
diff --git a/skore/tests/unit/persistence/test_disk.py b/skore/tests/unit/persistence/test_disk.py
index 77838978e..7c747c400 100644
--- a/skore/tests/unit/persistence/test_disk.py
+++ b/skore/tests/unit/persistence/test_disk.py
@@ -2,7 +2,7 @@
import shutil
from pathlib import Path
-from skore.persistence.disk_cache_storage import DiskCacheStorage
+from skore.persistence.storage import DiskCacheStorage
def test_disk_storage(tmp_path: Path):
diff --git a/skore/tests/unit/persistence/test_memory.py b/skore/tests/unit/persistence/test_memory.py
index f28dcc73d..53608496e 100644
--- a/skore/tests/unit/persistence/test_memory.py
+++ b/skore/tests/unit/persistence/test_memory.py
@@ -1,4 +1,4 @@
-from skore.persistence.in_memory_storage import InMemoryStorage
+from skore.persistence.storage import InMemoryStorage
def test_in_memory_storage():
diff --git a/skore/tests/unit/test_project.py b/skore/tests/unit/test_project.py
index 88b256e54..662ed5e90 100644
--- a/skore/tests/unit/test_project.py
+++ b/skore/tests/unit/test_project.py
@@ -17,6 +17,7 @@
ProjectAlreadyExistsError,
ProjectCreationError,
)
+from skore.persistence.view.view import View
from skore.project import (
Project,
create,
@@ -24,7 +25,6 @@
)
from skore.project.create import _validate_project_name
from skore.project.load import ProjectLoadError
-from skore.view.view import View
def test_put_string_item(in_memory_project):
@@ -262,17 +262,15 @@ def test_put_several_nested(in_memory_project):
assert in_memory_project.get("a") == {"b": "baz"}
-def test_put_several_error(in_memory_project):
- """If some key-value pairs are wrong, add all that are valid and print a warning."""
- with pytest.raises(NotImplementedError):
- in_memory_project.put(
- {
- "a": "foo",
- "b": (lambda: "unsupported object"),
- }
- )
+def test_put_several_lambda(in_memory_project):
+ in_memory_project.put(
+ {
+ "a": "foo",
+ "b": (lambda: "unsupported object"),
+ }
+ )
- assert in_memory_project.list_item_keys() == ["a"]
+ assert in_memory_project.list_item_keys() == ["a", "b"]
def test_put_key_is_a_tuple(in_memory_project):
diff --git a/skore/tests/unit/view/test_view_repository.py b/skore/tests/unit/view/test_view_repository.py
index ed46d271d..87c21cea1 100644
--- a/skore/tests/unit/view/test_view_repository.py
+++ b/skore/tests/unit/view/test_view_repository.py
@@ -1,7 +1,7 @@
import pytest
-from skore.persistence.in_memory_storage import InMemoryStorage
-from skore.view.view import View
-from skore.view.view_repository import ViewRepository
+from skore.persistence.repository import ViewRepository
+from skore.persistence.storage import InMemoryStorage
+from skore.persistence.view.view import View
@pytest.fixture
From 9a3698ca002365d992b07fceb2129ccbdd29644d Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Thu, 9 Jan 2025 10:48:45 +0100
Subject: [PATCH 07/22] Change the way a pickle item is exported to UI
---
skore/src/skore/persistence/item/pickle_item.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py
index 42a0ace5d..2f6d8ff55 100644
--- a/skore/src/skore/persistence/item/pickle_item.py
+++ b/skore/src/skore/persistence/item/pickle_item.py
@@ -25,3 +25,9 @@ def object(self) -> Any:
@classmethod
def factory(cls, object: Any) -> PickleItem:
return cls(dumps(object))
+
+ def as_serializable_dict(self):
+ return super().as_serializable_dict() | {
+ "media_type": "text/markdown",
+ "value": repr(self.object),
+ }
From 8643aa9b7d66e148441b235bb67f9ef4e7f40108 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Fri, 10 Jan 2025 11:30:09 +0100
Subject: [PATCH 08/22] Add tests
---
skore/src/skore/persistence/item/__init__.py | 6 +--
.../persistence/repository/item_repository.py | 26 +-----------
skore/tests/integration/ui/test_ui.py | 29 ++++++++++++++
skore/tests/unit/item/test_pickle_item.py | 40 +++++++++++++++++++
skore/tests/unit/project/test_project.py | 9 +----
5 files changed, 76 insertions(+), 34 deletions(-)
create mode 100644 skore/tests/unit/item/test_pickle_item.py
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index 7b2d2f002..79b0eb138 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -43,7 +43,7 @@ def object_to_item(object: Any) -> Item:
# correct type. If not, they throw a ItemTypeError exception.
return cls.factory(object)
- return PickleItem(object)
+ return PickleItem.factory(object)
def item_to_object(item: Item) -> Any:
@@ -51,9 +51,9 @@ def item_to_object(item: Item) -> Any:
return item.primitive
elif isinstance(item, NumpyArrayItem):
return item.array
- elif isinstance(item, PandasDataFrameItem) or isinstance(item, PolarsDataFrameItem):
+ elif isinstance(item, (PandasDataFrameItem, PolarsDataFrameItem)):
return item.dataframe
- elif isinstance(item, PandasSeriesItem) or isinstance(item, PolarsSeriesItem):
+ elif isinstance(item, (PandasSeriesItem, PolarsSeriesItem)):
return item.series
elif isinstance(item, SklearnBaseEstimatorItem):
return item.estimator
diff --git a/skore/src/skore/persistence/repository/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py
index aea849f80..dd99258ef 100644
--- a/skore/src/skore/persistence/repository/item_repository.py
+++ b/skore/src/skore/persistence/repository/item_repository.py
@@ -8,17 +8,7 @@
from typing import TYPE_CHECKING
-from skore.persistence.item import (
- CrossValidationItem,
- MediaItem,
- NumpyArrayItem,
- PandasDataFrameItem,
- PandasSeriesItem,
- PolarsDataFrameItem,
- PolarsSeriesItem,
- PrimitiveItem,
- SklearnBaseEstimatorItem,
-)
+import skore.persistence.item
if TYPE_CHECKING:
from skore.persistence.item import Item
@@ -35,18 +25,6 @@ class ItemRepository:
storage as a map from keys to *lists of* values.
"""
- ITEM_CLASS_NAME_TO_ITEM_CLASS = {
- "MediaItem": MediaItem,
- "NumpyArrayItem": NumpyArrayItem,
- "PandasDataFrameItem": PandasDataFrameItem,
- "PandasSeriesItem": PandasSeriesItem,
- "PolarsDataFrameItem": PolarsDataFrameItem,
- "PolarsSeriesItem": PolarsSeriesItem,
- "PrimitiveItem": PrimitiveItem,
- "CrossValidationItem": CrossValidationItem,
- "SklearnBaseEstimatorItem": SklearnBaseEstimatorItem,
- }
-
def __init__(self, storage: AbstractStorage):
"""
Initialize the ItemRepository with a storage system.
@@ -68,7 +46,7 @@ def __deconstruct_item(item: Item) -> dict:
@staticmethod
def __construct_item(value) -> Item:
item_class_name = value["item_class_name"]
- item_class = ItemRepository.ITEM_CLASS_NAME_TO_ITEM_CLASS[item_class_name]
+ item_class = getattr(skore.persistence.item, item_class_name)
item = value["item"]
return item_class(**item)
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index a4d641ee7..360761a6f 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -337,3 +337,32 @@ def now(*args, **kwargs):
("5", 5),
("4", 5),
]
+
+
+def test_get_items_with_pickle_item(
+ monkeypatch,
+ MockDatetime,
+ mock_nowstr,
+ client,
+ in_memory_project,
+):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+ in_memory_project.put("pickle", object)
+
+ response = client.get("/api/project/items")
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "items": {
+ "pickle": [
+ {
+ "created_at": mock_nowstr,
+ "updated_at": mock_nowstr,
+ "name": "pickle",
+ "media_type": "text/markdown",
+ "value": "",
+ },
+ ],
+ },
+ "views": {},
+ }
diff --git a/skore/tests/unit/item/test_pickle_item.py b/skore/tests/unit/item/test_pickle_item.py
new file mode 100644
index 000000000..acbc3924e
--- /dev/null
+++ b/skore/tests/unit/item/test_pickle_item.py
@@ -0,0 +1,40 @@
+import pickle
+
+import pytest
+from skore.persistence.item import PickleItem
+
+
+class TestPickleItem:
+ @pytest.fixture(autouse=True)
+ def monkeypatch_datetime(self, monkeypatch, MockDatetime):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+
+ @pytest.mark.parametrize("object", [0, 0.0, int, True, [0], {0: 0}])
+ def test_factory(self, mock_nowstr, object):
+ item = PickleItem.factory(object)
+
+ assert item.pickle_bytes == pickle.dumps(object)
+ assert item.created_at == mock_nowstr
+ assert item.updated_at == mock_nowstr
+
+ def test_object(self, mock_nowstr):
+ item1 = PickleItem.factory(int)
+ item2 = PickleItem(
+ pickle_bytes=pickle.dumps(int),
+ created_at=mock_nowstr,
+ updated_at=mock_nowstr,
+ )
+
+ assert item1.object is int
+ assert item2.object is int
+
+ def test_get_serializable_dict(self, mock_nowstr):
+ item = PickleItem.factory(int)
+ serializable = item.as_serializable_dict()
+
+ assert serializable == {
+ "updated_at": mock_nowstr,
+ "created_at": mock_nowstr,
+ "media_type": "text/markdown",
+ "value": repr(int),
+ }
diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py
index a1da72d71..13414b7e5 100644
--- a/skore/tests/unit/project/test_project.py
+++ b/skore/tests/unit/project/test_project.py
@@ -235,13 +235,8 @@ def test_put_several_nested(in_memory_project):
assert in_memory_project.get("a") == {"b": "baz"}
-def test_put_several_lambda(in_memory_project):
- in_memory_project.put(
- {
- "a": "foo",
- "b": (lambda: "unsupported object"),
- }
- )
+def test_put_several_complex(in_memory_project):
+ in_memory_project.put({"a": int, "b": float})
assert in_memory_project.list_item_keys() == ["a", "b"]
From 80df8dbdf67fc488d7c90ca5ea2c0702ca535cda Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Fri, 10 Jan 2025 15:05:48 +0100
Subject: [PATCH 09/22] Add docstring
---
skore/src/skore/persistence/item/__init__.py | 1 +
.../src/skore/persistence/item/pickle_item.py | 41 +++++++++++++++++++
.../skore/persistence/repository/__init__.py | 2 +
.../src/skore/persistence/storage/__init__.py | 2 +
4 files changed, 46 insertions(+)
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index 79b0eb138..fca1e5c18 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -47,6 +47,7 @@ def object_to_item(object: Any) -> Item:
def item_to_object(item: Item) -> Any:
+ """Transform an Item into its original object."""
if isinstance(item, PrimitiveItem):
return item.primitive
elif isinstance(item, NumpyArrayItem):
diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py
index 2f6d8ff55..30c627cde 100644
--- a/skore/src/skore/persistence/item/pickle_item.py
+++ b/skore/src/skore/persistence/item/pickle_item.py
@@ -1,3 +1,9 @@
+"""PickleItem.
+
+This module defines the PickleItem class, which is used to persist objects that cannot
+be otherwise.
+"""
+
from __future__ import annotations
from functools import cached_property
@@ -8,25 +14,60 @@
class PickleItem(Item):
+ """
+ A class to represent any object item.
+
+ This class is generally used to persist objects that cannot be otherwise.
+ It encapsulates the object with its pickle representaton, its creation and update
+ timestamps.
+ """
+
def __init__(
self,
pickle_bytes: bytes,
created_at: str | None = None,
updated_at: str | None = None,
):
+ """
+ Initialize a PickleItem.
+
+ Parameters
+ ----------
+ pickle_bytes : bytes
+ The raw bytes of the object pickle representation.
+ created_at : str
+ The creation timestamp in ISO format.
+ updated_at : str
+ The last update timestamp in ISO format.
+ """
super().__init__(created_at, updated_at)
self.pickle_bytes = pickle_bytes
@cached_property
def object(self) -> Any:
+ """The object from the persistence."""
return loads(self.pickle_bytes)
@classmethod
def factory(cls, object: Any) -> PickleItem:
+ """
+ Create a new PickleItem with any object.
+
+ Parameters
+ ----------
+ object: Any
+ The object to store.
+
+ Returns
+ -------
+ PickleItem
+ A new PickleItem instance.
+ """
return cls(dumps(object))
def as_serializable_dict(self):
+ """Get a JSON serializable representation of the item."""
return super().as_serializable_dict() | {
"media_type": "text/markdown",
"value": repr(self.object),
diff --git a/skore/src/skore/persistence/repository/__init__.py b/skore/src/skore/persistence/repository/__init__.py
index c088fdacf..845bb362c 100644
--- a/skore/src/skore/persistence/repository/__init__.py
+++ b/skore/src/skore/persistence/repository/__init__.py
@@ -1,3 +1,5 @@
+"""Provide a set of classes responsible for manipulating items and views."""
+
from .item_repository import ItemRepository
from .view_repository import ViewRepository
diff --git a/skore/src/skore/persistence/storage/__init__.py b/skore/src/skore/persistence/storage/__init__.py
index 21b6a3b1a..75bb67428 100644
--- a/skore/src/skore/persistence/storage/__init__.py
+++ b/skore/src/skore/persistence/storage/__init__.py
@@ -1,3 +1,5 @@
+"""Provide a set of storage classes for storing and retrieving data."""
+
from .abstract_storage import AbstractStorage
from .disk_cache_storage import DiskCacheStorage
from .in_memory_storage import InMemoryStorage
From c8f3ecfa0b5dd7f26bf4e72954d61be319bf83e7 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Fri, 10 Jan 2025 17:29:32 +0100
Subject: [PATCH 10/22] WIP - Change examples
---
.../getting_started/plot_tracking_items.py | 2 +-
.../plot_working_with_projects.py | 18 ++++++------------
.../model_evaluation/plot_cross_validate.py | 2 +-
3 files changed, 8 insertions(+), 14 deletions(-)
diff --git a/examples/getting_started/plot_tracking_items.py b/examples/getting_started/plot_tracking_items.py
index 11bfe30aa..ec8e1a5fb 100644
--- a/examples/getting_started/plot_tracking_items.py
+++ b/examples/getting_started/plot_tracking_items.py
@@ -73,7 +73,7 @@
# We retrieve the history of the ``my_int`` item:
# %%
-item_histories = my_project.get_item_versions("my_int")
+item_histories = my_project.get_item_versions("my_int") # TO CHANGE /!\
# %%
# We can print the first history (first iteration) of this item:
diff --git a/examples/getting_started/plot_working_with_projects.py b/examples/getting_started/plot_working_with_projects.py
index 0b16b2074..3d672df1d 100644
--- a/examples/getting_started/plot_working_with_projects.py
+++ b/examples/getting_started/plot_working_with_projects.py
@@ -119,25 +119,19 @@ def my_func(x):
)
# %%
-# Moreover, we can also explicitly tell skore the media type of an object, for example
-# in HTML:
+# Moreover, we can also explicitly tell skore the way we want to display an object, for
+# example in HTML:
# %%
-from skore.item import MediaItem
-my_project.put_item(
+my_project.put(
"my_string_3",
- MediaItem.factory(
- "Title
bold, italic, etc.
", media_type="text/html"
- ),
+ "Title
bold, italic, etc.",
+ display_as="html",
)
# %%
-# .. note::
-# We used :func:`~skore.Project.put_item` instead of :func:`~skore.Project.put`.
-
-# %%
-# Note that the media type is only used for the UI, and not in this notebook at hand:
+# Note that the `display_as` is only used for the UI, and not in this notebook at hand:
# %%
my_project.get("my_string_3")
diff --git a/examples/model_evaluation/plot_cross_validate.py b/examples/model_evaluation/plot_cross_validate.py
index 9422e37b0..a47b6b74f 100644
--- a/examples/model_evaluation/plot_cross_validate.py
+++ b/examples/model_evaluation/plot_cross_validate.py
@@ -157,7 +157,7 @@
# %%
# We can also access the plot after we have stored the ``CrossValidationReporter``:
my_project.put("cross_validation_regression", reporter)
-cv_item = my_project.get_item("cross_validation_regression")
+cv_item = my_project.get_item("cross_validation_regression") # TO CHANGE /!\
cv_item.plots["Scores"]
# %%
From c4b2ee8cdac0bdf69ef814bd5fef4e3cf23b1c85 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Mon, 13 Jan 2025 16:22:18 +0100
Subject: [PATCH 11/22] Rename CrossValidationItem to
CrossValidationReporterItem
---
skore/src/skore/persistence/item/__init__.py | 10 +++++-----
...ation_item.py => cross_validation_reporter_item.py} | 0
2 files changed, 5 insertions(+), 5 deletions(-)
rename skore/src/skore/persistence/item/{cross_validation_item.py => cross_validation_reporter_item.py} (100%)
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index fca1e5c18..bc3890ce5 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -6,7 +6,7 @@
from typing import Any
from . import skrub_table_report_item as SkrubTableReportItem
-from .cross_validation_item import CrossValidationItem
+from .cross_validation_reporter_item import CrossValidationReporterItem
from .item import Item, ItemTypeError
from .media_item import MediaItem
from .numpy_array_item import NumpyArrayItem
@@ -31,7 +31,7 @@ def object_to_item(object: Any) -> Item:
SklearnBaseEstimatorItem,
MediaItem,
SkrubTableReportItem,
- CrossValidationItem,
+ CrossValidationReporterItem,
):
with suppress(ImportError, ItemTypeError):
# ImportError:
@@ -58,8 +58,8 @@ def item_to_object(item: Item) -> Any:
return item.series
elif isinstance(item, SklearnBaseEstimatorItem):
return item.estimator
- elif isinstance(item, CrossValidationItem):
- return item.cv_results_serialized
+ elif isinstance(item, CrossValidationReporterItem):
+ return item.reporter
elif isinstance(item, MediaItem):
return item.media_bytes
elif isinstance(item, PickleItem):
@@ -69,7 +69,7 @@ def item_to_object(item: Item) -> Any:
__all__ = [
- "CrossValidationItem",
+ "CrossValidationReporterItem",
"Item",
"MediaItem",
"NumpyArrayItem",
diff --git a/skore/src/skore/persistence/item/cross_validation_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
similarity index 100%
rename from skore/src/skore/persistence/item/cross_validation_item.py
rename to skore/src/skore/persistence/item/cross_validation_reporter_item.py
From 4d08a193321bbaa1e9d7a31e7868a76cbd24df5d Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Mon, 13 Jan 2025 17:06:18 +0100
Subject: [PATCH 12/22] CrossValidationReporterItem's factory now pickles the
reporter
---
.../item/cross_validation_reporter_item.py | 369 +++++++-----------
.../sklearn/test_cross_validate.py | 4 +-
skore/tests/integration/ui/test_ui.py | 5 +-
...=> test_cross_validation_reporter_item.py} | 58 ++-
4 files changed, 172 insertions(+), 264 deletions(-)
rename skore/tests/unit/item/{test_cross_validation_item.py => test_cross_validation_reporter_item.py} (80%)
diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
index 23c81dab9..263e1528c 100644
--- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py
+++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
@@ -1,20 +1,21 @@
-"""CrossValidationItem class.
+"""CrossValidationReporterItem.
-This class represents the output of a cross-validation workflow.
+This module defines the CrossValidationReporterItem class, which is used to persist
+reporters of cross-validation.
"""
from __future__ import annotations
import contextlib
-import copy
import dataclasses
import hashlib
import importlib
import json
+import pickle
import re
import statistics
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union
+from typing import TYPE_CHECKING, Literal, Optional, TypedDict
import numpy
import plotly.graph_objects
@@ -27,8 +28,6 @@
if TYPE_CHECKING:
import sklearn.base
- CVSplitter = Any
-
class EstimatorParamInfo(TypedDict):
"""Information about an estimator parameter."""
@@ -43,6 +42,12 @@ class EstimatorInfo(TypedDict):
params: dict[str, EstimatorParamInfo]
+HUMANIZED_PLOT_NAMES = {
+ "scores": "Scores",
+ "timing": "Timings",
+}
+
+
def _hash_numpy(arr: numpy.ndarray) -> str:
"""Compute a hash string from a numpy array.
@@ -117,86 +122,122 @@ def _params_to_str(estimator_info) -> str:
return "\n".join(params_list)
-# Data used for training, passed as input to scikit-learn
-Data = Any
-# Target used for training, passed as input to scikit-learn
-Target = Any
+def _estimator_info(estimator: sklearn.base.BaseEstimator) -> EstimatorInfo:
+ estimator_params = (
+ estimator.get_params() if hasattr(estimator, "get_params") else {}
+ )
+
+ name = estimator.__class__.__name__
+ module = estimator.__module__
+
+ # Figure out the default parameters of the estimator,
+ # so that we can highlight the non-default ones in the UI
+
+ # This is done by instantiating the class with no arguments and
+ # computing the diff between the default and ours
+ try:
+ estimator_module = importlib.import_module(module)
+ EstimatorClass = getattr(estimator_module, name)
+ default_estimator_params = EstimatorClass().get_params()
+ except Exception:
+ default_estimator_params = {}
+
+ final_estimator_params: dict[str, EstimatorParamInfo] = {}
+ for k, v in estimator_params.items():
+ param_is_default: bool = (
+ k in default_estimator_params and default_estimator_params[k] == v
+ )
+ final_estimator_params[str(k)] = {
+ "value": repr(v),
+ "default": param_is_default,
+ }
+ return {
+ "name": name,
+ "module": module,
+ "params": final_estimator_params,
+ }
-class CrossValidationItem(Item):
- """
- A class to represent the output of a cross-validation workflow.
- This class encapsulates the output of the
- :func:`sklearn.model_selection.cross_validate` function along with its creation and
- update timestamps.
- """
+class CrossValidationReporterItem(Item):
+ """Class to persist the reporter of cross-validation."""
def __init__(
self,
- cv_results_serialized: dict,
- estimator_info: EstimatorInfo,
- X_info: dict,
- y_info: Union[dict, None],
- plots_bytes: dict[str, bytes],
- cv_info: dict,
- created_at: Union[str, None] = None,
- updated_at: Union[str, None] = None,
+ reporter_bytes: bytes,
+ created_at: Optional[str] = None,
+ updated_at: Optional[str] = None,
):
"""
- Initialize a CrossValidationItem.
+ Initialize a CrossValidationReporterItem.
Parameters
----------
- cv_results_serialized : dict
- The dict output of the :func:`sklearn.model_selection.cross_validate`
- function, in a form suitable for serialization.
- estimator_info : dict
- The estimator that was cross-validated.
- X_info : dict
- A summary of the data, input of the
- :func:`sklearn.model_selection.cross_validate` function.
- y_info : dict
- A summary of the target, input of the
- :func:`sklearn.model_selection.cross_validate` function.
- plots_bytes : dict[str, bytes]
- A collection of plots of the cross-validation results, in the form of bytes.
- cv_info: dict
- A dict containing cross validation splitting strategy params.
- created_at : str
+ reporter_bytes : bytes
+ The raw bytes of the reporter pickle representation.
+ created_at : str, optional
The creation timestamp in ISO format.
- updated_at : str
+ updated_at : str, optional
The last update timestamp in ISO format.
"""
super().__init__(created_at, updated_at)
- self.cv_results_serialized = cv_results_serialized
- self.estimator_info = estimator_info
- self.X_info = X_info
- self.y_info = y_info
- self.plots_bytes = plots_bytes
- self.cv_info = cv_info
+ self.reporter_bytes = reporter_bytes
- def as_serializable_dict(self):
- """Get a serializable dict from the item.
+ @classmethod
+ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationReporterItem:
+ """
+ Create a CrossValidationReporterItem instance from a CrossValidationReporter.
- Derived class must call their super implementation
- and merge the result with their output.
+ Parameters
+ ----------
+ reporter : CrossValidationReporter
+
+ Returns
+ -------
+ CrossValidationReporterItem
+ A new CrossValidationReporterItem instance.
"""
+ if not isinstance(reporter, CrossValidationReporter):
+ raise ItemTypeError(f"Type '{reporter.__class__}' is not supported.")
+
+ instance = cls(pickle.dumps(reporter))
+
+ # add reporter as cached property
+ instance.reporter = reporter
+
+ return instance
+
+ @cached_property
+ def reporter(self) -> CrossValidationReporter:
+ """The CrossValidationReporter from the persistence."""
+ return pickle.loads(self.reporter_bytes)
+
+ def as_serializable_dict(self):
+ """Get a serializable dict from the item."""
# Get tabular results (the cv results in a dataframe-like structure)
- cv_results = copy.deepcopy(self.cv_results_serialized)
- cv_results.pop("indices", None)
-
- metrics_names = list(cv_results.keys())
- tabular_results = {
- "name": "Cross validation results",
- "columns": metrics_names,
- "data": list(zip(*cv_results.values())),
- "favorability": [_metric_favorability(m) for m in metrics_names],
+ cv_results = {
+ key: value.tolist()
+ for key, value in self.reporter.cv_results.items()
+ if (
+ key != "estimator"
+ and key != "indices"
+ and isinstance(value, numpy.ndarray)
+ )
}
+ metrics_names = list(cv_results)
+ tabular_results = [
+ {
+ "name": "Cross validation results",
+ "columns": metrics_names,
+ "data": list(zip(*cv_results.values())),
+ "favorability": [_metric_favorability(m) for m in metrics_names],
+ }
+ ]
+
# Get scalar results (summary statistics of the cv results)
- mean_cv_results = [
+ scalar_results = [
{
"name": _metric_title(k),
"value": statistics.mean(v),
@@ -206,163 +247,22 @@ def as_serializable_dict(self):
for k, v in cv_results.items()
]
- scalar_results = mean_cv_results
-
- params_as_str = _params_to_str(self.estimator_info)
-
# If the estimator is from sklearn, make the class name a hyperlink
# to the relevant docs
- name = self.estimator_info["name"]
- module = re.sub(r"\.\_.+", "", self.estimator_info["module"])
+ estimator_info = _estimator_info(self.reporter.estimator)
+ name = estimator_info["name"]
+ module = re.sub(r"\.\_.+", "", estimator_info["module"])
if module.startswith("sklearn"):
doc_url = f"https://scikit-learn.org/stable/modules/generated/{module}.{name}.html"
doc_link = f'{name}
'
else:
doc_link = f"`{name}`"
+ params_as_str = _params_to_str(estimator_info)
estimator_params_as_str = f"{doc_link}\n{params_as_str}"
- # Get cross-validation details
- cv_params_as_str = ", ".join(f"{k}: *{v}*" for k, v in self.cv_info.items())
-
- r = super().as_serializable_dict()
- sections = [
- {
- "title": "Model",
- "icon": "icon-square-cursor",
- "items": [
- {
- "name": "Estimator parameters",
- "description": "Core model configuration used for training",
- "value": estimator_params_as_str,
- },
- {
- "name": "Cross-validation parameters",
- "description": "Controls how data is split and validated",
- "value": cv_params_as_str,
- },
- ],
- }
- ]
- value = {
- "scalar_results": scalar_results,
- "tabular_results": [tabular_results],
- "plots": [
- {
- "name": plot_name,
- "value": json.loads(plot_bytes.decode("utf-8")),
- }
- for plot_name, plot_bytes in self.plots_bytes.items()
- ],
- "sections": sections,
- }
- r.update(
- {
- "media_type": "application/vnd.skore.cross_validation+json",
- "value": value,
- }
- )
- return r
-
- @staticmethod
- def _estimator_info(estimator: sklearn.base.BaseEstimator) -> EstimatorInfo:
- estimator_params = (
- estimator.get_params() if hasattr(estimator, "get_params") else {}
- )
-
- name = estimator.__class__.__name__
- module = estimator.__module__
-
- # Figure out the default parameters of the estimator,
- # so that we can highlight the non-default ones in the UI
-
- # This is done by instantiating the class with no arguments and
- # computing the diff between the default and ours
- try:
- estimator_module = importlib.import_module(module)
- EstimatorClass = getattr(estimator_module, name)
- default_estimator_params = EstimatorClass().get_params()
- except Exception:
- default_estimator_params = {}
-
- final_estimator_params: dict[str, EstimatorParamInfo] = {}
- for k, v in estimator_params.items():
- param_is_default: bool = (
- k in default_estimator_params and default_estimator_params[k] == v
- )
- final_estimator_params[str(k)] = {
- "value": repr(v),
- "default": param_is_default,
- }
-
- return {
- "name": name,
- "module": module,
- "params": final_estimator_params,
- }
-
- @classmethod
- def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem:
- """
- Create a new CrossValidationItem instance from a CrossValidationReporter.
-
- Parameters
- ----------
- reporter : CrossValidationReporter
-
- Returns
- -------
- CrossValidationItem
- A new CrossValidationItem instance.
- """
- if not isinstance(reporter, CrossValidationReporter):
- raise ItemTypeError(
- f"Type '{reporter.__class__}' is not supported, "
- f"only '{CrossValidationReporter.__name__}' is."
- )
-
- cv_results = reporter._cv_results
- estimator = reporter.estimator
- X = reporter.X
- y = reporter.y
- plots = reporter.plots
- cv = reporter.cv
-
- cv_results_serialized = {}
- for k, v in cv_results.items():
- if k == "estimator":
- continue
- if k == "indices":
- cv_results_serialized["indices"] = {
- "train": tuple(arr.tolist() for arr in v["train"]),
- "test": tuple(arr.tolist() for arr in v["test"]),
- }
- if isinstance(v, numpy.ndarray):
- cv_results_serialized[k] = v.tolist()
-
- estimator_info = CrossValidationItem._estimator_info(estimator)
-
- y_array = y if isinstance(y, numpy.ndarray) else numpy.array(y)
- y_info = None if y is None else {"hash": _hash_numpy(y_array)}
-
- X_array = X if isinstance(X, numpy.ndarray) else numpy.array(X)
- X_info = {
- "nb_rows": X_array.shape[0],
- "nb_cols": X_array.shape[1],
- "hash": _hash_numpy(X_array),
- }
-
- humanized_plot_names = {
- "scores": "Scores",
- "timing": "Timings",
- }
- plots_bytes = {
- humanized_plot_names[plot_name]: (
- plotly.io.to_json(plot, engine="json").encode("utf-8")
- )
- for plot_name, plot in dataclasses.asdict(plots).items()
- }
-
+ #
+ cv = self.reporter.cv
cv_info: dict[str, str] = {}
if isinstance(cv, int):
cv_info["n_splits"] = repr(cv)
@@ -376,19 +276,40 @@ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem:
attr = getattr(cv, attr_name)
cv_info[attr_name] = repr(attr)
- return cls(
- cv_results_serialized=cv_results_serialized,
- estimator_info=estimator_info,
- X_info=X_info,
- y_info=y_info,
- plots_bytes=plots_bytes,
- cv_info=cv_info,
- )
+ cv_params_as_str = ", ".join(f"{k}: *{v}*" for k, v in cv_info.items())
- @cached_property
- def plots(self) -> dict:
- """Various plots of the cross-validation results."""
- return {
- name: plotly.io.from_json(plot_bytes.decode("utf-8"))
- for name, plot_bytes in self.plots_bytes.items()
+ #
+ value = {
+ "scalar_results": scalar_results,
+ "tabular_results": tabular_results,
+ "plots": [
+ {
+ "name": HUMANIZED_PLOT_NAMES[plot_name],
+ "value": json.loads(plotly.io.to_json(plot, engine="json")),
+ }
+ for plot_name, plot in dataclasses.asdict(self.reporter.plots).items()
+ ],
+ "sections": [
+ {
+ "title": "Model",
+ "icon": "icon-square-cursor",
+ "items": [
+ {
+ "name": "Estimator parameters",
+ "description": "Core model configuration used for training",
+ "value": estimator_params_as_str,
+ },
+ {
+ "name": "Cross-validation parameters",
+ "description": "Controls how data is split and validated",
+ "value": cv_params_as_str,
+ },
+ ],
+ }
+ ],
+ }
+
+ return super().as_serializable_dict() | {
+ "media_type": "application/vnd.skore.cross_validation+json",
+ "value": value,
}
diff --git a/skore/tests/integration/sklearn/test_cross_validate.py b/skore/tests/integration/sklearn/test_cross_validate.py
index e39f6ba58..00758f751 100644
--- a/skore/tests/integration/sklearn/test_cross_validate.py
+++ b/skore/tests/integration/sklearn/test_cross_validate.py
@@ -10,7 +10,7 @@
from sklearn.multiclass import OneVsOneClassifier
from sklearn.svm import SVC
from skore import CrossValidationReporter
-from skore.persistence.item.cross_validation_item import CrossValidationItem
+from skore.persistence.item import CrossValidationReporterItem
from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add
@@ -201,7 +201,7 @@ def test_cross_validation_reporter(in_memory_project, fixture_name, request):
in_memory_project.put("cross-validation", reporter)
retrieved_item = in_memory_project.item_repository.get_item("cross-validation")
- assert isinstance(retrieved_item, CrossValidationItem)
+ assert isinstance(retrieved_item, CrossValidationReporterItem)
@pytest.mark.parametrize(
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index 360761a6f..44b15d7a3 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -167,7 +167,7 @@ def _fake_cross_validate(*args, **kwargs):
monkeypatch.setattr("sklearn.model_selection.cross_validate", _fake_cross_validate)
-def test_serialize_cross_validation_item(
+def test_serialize_cross_validation_reporter_item(
client,
in_memory_project,
monkeypatch,
@@ -176,9 +176,6 @@ def test_serialize_cross_validation_item(
fake_cross_validate,
):
monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
- monkeypatch.setattr(
- "skore.persistence.item.cross_validation_item.CrossValidationItem.plots", {}
- )
monkeypatch.setattr(
"skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_compare_scores",
lambda _: {},
diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_reporter_item.py
similarity index 80%
rename from skore/tests/unit/item/test_cross_validation_item.py
rename to skore/tests/unit/item/test_cross_validation_reporter_item.py
index 909eff872..973fc933d 100644
--- a/skore/tests/unit/item/test_cross_validation_item.py
+++ b/skore/tests/unit/item/test_cross_validation_reporter_item.py
@@ -1,13 +1,13 @@
from dataclasses import dataclass
+from pickle import dumps
import numpy
import plotly.graph_objects
import pytest
from sklearn.model_selection import StratifiedKFold
from skore.persistence.item import ItemTypeError
-from skore.persistence.item.cross_validation_item import (
- CrossValidationItem,
- _hash_numpy,
+from skore.persistence.item.cross_validation_reporter_item import (
+ CrossValidationReporterItem,
_metric_favorability,
)
from skore.sklearn.cross_validation import CrossValidationReporter
@@ -27,7 +27,7 @@ class FakeEstimatorNoGetParams:
@dataclass
class FakeCrossValidationReporter(CrossValidationReporter):
- _cv_results = {
+ cv_results = {
"test_score": numpy.array([1, 2, 3]),
"estimator": [FakeEstimator(), FakeEstimator(), FakeEstimator()],
"fit_time": [1, 2, 3],
@@ -44,7 +44,7 @@ class FakeCrossValidationReporter(CrossValidationReporter):
@dataclass
class FakeCrossValidationReporterNoGetParams(CrossValidationReporter):
- _cv_results = {
+ cv_results = {
"test_score": numpy.array([1, 2, 3]),
"estimator": [
FakeEstimatorNoGetParams(),
@@ -63,14 +63,14 @@ class FakeCrossValidationReporterNoGetParams(CrossValidationReporter):
cv = StratifiedKFold(n_splits=5)
-class TestCrossValidationItem:
+class TestCrossValidationReporterItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
def test_factory_exception(self):
with pytest.raises(ItemTypeError):
- CrossValidationItem.factory(None)
+ CrossValidationReporterItem.factory(None)
@pytest.mark.parametrize(
"reporter",
@@ -82,42 +82,32 @@ def test_factory_exception(self):
],
)
def test_factory(self, mock_nowstr, reporter):
- item = CrossValidationItem.factory(reporter)
-
- assert item.cv_results_serialized == {"test_score": [1, 2, 3]}
- assert item.estimator_info == {
- "name": reporter.estimator.__class__.__name__,
- "params": (
- {}
- if isinstance(reporter.estimator, FakeEstimatorNoGetParams)
- else {"alpha": {"value": "3", "default": True}}
- ),
- "module": "tests.unit.item.test_cross_validation_item",
- }
- assert item.X_info == {
- "nb_cols": 1,
- "nb_rows": 1,
- "hash": _hash_numpy(FakeCrossValidationReporter.X),
- }
- assert item.y_info == {"hash": _hash_numpy(FakeCrossValidationReporter.y)}
- assert item.cv_info == {
- "n_splits": "5",
- "random_state": "None",
- "shuffle": "False",
- }
- assert isinstance(item.plots_bytes, dict)
- assert isinstance(item.plots, dict)
+ item = CrossValidationReporterItem.factory(reporter)
+
+ assert item.reporter_bytes == dumps(reporter)
assert item.created_at == mock_nowstr
assert item.updated_at == mock_nowstr
+ def test_reporter(self, mock_nowstr):
+ reporter = FakeCrossValidationReporter()
+ item1 = CrossValidationReporterItem.factory(reporter)
+ item2 = CrossValidationReporterItem(
+ reporter_bytes=dumps(reporter),
+ created_at=mock_nowstr,
+ updated_at=mock_nowstr,
+ )
+
+ assert item1.reporter == reporter
+ assert item2.reporter == reporter
+
def test_get_serializable_dict(self, monkeypatch, mock_nowstr):
monkeypatch.setattr(
- "skore.persistence.item.cross_validation_item.CrossValidationReporter",
+ "skore.persistence.item.cross_validation_reporter_item.CrossValidationReporter",
FakeCrossValidationReporter,
)
reporter = FakeCrossValidationReporter()
- item = CrossValidationItem.factory(reporter)
+ item = CrossValidationReporterItem.factory(reporter)
serializable = item.as_serializable_dict()
assert serializable["updated_at"] == mock_nowstr
From 43a296bd956d5841695bb6293f39f21f45a45e8a Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 14 Jan 2025 15:34:55 +0100
Subject: [PATCH 13/22] [skip ci] Use `reporter._cv_results` instead of
`reporter.cv_results`
---
.../persistence/item/cross_validation_reporter_item.py | 8 ++------
.../unit/item/test_cross_validation_reporter_item.py | 4 ++--
2 files changed, 4 insertions(+), 8 deletions(-)
diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
index 263e1528c..258a30be2 100644
--- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py
+++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
@@ -218,12 +218,8 @@ def as_serializable_dict(self):
# Get tabular results (the cv results in a dataframe-like structure)
cv_results = {
key: value.tolist()
- for key, value in self.reporter.cv_results.items()
- if (
- key != "estimator"
- and key != "indices"
- and isinstance(value, numpy.ndarray)
- )
+ for key, value in self.reporter._cv_results.items()
+ if key not in ("estimator", "indices") and isinstance(value, numpy.ndarray)
}
metrics_names = list(cv_results)
diff --git a/skore/tests/unit/item/test_cross_validation_reporter_item.py b/skore/tests/unit/item/test_cross_validation_reporter_item.py
index 973fc933d..0979d5fc3 100644
--- a/skore/tests/unit/item/test_cross_validation_reporter_item.py
+++ b/skore/tests/unit/item/test_cross_validation_reporter_item.py
@@ -27,7 +27,7 @@ class FakeEstimatorNoGetParams:
@dataclass
class FakeCrossValidationReporter(CrossValidationReporter):
- cv_results = {
+ _cv_results = {
"test_score": numpy.array([1, 2, 3]),
"estimator": [FakeEstimator(), FakeEstimator(), FakeEstimator()],
"fit_time": [1, 2, 3],
@@ -44,7 +44,7 @@ class FakeCrossValidationReporter(CrossValidationReporter):
@dataclass
class FakeCrossValidationReporterNoGetParams(CrossValidationReporter):
- cv_results = {
+ _cv_results = {
"test_score": numpy.array([1, 2, 3]),
"estimator": [
FakeEstimatorNoGetParams(),
From 459e5488917bd833bce32c76bf78abbb3432d4f7 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Tue, 14 Jan 2025 17:45:21 +0100
Subject: [PATCH 14/22] Add `display_as` parameter
---
skore/src/skore/persistence/item/__init__.py | 79 +++++++++++++------
.../src/skore/persistence/item/media_item.py | 11 ++-
skore/src/skore/project/project.py | 32 +++++---
skore/tests/unit/project/test_display_as.py | 35 ++++++++
4 files changed, 120 insertions(+), 37 deletions(-)
create mode 100644 skore/tests/unit/project/test_display_as.py
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index bc3890ce5..aa9ea9319 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -3,12 +3,12 @@
from __future__ import annotations
from contextlib import suppress
-from typing import Any
+from typing import Any, Literal, Optional
from . import skrub_table_report_item as SkrubTableReportItem
from .cross_validation_reporter_item import CrossValidationReporterItem
from .item import Item, ItemTypeError
-from .media_item import MediaItem
+from .media_item import MediaItem, MediaType
from .numpy_array_item import NumpyArrayItem
from .pandas_dataframe_item import PandasDataFrameItem
from .pandas_series_item import PandasSeriesItem
@@ -19,31 +19,60 @@
from .sklearn_base_estimator_item import SklearnBaseEstimatorItem
-def object_to_item(object: Any) -> Item:
+def object_to_item(
+ object: Any,
+ /,
+ *,
+ note: Optional[str] = None,
+ display_as: Optional[Literal["HTML", "MARKDOWN", "SVG"]] = None,
+) -> Item:
"""Transform an object into an Item."""
- for cls in (
- PrimitiveItem,
- PandasDataFrameItem,
- PandasSeriesItem,
- PolarsDataFrameItem,
- PolarsSeriesItem,
- NumpyArrayItem,
- SklearnBaseEstimatorItem,
- MediaItem,
- SkrubTableReportItem,
- CrossValidationReporterItem,
- ):
- with suppress(ImportError, ItemTypeError):
- # ImportError:
- # The factories are responsible to import third-party libraries in a
- # lazy way. If library is missing, an ImportError exception will
- # automatically be thrown.
- # ItemTypeError:
- # The factories are responsible for checking that parameters are of the
- # correct type. If not, they throw a ItemTypeError exception.
- return cls.factory(object)
+ if display_as is not None:
+ if not isinstance(object, str):
+ raise TypeError("`object` must be a str if `display_as` is specified")
- return PickleItem.factory(object)
+ if display_as not in MediaType.__members__:
+ raise ValueError(f"`display_as` must be in {list(MediaType.__members__)}")
+
+ item = MediaItem.factory_str(
+ media=object,
+ media_type=MediaType[display_as].value,
+ )
+ else:
+ for cls in (
+ PrimitiveItem,
+ PandasDataFrameItem,
+ PandasSeriesItem,
+ PolarsDataFrameItem,
+ PolarsSeriesItem,
+ NumpyArrayItem,
+ SklearnBaseEstimatorItem,
+ MediaItem,
+ SkrubTableReportItem,
+ CrossValidationReporterItem,
+ ):
+ with suppress(ImportError, ItemTypeError):
+ # ImportError:
+ # The factories are responsible to import third-party libraries in a
+ # lazy way. If library is missing, an ImportError exception will
+ # automatically be thrown.
+ # ItemTypeError:
+ # The factories are responsible for checking that parameters are of
+ # the correct type. If not, they throw a ItemTypeError exception.
+ item = cls.factory(object)
+ break
+ else:
+ item = PickleItem.factory(object)
+
+ if not isinstance(note, (type(None), str)):
+ raise TypeError(f"`note` must be a string (found '{type(note)}')")
+
+ # Since the item classes are now private, and to avoid having to pass the `note`
+ # parameter in the factories of each item class, we define the content of the
+ # `note` attribute dynamically.
+ item.note = note
+
+ return item
def item_to_object(item: Item) -> Any:
diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py
index 3736a2345..e57434743 100644
--- a/skore/src/skore/persistence/item/media_item.py
+++ b/skore/src/skore/persistence/item/media_item.py
@@ -6,6 +6,7 @@
from __future__ import annotations
import base64
+from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, Union
@@ -25,6 +26,14 @@ def lazy_is_instance(object: Any, cls_fullname: str) -> bool:
}
+class MediaType(Enum):
+ """Enum used to map aliases and media types."""
+
+ HTML = "text/html"
+ MARKDOWN = "text/markdown"
+ SVG = "image/svg+xml"
+
+
class MediaItem(Item):
"""
A class to represent a media item.
@@ -165,7 +174,7 @@ def factory_str(cls, media: str, media_type: str = "text/markdown") -> MediaItem
media : str
The string content to store.
media_type : str, optional
- The MIME type of the media content, by default "text/html".
+ The MIME type of the media content, by default "text/markdown".
Returns
-------
diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py
index cb335feb5..365a7c97e 100644
--- a/skore/src/skore/project/project.py
+++ b/skore/src/skore/project/project.py
@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from skore.persistence.item import item_to_object, object_to_item
@@ -73,7 +73,14 @@ def __init__(
self.item_repository = item_repository
self.view_repository = view_repository
- def put(self, key: str, value: Any, *, note: Optional[str] = None):
+ def put(
+ self,
+ key: str,
+ value: Any,
+ *,
+ note: Optional[str] = None,
+ display_as: Optional[Literal["HTML", "MARKDOWN", "SVG"]] = None,
+ ):
"""Add a key-value pair to the Project.
If an item with the same key already exists, its value is replaced by the new
@@ -85,8 +92,11 @@ def put(self, key: str, value: Any, *, note: Optional[str] = None):
The key to associate with ``value`` in the Project.
value : Any
The value to associate with ``key`` in the Project.
- note : str or None, optional
+ note : str, optional
A note to attach with the item.
+ display_as : str, optional
+ Used in combination with a string value, it customizes the way the value is
+ displayed in the interface.
Raises
------
@@ -99,14 +109,14 @@ def put(self, key: str, value: Any, *, note: Optional[str] = None):
if not isinstance(key, str):
raise TypeError(f"Key must be a string (found '{type(key)}')")
- item = object_to_item(value)
-
- if note is not None:
- if not isinstance(note, str):
- raise TypeError(f"Note must be a string (found '{type(note)}')")
- item.note = note
-
- self.item_repository.put_item(key, item)
+ self.item_repository.put_item(
+ key,
+ object_to_item(
+ value,
+ note=note,
+ display_as=display_as,
+ ),
+ )
def get(self, key: str) -> Any:
"""Get the value corresponding to ``key`` from the Project.
diff --git a/skore/tests/unit/project/test_display_as.py b/skore/tests/unit/project/test_display_as.py
new file mode 100644
index 000000000..50833bc5a
--- /dev/null
+++ b/skore/tests/unit/project/test_display_as.py
@@ -0,0 +1,35 @@
+import pytest
+from skore.persistence.item import MediaItem, PrimitiveItem
+
+
+@pytest.fixture(autouse=True)
+def monkeypatch_datetime(monkeypatch, MockDatetime):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+
+
+def test_str_without_display_as(in_memory_project, mock_nowstr):
+ in_memory_project.put("key", "")
+
+ item = in_memory_project.item_repository.get_item("key")
+
+ assert isinstance(item, PrimitiveItem)
+ assert item.primitive == ""
+
+
+def test_str_with_display_as(in_memory_project, mock_nowstr):
+ in_memory_project.put("key", "", display_as="MARKDOWN")
+
+ item = in_memory_project.item_repository.get_item("key")
+
+ assert isinstance(item, MediaItem)
+ assert item.media_bytes == b""
+ assert item.media_encoding == "utf-8"
+ assert item.media_type == "text/markdown"
+
+
+def test_exception(in_memory_project, mock_nowstr):
+ with pytest.raises(TypeError):
+ in_memory_project.put("key", 1, display_as="MARKDOWN")
+
+ with pytest.raises(ValueError):
+ in_memory_project.put("key", "", display_as="")
From daebfe6b1f590d5f580f2e2fcb7ab678a6bedc30 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Wed, 15 Jan 2025 10:22:25 +0100
Subject: [PATCH 15/22] Update examples
---
examples/model_evaluation/plot_cross_validate.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/examples/model_evaluation/plot_cross_validate.py b/examples/model_evaluation/plot_cross_validate.py
index a47b6b74f..cb0ab5e3c 100644
--- a/examples/model_evaluation/plot_cross_validate.py
+++ b/examples/model_evaluation/plot_cross_validate.py
@@ -155,15 +155,16 @@
reporter.plots.scores
# %%
-# We can also access the plot after we have stored the ``CrossValidationReporter``:
-my_project.put("cross_validation_regression", reporter)
-cv_item = my_project.get_item("cross_validation_regression") # TO CHANGE /!\
-cv_item.plots["Scores"]
+# We can put the reporter in the project, and retrieve it as is:
+my_project.put("cross_validation_reporter", reporter)
+
+reporter = my_project.get("cross_validation_reporter")
+reporter.plots.scores
# %%
# .. note::
#
-# If we put a cross-validation item in a skore project, we get some nice
+# If we put a cross-validation reporter in a skore project, we get some nice
# information in the UI:
#
# .. image:: https://media.githubusercontent.com/media/probabl-ai/skore/main/sphinx/_static/images/2024_12_12_skore_demo_comp.gif
From 3084bb039ca17d92f24b7ad32c23543181b1ff60 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Wed, 15 Jan 2025 12:48:59 +0100
Subject: [PATCH 16/22] [skip ci] refactorize `get_item_versions` to be item
agnostic
---
.../getting_started/plot_tracking_items.py | 23 +++----
skore/src/skore/project/project.py | 67 ++++++++++---------
skore/src/skore/ui/project_routes.py | 13 ++--
skore/tests/integration/ui/test_ui.py | 2 +-
skore/tests/unit/project/test_project.py | 65 +++++++++++-------
5 files changed, 93 insertions(+), 77 deletions(-)
diff --git a/examples/getting_started/plot_tracking_items.py b/examples/getting_started/plot_tracking_items.py
index ec8e1a5fb..9eefc83e2 100644
--- a/examples/getting_started/plot_tracking_items.py
+++ b/examples/getting_started/plot_tracking_items.py
@@ -73,17 +73,14 @@
# We retrieve the history of the ``my_int`` item:
# %%
-item_histories = my_project.get_item_versions("my_int") # TO CHANGE /!\
+history = my_project.get("my_int", latest=False, metadata=True)
# %%
# We can print the first history (first iteration) of this item:
# %%
-passed_item = item_histories[0]
-print(passed_item)
-print(passed_item.primitive)
-print(passed_item.created_at)
-print(passed_item.updated_at)
+
+print(history[0])
# %%
# Let us construct a dataframe with the values and last updated times:
@@ -92,13 +89,13 @@
import numpy as np
import pandas as pd
-list_primitive, list_created_at, list_updated_at = zip(
- *[(elem.primitive, elem.created_at, elem.updated_at) for elem in item_histories]
+list_value, list_created_at, list_updated_at = zip(
+ *[(version["value"], history[0]["date"], version["date"]) for version in history]
)
df_track = pd.DataFrame(
{
- "primitive": list_primitive,
+ "value": list_value,
"created_at": list_created_at,
"updated_at": list_updated_at,
}
@@ -111,9 +108,9 @@
# :language: python
#
# Notice that the ``created_at`` dates are the same for all iterations because they
-# correspond to the same item, but the ``updated_at`` dates are spaced by 0.1 second
-# (approximately) as we used :python:`time.sleep(0.1)` between each
-# :func:`~skore.Project.put`.
+# correspond to the date of the first version of the item, but the ``updated_at`` dates
+# are spaced by 0.1 second (approximately) as we used :python:`time.sleep(0.1)` between
+# each :func:`~skore.Project.put`.
# %%
# We can now track the value of the item over time:
@@ -124,7 +121,7 @@
fig = px.line(
df_track,
x="iteration_number",
- y="primitive",
+ y="value",
hover_data=df_track.columns,
markers=True,
)
diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py
index 365a7c97e..6271d8785 100644
--- a/skore/src/skore/project/project.py
+++ b/skore/src/skore/project/project.py
@@ -9,7 +9,6 @@
if TYPE_CHECKING:
from skore.persistence import (
- Item,
ItemRepository,
View,
ViewRepository,
@@ -118,55 +117,57 @@ def put(
),
)
- def get(self, key: str) -> Any:
- """Get the value corresponding to ``key`` from the Project.
-
- Parameters
- ----------
- key : str
- The key corresponding to the item to get.
-
- Raises
- ------
- KeyError
- If the key does not correspond to any item.
- """
- return item_to_object(self.item_repository.get_item(key))
-
- def get_item_versions(self, key: str) -> list[Item]:
- """
- Get all the versions of an item associated with ``key`` from the Project.
-
- The list is ordered from oldest to newest :func:`~skore.Project.put` date.
+ def get(self, key, *, latest=True, metadata=False):
+ """Get the value associated to ``key`` from the Project.
Parameters
----------
key : str
The key corresponding to the item to get.
+ latest : boolean, optional
+ Get the latest value or all the values associated to ``key``, default True.
+ metadata : boolean, optional
+ Get the metadata in addition of the value, default False.
Returns
-------
- list[Item]
- The list of items corresponding to ``key``.
+ value : any
+ Value associated to ``key``, when latest=True and metadata=False.
+ value_and_metadata : dict
+ Value associated to ``key`` with its metadata, when latest=True and metadata=True.
+ list_of_values : list[any]
+ Values associated to ``key``, ordered by date, when latest=False.
+ list_of_values_and_metadata : list[dict]
+ Values associated to ``key`` with their metadata, ordered by date, when
+ latest=False and metadata=False.
Raises
------
KeyError
- If the key does not correspond to any item.
+ If the key is not in the project.
"""
- return self.item_repository.get_item_versions(key)
+ if not metadata:
- def list_item_keys(self) -> list[str]:
- """List all item keys in the Project.
+ def dto(item):
+ return item_to_object(item)
- Returns
- -------
- list[str]
- The list of item keys. The list is empty if there is no item.
- """
+ else:
+
+ def dto(item):
+ return {
+ "value": item_to_object(item),
+ "date": item.updated_at,
+ "note": item.note,
+ }
+
+ if latest:
+ return dto(self.item_repository.get_item(key))
+ return list(map(dto, self.item_repository.get_item_versions(key)))
+
+ def keys(self) -> list[str]:
return self.item_repository.keys()
- def delete_item(self, key: str):
+ def delete(self, key: str):
"""Delete the item corresponding to ``key`` from the Project.
Parameters
diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py
index 79eff5ce6..e3ac728cb 100644
--- a/skore/src/skore/ui/project_routes.py
+++ b/skore/src/skore/ui/project_routes.py
@@ -49,12 +49,15 @@ def __item_as_serializable(name: str, item: Item) -> SerializableItem:
def __project_as_serializable(project: Project) -> SerializableProject:
items = {
key: [
- __item_as_serializable(key, item) for item in project.get_item_versions(key)
+ __item_as_serializable(key, item)
+ for item in project.item_repository.get_item_versions(key)
]
- for key in project.list_item_keys()
+ for key in project.item_repository.keys()
}
- views = {key: project.get_view(key).layout for key in project.list_view_keys()}
+ views = {
+ key: project.get_view(key).layout for key in project.view_repository.keys()
+ }
return SerializableProject(
items=items,
@@ -112,8 +115,8 @@ async def get_activity(
return sorted(
(
__item_as_serializable(key, version)
- for key in project.list_item_keys()
- for version in project.get_item_versions(key)
+ for key in project.item_repository.keys()
+ for version in project.item_repository.get_item_versions(key)
if datetime.fromisoformat(version.updated_at) > after
),
key=operator.attrgetter("updated_at"),
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index 44b15d7a3..9d9a481bd 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -33,7 +33,7 @@ def test_get_items(client, in_memory_project):
in_memory_project.put("test", "version_1")
in_memory_project.put("test", "version_2")
- items = in_memory_project.get_item_versions("test")
+ items = in_memory_project.item_repository.get_item_versions("test")
response = client.get("/api/project/items")
assert response.status_code == 200
diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py
index f5a842011..48b657534 100644
--- a/skore/tests/unit/project/test_project.py
+++ b/skore/tests/unit/project/test_project.py
@@ -19,6 +19,11 @@
from skore.project.create import _create, _validate_project_name
+@pytest.fixture(autouse=True)
+def monkeypatch_datetime(monkeypatch, MockDatetime):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+
+
def test_put_string_item(in_memory_project):
in_memory_project.put("string_item", "Hello, World!")
assert in_memory_project.get("string_item") == "Hello, World!"
@@ -145,19 +150,19 @@ def test_put(in_memory_project):
in_memory_project.put("key3", 3)
in_memory_project.put("key4", 4)
- assert in_memory_project.list_item_keys() == ["key1", "key2", "key3", "key4"]
+ assert in_memory_project.keys() == ["key1", "key2", "key3", "key4"]
def test_put_kwargs(in_memory_project):
in_memory_project.put(key="key1", value=1)
- assert in_memory_project.list_item_keys() == ["key1"]
+ assert in_memory_project.keys() == ["key1"]
def test_put_wrong_key_type(in_memory_project):
with pytest.raises(TypeError):
in_memory_project.put(key=2, value=1)
- assert in_memory_project.list_item_keys() == []
+ assert in_memory_project.keys() == []
def test_put_twice(in_memory_project):
@@ -167,39 +172,49 @@ def test_put_twice(in_memory_project):
assert in_memory_project.get("key2") == 5
-def test_get(in_memory_project):
- in_memory_project.put("key1", 1)
- assert in_memory_project.get("key1") == 1
-
- with pytest.raises(KeyError):
- in_memory_project.get("key2")
-
-
-def test_get_item_versions(in_memory_project):
- in_memory_project.put("key", 1)
- in_memory_project.put("key", 2)
+def test_get(in_memory_project, mock_nowstr):
+ in_memory_project.put("key", 1, note="1")
+ in_memory_project.put("key", 2, note="2")
- items = in_memory_project.get_item_versions("key")
+ assert in_memory_project.get("key") == 2
+ assert in_memory_project.get("key", latest=True, metadata=False) == 2
+ assert in_memory_project.get("key", latest=False, metadata=False) == [1, 2]
+ assert in_memory_project.get("key", latest=False, metadata=True) == [
+ {
+ "value": 1,
+ "date": mock_nowstr,
+ "note": "1",
+ },
+ {
+ "value": 2,
+ "date": mock_nowstr,
+ "note": "2",
+ },
+ ]
+ assert in_memory_project.get("key", latest=True, metadata=True) == {
+ "value": 2,
+ "date": mock_nowstr,
+ "note": "2",
+ }
- assert len(items) == 2
- assert items[0].primitive == 1
- assert items[1].primitive == 2
+ with pytest.raises(KeyError):
+ in_memory_project.get("")
def test_delete(in_memory_project):
in_memory_project.put("key1", 1)
- in_memory_project.delete_item("key1")
+ in_memory_project.delete("key1")
- assert in_memory_project.list_item_keys() == []
+ assert in_memory_project.keys() == []
with pytest.raises(KeyError):
- in_memory_project.delete_item("key2")
+ in_memory_project.delete("key2")
def test_keys(in_memory_project):
in_memory_project.put("key1", 1)
in_memory_project.put("key2", 2)
- assert in_memory_project.list_item_keys() == ["key1", "key2"]
+ assert in_memory_project.keys() == ["key1", "key2"]
def test_view(in_memory_project):
@@ -222,7 +237,7 @@ def test_put_several_complex(in_memory_project):
in_memory_project.put("a", int)
in_memory_project.put("b", float)
- assert in_memory_project.list_item_keys() == ["a", "b"]
+ assert in_memory_project.keys() == ["a", "b"]
def test_put_key_is_a_tuple(in_memory_project):
@@ -230,7 +245,7 @@ def test_put_key_is_a_tuple(in_memory_project):
with pytest.raises(TypeError):
in_memory_project.put(("a", "foo"), ("b", "bar"))
- assert in_memory_project.list_item_keys() == []
+ assert in_memory_project.keys() == []
def test_put_key_is_a_set(in_memory_project):
@@ -238,7 +253,7 @@ def test_put_key_is_a_set(in_memory_project):
with pytest.raises(TypeError):
in_memory_project.put(set(), "hello")
- assert in_memory_project.list_item_keys() == []
+ assert in_memory_project.keys() == []
def test_put_wrong_key_and_value_raise(in_memory_project):
From 602b8941b2d248d3c0f758f4ef0108d338f5b4fa Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Wed, 15 Jan 2025 13:03:17 +0100
Subject: [PATCH 17/22] [skip ci] Fix examples and sphinx
---
examples/getting_started/plot_working_with_projects.py | 8 ++++----
sphinx/api.rst | 2 --
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/examples/getting_started/plot_working_with_projects.py b/examples/getting_started/plot_working_with_projects.py
index 4897b4738..ef65603de 100644
--- a/examples/getting_started/plot_working_with_projects.py
+++ b/examples/getting_started/plot_working_with_projects.py
@@ -71,20 +71,20 @@
# see :ref:`example_tracking_items`.
# %%
-# By using the :func:`~skore.Project.delete_item` method, we can also delete an object
+# By using the :func:`~skore.Project.delete` method, we can also delete an object
# so that our skore UI does not become cluttered:
# %%
my_project.put("my_int_2", 10)
# %%
-my_project.delete_item("my_int_2")
+my_project.delete("my_int_2")
# %%
# We can display all the keys in our project:
# %%
-my_project.list_item_keys()
+my_project.keys()
# %%
# Storing strings and texts
@@ -127,7 +127,7 @@ def my_func(x):
my_project.put(
"my_string_3",
"Title
bold, italic, etc.",
- display_as="html",
+ display_as="HTML",
)
# %%
diff --git a/sphinx/api.rst b/sphinx/api.rst
index b5d5527d4..09a3bfd6a 100644
--- a/sphinx/api.rst
+++ b/sphinx/api.rst
@@ -20,7 +20,6 @@ These functions and classes are meant for managing a Project.
:caption: Managing a project
Project
- item.primitive_item.PrimitiveItem
open
Get assistance when developing ML/DS projects
@@ -35,7 +34,6 @@ These functions and classes enhance scikit-learn's ones.
train_test_split
CrossValidationReporter
- item.cross_validation_item.CrossValidationItem
Report for a single estimator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
From 71aec33a9c85dc78df37f8d24c3e920cdc1e878d Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Wed, 15 Jan 2025 14:59:28 +0100
Subject: [PATCH 18/22] [skip ci] Hide view API
---
skore/src/skore/project/create.py | 2 +-
skore/src/skore/project/project.py | 50 ------------------------
skore/src/skore/ui/project_routes.py | 7 ++--
skore/tests/integration/ui/test_ui.py | 2 +-
skore/tests/unit/project/test_project.py | 17 --------
5 files changed, 6 insertions(+), 72 deletions(-)
diff --git a/skore/src/skore/project/create.py b/skore/src/skore/project/create.py
index e5a7e92de..fa6eac7f2 100644
--- a/skore/src/skore/project/create.py
+++ b/skore/src/skore/project/create.py
@@ -143,7 +143,7 @@ def _create(
) from e
p = _load(project_directory)
- p.put_view("default", View(layout=[]))
+ p.view_repository.put_view("default", View(layout=[]))
console.rule("[bold cyan]skore[/bold cyan]")
console.print(f"Project file '{project_directory}' was successfully created.")
diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py
index 6271d8785..acba73822 100644
--- a/skore/src/skore/project/project.py
+++ b/skore/src/skore/project/project.py
@@ -10,7 +10,6 @@
if TYPE_CHECKING:
from skore.persistence import (
ItemRepository,
- View,
ViewRepository,
)
@@ -182,55 +181,6 @@ def delete(self, key: str):
"""
self.item_repository.delete_item(key)
- def put_view(self, key: str, view: View):
- """Add a view to the Project."""
- self.view_repository.put_view(key, view)
-
- def get_view(self, key: str) -> View:
- """Get the view corresponding to ``key`` from the Project.
-
- Parameters
- ----------
- key : str
- The key of the item to get.
-
- Returns
- -------
- View
- The view corresponding to ``key``.
-
- Raises
- ------
- KeyError
- If the key does not correspond to any view.
- """
- return self.view_repository.get_view(key)
-
- def delete_view(self, key: str):
- """Delete the view corresponding to ``key`` from the Project.
-
- Parameters
- ----------
- key : str
- The key corresponding to the view to delete.
-
- Raises
- ------
- KeyError
- If the key does not correspond to any view.
- """
- return self.view_repository.delete_view(key)
-
- def list_view_keys(self) -> list[str]:
- """List all view keys in the Project.
-
- Returns
- -------
- list[str]
- The list of view keys. The list is empty if there is no view.
- """
- return self.view_repository.keys()
-
def set_note(self, key: str, message: str, *, version=-1):
"""Attach a note to key ``key``.
diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py
index e3ac728cb..5606bbd39 100644
--- a/skore/src/skore/ui/project_routes.py
+++ b/skore/src/skore/ui/project_routes.py
@@ -56,7 +56,8 @@ def __project_as_serializable(project: Project) -> SerializableProject:
}
views = {
- key: project.get_view(key).layout for key in project.view_repository.keys()
+ key: project.view_repository.get_view(key).layout
+ for key in project.view_repository.keys()
}
return SerializableProject(
@@ -81,7 +82,7 @@ async def put_view(request: Request, key: str, layout: Layout):
project: Project = request.app.state.project
view = View(layout=layout)
- project.put_view(key, view)
+ project.view_repository.put_view(key, view)
return __project_as_serializable(project)
@@ -92,7 +93,7 @@ async def delete_view(request: Request, key: str):
project: Project = request.app.state.project
try:
- project.delete_view(key)
+ project.view_repository.delete_view(key)
except KeyError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="View not found"
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index 9d9a481bd..c5170a958 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -60,7 +60,7 @@ def test_put_view_layout(client):
def test_delete_view(client, in_memory_project):
- in_memory_project.put_view("hello", View(layout=[]))
+ in_memory_project.view_repository.put_view("hello", View(layout=[]))
response = client.delete("/api/project/views?key=hello")
assert response.status_code == 202
diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py
index 48b657534..cf2f7225e 100644
--- a/skore/tests/unit/project/test_project.py
+++ b/skore/tests/unit/project/test_project.py
@@ -15,7 +15,6 @@
InvalidProjectNameError,
ProjectCreationError,
)
-from skore.persistence.view.view import View
from skore.project.create import _create, _validate_project_name
@@ -217,22 +216,6 @@ def test_keys(in_memory_project):
assert in_memory_project.keys() == ["key1", "key2"]
-def test_view(in_memory_project):
- layout = ["key1", "key2"]
-
- view = View(layout=layout)
-
- in_memory_project.put_view("view", view)
- assert in_memory_project.get_view("view") == view
-
-
-def test_list_view_keys(in_memory_project):
- view = View(layout=[])
-
- in_memory_project.put_view("view", view)
- assert in_memory_project.list_view_keys() == ["view"]
-
-
def test_put_several_complex(in_memory_project):
in_memory_project.put("a", int)
in_memory_project.put("b", float)
From a85d2f18b06130c95e3245f2c6dc9f9e66ed206c Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Wed, 15 Jan 2025 15:10:04 +0100
Subject: [PATCH 19/22] [skip ci] Fix linter
---
.../persistence/repository/item_repository.py | 14 ++++++++++-
.../persistence/repository/view_repository.py | 16 +++++++++++--
skore/src/skore/project/project.py | 23 ++++++++++++++++++-
skore/src/skore/ui/project_routes.py | 6 ++---
4 files changed, 52 insertions(+), 7 deletions(-)
diff --git a/skore/src/skore/persistence/repository/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py
index 601c7b20b..97e29e174 100644
--- a/skore/src/skore/persistence/repository/item_repository.py
+++ b/skore/src/skore/persistence/repository/item_repository.py
@@ -6,6 +6,7 @@
from __future__ import annotations
+from collections.abc import Iterator
from typing import TYPE_CHECKING, Union
import skore.persistence.item
@@ -129,10 +130,21 @@ def keys(self) -> list[str]:
Returns
-------
list[str]
- A list of all keys in the storage.
+ A list of all keys.
"""
return list(self.storage.keys())
+ def __iter__(self) -> Iterator[str]:
+ """
+ Yield the keys of items stored in the repository.
+
+ Returns
+ -------
+ Iterator[str]
+ An iterator yielding all keys.
+ """
+ yield from self.storage
+
def set_item_note(self, key: str, message: str, *, version=-1):
"""Attach a note to key ``key``.
diff --git a/skore/src/skore/persistence/repository/view_repository.py b/skore/src/skore/persistence/repository/view_repository.py
index 8ec7dcfc9..02f42baa8 100644
--- a/skore/src/skore/persistence/repository/view_repository.py
+++ b/skore/src/skore/persistence/repository/view_repository.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+from collections.abc import Iterator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -65,11 +66,22 @@ def delete_view(self, key: str):
def keys(self) -> list[str]:
"""
- Get all keys of items stored in the repository.
+ Get all keys of views stored in the repository.
Returns
-------
list[str]
- A list of all keys in the storage.
+ A list of all keys.
"""
return list(self.storage.keys())
+
+ def __iter__(self) -> Iterator[str]:
+ """
+ Yield the keys of views stored in the repository.
+
+ Returns
+ -------
+ Iterator[str]
+ An iterator yielding all keys.
+ """
+ yield from self.storage
diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py
index acba73822..d485e3f72 100644
--- a/skore/src/skore/project/project.py
+++ b/skore/src/skore/project/project.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
+from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from skore.persistence.item import item_to_object, object_to_item
@@ -133,7 +134,8 @@ def get(self, key, *, latest=True, metadata=False):
value : any
Value associated to ``key``, when latest=True and metadata=False.
value_and_metadata : dict
- Value associated to ``key`` with its metadata, when latest=True and metadata=True.
+ Value associated to ``key`` with its metadata, when latest=True and
+ metadata=True.
list_of_values : list[any]
Values associated to ``key``, ordered by date, when latest=False.
list_of_values_and_metadata : list[dict]
@@ -164,8 +166,27 @@ def dto(item):
return list(map(dto, self.item_repository.get_item_versions(key)))
def keys(self) -> list[str]:
+ """
+ Get all keys of items stored in the project.
+
+ Returns
+ -------
+ list[str]
+ A list of all keys.
+ """
return self.item_repository.keys()
+ def __iter__(self) -> Iterator[str]:
+ """
+ Yield the keys of items stored in the project.
+
+ Returns
+ -------
+ Iterator[str]
+ An iterator yielding all keys.
+ """
+ yield from self.item_repository
+
def delete(self, key: str):
"""Delete the item corresponding to ``key`` from the Project.
diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py
index 5606bbd39..04d1e88e1 100644
--- a/skore/src/skore/ui/project_routes.py
+++ b/skore/src/skore/ui/project_routes.py
@@ -52,12 +52,12 @@ def __project_as_serializable(project: Project) -> SerializableProject:
__item_as_serializable(key, item)
for item in project.item_repository.get_item_versions(key)
]
- for key in project.item_repository.keys()
+ for key in project.item_repository
}
views = {
key: project.view_repository.get_view(key).layout
- for key in project.view_repository.keys()
+ for key in project.view_repository
}
return SerializableProject(
@@ -116,7 +116,7 @@ async def get_activity(
return sorted(
(
__item_as_serializable(key, version)
- for key in project.item_repository.keys()
+ for key in project.item_repository
for version in project.item_repository.get_item_versions(key)
if datetime.fromisoformat(version.updated_at) > after
),
From fe9c497f613ae7403d16c3f311ba384871bb4f2e Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Thu, 16 Jan 2025 15:41:32 +0100
Subject: [PATCH 20/22] [skip ci] Leave pillow from MediaItem to
PillowImageItem
---
skore/src/skore/persistence/item/__init__.py | 5 +
.../item/cross_validation_reporter_item.py | 10 +-
.../src/skore/persistence/item/media_item.py | 54 ++--------
.../src/skore/persistence/item/pickle_item.py | 3 +-
.../persistence/item/pillow_image_item.py | 102 ++++++++++++++++++
skore/tests/integration/ui/test_ui.py | 48 +++++++--
skore/tests/unit/item/test_media_item.py | 18 ----
.../tests/unit/item/test_pillow_image_item.py | 56 ++++++++++
skore/tests/unit/project/test_project.py | 18 ++--
9 files changed, 222 insertions(+), 92 deletions(-)
create mode 100644 skore/src/skore/persistence/item/pillow_image_item.py
create mode 100644 skore/tests/unit/item/test_pillow_image_item.py
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index aa9ea9319..de24446be 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -13,6 +13,7 @@
from .pandas_dataframe_item import PandasDataFrameItem
from .pandas_series_item import PandasSeriesItem
from .pickle_item import PickleItem
+from .pillow_image_item import PillowImageItem
from .polars_dataframe_item import PolarsDataFrameItem
from .polars_series_item import PolarsSeriesItem
from .primitive_item import PrimitiveItem
@@ -50,6 +51,7 @@ def object_to_item(
MediaItem,
SkrubTableReportItem,
CrossValidationReporterItem,
+ PillowImageItem,
):
with suppress(ImportError, ItemTypeError):
# ImportError:
@@ -91,6 +93,8 @@ def item_to_object(item: Item) -> Any:
return item.reporter
elif isinstance(item, MediaItem):
return item.media_bytes
+ elif isinstance(item, PillowImageItem):
+ return item.image
elif isinstance(item, PickleItem):
return item.object
else:
@@ -105,6 +109,7 @@ def item_to_object(item: Item) -> Any:
"PandasDataFrameItem",
"PandasSeriesItem",
"PickleItem",
+ "PillowImageItem",
"PolarsDataFrameItem",
"PolarsSeriesItem",
"PrimitiveItem",
diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
index ecad01237..775133e69 100644
--- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py
+++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py
@@ -14,7 +14,6 @@
import pickle
import re
import statistics
-from functools import cached_property
from typing import TYPE_CHECKING, Literal, Optional, TypedDict
import numpy
@@ -204,14 +203,9 @@ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationReporterIt
if not isinstance(reporter, CrossValidationReporter):
raise ItemTypeError(f"Type '{reporter.__class__}' is not supported.")
- instance = cls(pickle.dumps(reporter))
+ return cls(pickle.dumps(reporter))
- # add reporter as cached property
- instance.reporter = reporter
-
- return instance
-
- @cached_property
+ @property
def reporter(self) -> CrossValidationReporter:
"""The CrossValidationReporter from the persistence."""
return pickle.loads(self.reporter_bytes)
diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py
index e57434743..e2a5cfb0a 100644
--- a/skore/src/skore/persistence/item/media_item.py
+++ b/skore/src/skore/persistence/item/media_item.py
@@ -1,6 +1,6 @@
"""MediaItem.
-This module defines the MediaItem class, which represents media items.
+This module defines the MediaItem class, used to persist media.
"""
from __future__ import annotations
@@ -15,12 +15,11 @@
if TYPE_CHECKING:
from altair.vegalite.v5.schema.core import TopLevelSpec as Altair
from matplotlib.figure import Figure as Matplotlib
- from PIL.Image import Image as Pillow
from plotly.basedatatypes import BaseFigure as Plotly
def lazy_is_instance(object: Any, cls_fullname: str) -> bool:
- """Return True if object is an instance of a class named `cls_fullname`."""
+ """Return True if object is an instance of `cls_fullname`."""
return cls_fullname in {
f"{cls.__module__}.{cls.__name__}" for cls in object.__class__.__mro__
}
@@ -65,8 +64,8 @@ def __init__(
The creation timestamp in ISO format.
updated_at : str, optional
The last update timestamp in ISO format.
- note : Union[str, None]
- An optional note.
+ note : str, optional
+ A note.
"""
super().__init__(created_at, updated_at, note)
@@ -75,12 +74,7 @@ def __init__(
self.media_type = media_type
def as_serializable_dict(self):
- """Get a serializable dict from the item.
-
- Derived class must call their super implementation
- and merge the result with their output.
- """
- d = super().as_serializable_dict()
+ """Return item as a JSONable dict to export to frontend."""
if "text" in self.media_type:
value = self.media_bytes.decode(encoding=self.media_encoding)
media_type = f"{self.media_type}"
@@ -88,13 +82,10 @@ def as_serializable_dict(self):
value = base64.b64encode(self.media_bytes).decode()
media_type = f"{self.media_type};base64"
- d.update(
- {
- "media_type": media_type,
- "value": value,
- }
- )
- return d
+ return super().as_serializable_dict() | {
+ "media_type": media_type,
+ "value": value,
+ }
@classmethod
def factory(cls, media, *args, **kwargs):
@@ -127,8 +118,6 @@ def factory(cls, media, *args, **kwargs):
return cls.factory_altair(media, *args, **kwargs)
if lazy_is_instance(media, "matplotlib.figure.Figure"):
return cls.factory_matplotlib(media, *args, **kwargs)
- if lazy_is_instance(media, "PIL.Image.Image"):
- return cls.factory_pillow(media, *args, **kwargs)
if lazy_is_instance(media, "plotly.basedatatypes.BaseFigure"):
return cls.factory_plotly(media, *args, **kwargs)
@@ -237,31 +226,6 @@ def factory_matplotlib(cls, media: Matplotlib) -> MediaItem:
media_type="image/svg+xml",
)
- @classmethod
- def factory_pillow(cls, media: Pillow) -> MediaItem:
- """
- Create a new MediaItem instance from a Pillow image.
-
- Parameters
- ----------
- media : Pillow
- The Pillow image to store.
-
- Returns
- -------
- MediaItem
- A new MediaItem instance.
- """
- with BytesIO() as stream:
- media.save(stream, format="png")
- media_bytes = stream.getvalue()
-
- return cls(
- media_bytes=media_bytes,
- media_encoding="utf-8",
- media_type="image/png",
- )
-
@classmethod
def factory_plotly(cls, media: Plotly) -> MediaItem:
"""
diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py
index d2678f545..a80fea3de 100644
--- a/skore/src/skore/persistence/item/pickle_item.py
+++ b/skore/src/skore/persistence/item/pickle_item.py
@@ -6,7 +6,6 @@
from __future__ import annotations
-from functools import cached_property
from pickle import dumps, loads
from typing import Any, Optional
@@ -47,7 +46,7 @@ def __init__(
self.pickle_bytes = pickle_bytes
- @cached_property
+ @property
def object(self) -> Any:
"""The object from the persistence."""
return loads(self.pickle_bytes)
diff --git a/skore/src/skore/persistence/item/pillow_image_item.py b/skore/src/skore/persistence/item/pillow_image_item.py
new file mode 100644
index 000000000..c67cde339
--- /dev/null
+++ b/skore/src/skore/persistence/item/pillow_image_item.py
@@ -0,0 +1,102 @@
+"""PillowImageItem.
+
+This module defines the PillowImageItem class, used to persist Pillow images.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+from .item import Item, ItemTypeError
+from .media_item import lazy_is_instance
+
+if TYPE_CHECKING:
+ import PIL.Image
+
+
+class PillowImageItem(Item):
+ """A class used to persist a Pillow image."""
+
+ def __init__(
+ self,
+ image_bytes: bytes,
+ image_mode: str,
+ image_size: tuple[int],
+ created_at: Optional[str] = None,
+ updated_at: Optional[str] = None,
+ note: Optional[str] = None,
+ ):
+ """
+ Initialize a PillowImageItem.
+
+ Parameters
+ ----------
+ image_bytes : bytes
+ The raw bytes of the Pillow image.
+ image_mode : str
+ The image mode.
+ image_size : tuple[int]
+ The image size.
+ created_at : str, optional
+ The creation timestamp in ISO format.
+ updated_at : str, optional
+ The last update timestamp in ISO format.
+ note : str, optional
+ A note.
+ """
+ super().__init__(created_at, updated_at, note)
+
+ self.image_bytes = image_bytes
+ self.image_mode = image_mode
+ self.image_size = image_size
+
+ @classmethod
+ def factory(cls, image: PIL.Image.Image) -> PillowImageItem:
+ """
+ Create a new PillowImageItem instance from a Pillow image.
+
+ Parameters
+ ----------
+ image : PIL.Image.Image
+ The Pillow image to store.
+
+ Returns
+ -------
+ PillowImageItem
+ A new PillowImageItem instance.
+ """
+ if not lazy_is_instance(image, "PIL.Image.Image"):
+ raise ItemTypeError(f"Type '{image.__class__}' is not supported.")
+
+ return cls(
+ image_bytes=image.tobytes(),
+ image_mode=image.mode,
+ image_size=image.size,
+ )
+
+ @property
+ def image(self) -> PIL.Image.Image:
+ """The image from the persistence."""
+ import PIL.Image
+
+ return PIL.Image.frombytes(
+ mode=self.image_mode,
+ size=self.image_size,
+ data=self.image_bytes,
+ )
+
+ def as_serializable_dict(self):
+ """Return item as a JSONable dict to export to frontend."""
+ import base64
+ import io
+
+ with io.BytesIO() as stream:
+ self.image.save(stream, format="png")
+
+ png_bytes = stream.getvalue()
+ png_bytes_b64 = base64.b64encode(png_bytes).decode()
+
+ return super().as_serializable_dict() | {
+ "media_type": "image/png;base64",
+ "value": png_bytes_b64,
+ }
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index c5170a958..b9470a50d 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -1,4 +1,6 @@
+import base64
import datetime
+import io
import json
import numpy
@@ -20,6 +22,11 @@ def client(in_memory_project):
return TestClient(app=create_app(project=in_memory_project))
+@pytest.fixture
+def monkeypatch_datetime(monkeypatch, MockDatetime):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+
+
def test_app_state(client):
assert client.app.state.project is not None
@@ -130,19 +137,42 @@ def test_serialize_sklearn_estimator(client, in_memory_project):
assert project["items"]["estimator"][0]["value"] is not None
-def test_serialize_media_item(client, in_memory_project):
- imarray = numpy.random.rand(100, 100, 3) * 255
- img = Image.fromarray(imarray.astype("uint8")).convert("RGBA")
- in_memory_project.put("img", img)
+def test_serialize_pillow_item(
+ client,
+ in_memory_project,
+ monkeypatch_datetime,
+ mock_nowstr,
+):
+ image_array = numpy.random.rand(100, 100, 3) * 255
+ image = Image.fromarray(image_array.astype("uint8")).convert("RGBA")
+
+ with io.BytesIO() as stream:
+ image.save(stream, format="png")
- html = "éપUœALDXIWDŸΩΩ
"
- in_memory_project.put("html", html)
+ png_bytes = stream.getvalue()
+ png_bytes_b64 = base64.b64encode(png_bytes).decode()
+ in_memory_project.put("image", image)
response = client.get("/api/project/items")
+
assert response.status_code == 200
- project = response.json()
- assert "image" in project["items"]["img"][0]["media_type"]
- assert project["items"]["html"][0]["value"] == html
+ assert response.json() == {
+ "views": {},
+ "items": {
+ "image": [
+ {
+ "name": "image",
+ "media_type": "image/png;base64",
+ "value": png_bytes_b64,
+ "updated_at": mock_nowstr,
+ "created_at": mock_nowstr,
+ }
+ ]
+ },
+ }
+
+
+def test_serialize_media_item(client, in_memory_project): ...
@pytest.fixture
diff --git a/skore/tests/unit/item/test_media_item.py b/skore/tests/unit/item/test_media_item.py
index 320da8d97..0794b50bc 100644
--- a/skore/tests/unit/item/test_media_item.py
+++ b/skore/tests/unit/item/test_media_item.py
@@ -1,8 +1,5 @@
-import io
-
import altair
import matplotlib.pyplot
-import PIL as pillow
import plotly.graph_objects as go
import pytest
from skore.persistence.item import ItemTypeError, MediaItem
@@ -62,21 +59,6 @@ def test_factory_matplotlib(self, mock_nowstr):
assert item.created_at == mock_nowstr
assert item.updated_at == mock_nowstr
- def test_factory_pillow(self, mock_nowstr):
- image = pillow.Image.new("RGB", (100, 100), color="red")
-
- with io.BytesIO() as stream:
- image.save(stream, format="png")
- image_bytes = stream.getvalue()
-
- item = MediaItem.factory(image)
-
- assert item.media_bytes == image_bytes
- assert item.media_encoding == "utf-8"
- assert item.media_type == "image/png"
- assert item.created_at == mock_nowstr
- assert item.updated_at == mock_nowstr
-
def test_factory_plotly(self, mock_nowstr):
figure = go.Figure(data=[go.Bar(x=[1, 2, 3], y=[1, 3, 2])])
figure_bytes = figure.to_json().encode("utf-8")
diff --git a/skore/tests/unit/item/test_pillow_image_item.py b/skore/tests/unit/item/test_pillow_image_item.py
new file mode 100644
index 000000000..3e43fa94b
--- /dev/null
+++ b/skore/tests/unit/item/test_pillow_image_item.py
@@ -0,0 +1,56 @@
+import base64
+import io
+
+import PIL.Image
+import pytest
+from skore.persistence.item import ItemTypeError, PillowImageItem
+
+
+class TestPillowImageItem:
+ @pytest.fixture(autouse=True)
+ def monkeypatch_datetime(self, monkeypatch, MockDatetime):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+
+ def test_factory(self, mock_nowstr):
+ image = PIL.Image.new("RGB", (100, 100), color="red")
+ item = PillowImageItem.factory(image)
+
+ assert item.image_bytes == image.tobytes()
+ assert item.image_mode == image.mode
+ assert item.image_size == image.size
+ assert item.created_at == mock_nowstr
+ assert item.updated_at == mock_nowstr
+
+ def test_factory_exception(self):
+ with pytest.raises(ItemTypeError):
+ PillowImageItem.factory(None)
+
+ def test_image(self):
+ image = PIL.Image.new("RGB", (100, 100), color="red")
+ item1 = PillowImageItem.factory(image)
+ item2 = PillowImageItem(
+ image_bytes=image.tobytes(),
+ image_mode=image.mode,
+ image_size=image.size,
+ )
+
+ assert item1.image == image
+ assert item2.image == image
+
+ def test_as_serializable_dict(self, mock_nowstr):
+ image = PIL.Image.new("RGB", (100, 100), color="red")
+ item = PillowImageItem.factory(image)
+
+ with io.BytesIO() as stream:
+ image.save(stream, format="png")
+
+ png_bytes = stream.getvalue()
+ png_bytes_b64 = base64.b64encode(png_bytes).decode()
+
+ assert item.as_serializable_dict() == {
+ "updated_at": mock_nowstr,
+ "created_at": mock_nowstr,
+ "note": None,
+ "media_type": "image/png;base64",
+ "value": png_bytes_b64,
+ }
diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py
index cf2f7225e..eb9790169 100644
--- a/skore/tests/unit/project/test_project.py
+++ b/skore/tests/unit/project/test_project.py
@@ -1,5 +1,3 @@
-from io import BytesIO
-
import altair
import numpy
import numpy.testing
@@ -124,14 +122,14 @@ def test_put_vega_chart(in_memory_project):
def test_put_pil_image(in_memory_project):
- # Add a PIL Image
- pil_image = Image.new("RGB", (100, 100), color="red")
- with BytesIO() as output:
- # FIXME: Not JPEG!
- pil_image.save(output, format="jpeg")
-
- in_memory_project.put("pil_image", pil_image) # MediaItem (PNG)
- assert isinstance(in_memory_project.get("pil_image"), bytes)
+ image1 = Image.new("RGB", (100, 100), color="red")
+ image2 = Image.new("RGBA", (150, 150), color="blue")
+
+ in_memory_project.put("image1", image1)
+ in_memory_project.put("image2", image2)
+
+ assert in_memory_project.get("image1") == image1
+ assert in_memory_project.get("image2") == image2
def test_put_rf_model(in_memory_project, monkeypatch):
From 4550767b1b8bd367caf2c3d7f456dc34895bdcbd Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Thu, 16 Jan 2025 15:59:02 +0100
Subject: [PATCH 21/22] [skip ci] Allow to pass created_at/updated_at/note via
factory to constructor in PillowImageItem and PickleItem
---
skore/src/skore/persistence/item/pickle_item.py | 4 ++--
skore/src/skore/persistence/item/pillow_image_item.py | 3 ++-
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py
index a80fea3de..1074f474c 100644
--- a/skore/src/skore/persistence/item/pickle_item.py
+++ b/skore/src/skore/persistence/item/pickle_item.py
@@ -52,7 +52,7 @@ def object(self) -> Any:
return loads(self.pickle_bytes)
@classmethod
- def factory(cls, object: Any) -> PickleItem:
+ def factory(cls, object: Any, /, **kwargs) -> PickleItem:
"""
Create a new PickleItem with any object.
@@ -66,7 +66,7 @@ def factory(cls, object: Any) -> PickleItem:
PickleItem
A new PickleItem instance.
"""
- return cls(dumps(object))
+ return cls(dumps(object), **kwargs)
def as_serializable_dict(self):
"""Get a JSON serializable representation of the item."""
diff --git a/skore/src/skore/persistence/item/pillow_image_item.py b/skore/src/skore/persistence/item/pillow_image_item.py
index c67cde339..20f53a5f8 100644
--- a/skore/src/skore/persistence/item/pillow_image_item.py
+++ b/skore/src/skore/persistence/item/pillow_image_item.py
@@ -51,7 +51,7 @@ def __init__(
self.image_size = image_size
@classmethod
- def factory(cls, image: PIL.Image.Image) -> PillowImageItem:
+ def factory(cls, image: PIL.Image.Image, /, **kwargs) -> PillowImageItem:
"""
Create a new PillowImageItem instance from a Pillow image.
@@ -72,6 +72,7 @@ def factory(cls, image: PIL.Image.Image) -> PillowImageItem:
image_bytes=image.tobytes(),
image_mode=image.mode,
image_size=image.size,
+ **kwargs,
)
@property
From 42d9b86169faa132b02047cb2294c5d6f2b33257 Mon Sep 17 00:00:00 2001
From: Thomas S
Date: Thu, 16 Jan 2025 17:11:18 +0100
Subject: [PATCH 22/22] [skip ci] Leave plotly from MediaItem to
PlotlyFigureItem
---
skore/src/skore/persistence/item/__init__.py | 7 ++
.../src/skore/persistence/item/media_item.py | 28 -------
.../persistence/item/plotly_figure_item.py | 84 +++++++++++++++++++
skore/tests/integration/ui/test_ui.py | 32 +++++++
skore/tests/unit/item/test_media_item.py | 13 ---
.../unit/item/test_plotly_figure_item.py | 51 +++++++++++
skore/tests/unit/project/test_project.py | 12 ++-
7 files changed, 185 insertions(+), 42 deletions(-)
create mode 100644 skore/src/skore/persistence/item/plotly_figure_item.py
create mode 100644 skore/tests/unit/item/test_plotly_figure_item.py
diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py
index de24446be..3ed8abcd5 100644
--- a/skore/src/skore/persistence/item/__init__.py
+++ b/skore/src/skore/persistence/item/__init__.py
@@ -14,6 +14,7 @@
from .pandas_series_item import PandasSeriesItem
from .pickle_item import PickleItem
from .pillow_image_item import PillowImageItem
+from .plotly_figure_item import PlotlyFigureItem
from .polars_dataframe_item import PolarsDataFrameItem
from .polars_series_item import PolarsSeriesItem
from .primitive_item import PrimitiveItem
@@ -52,6 +53,7 @@ def object_to_item(
SkrubTableReportItem,
CrossValidationReporterItem,
PillowImageItem,
+ PlotlyFigureItem,
):
with suppress(ImportError, ItemTypeError):
# ImportError:
@@ -74,6 +76,8 @@ def object_to_item(
# `note` attribute dynamically.
item.note = note
+ # -> to change in each class
+
return item
@@ -95,6 +99,8 @@ def item_to_object(item: Item) -> Any:
return item.media_bytes
elif isinstance(item, PillowImageItem):
return item.image
+ elif isinstance(item, PlotlyFigureItem):
+ return item.figure
elif isinstance(item, PickleItem):
return item.object
else:
@@ -110,6 +116,7 @@ def item_to_object(item: Item) -> Any:
"PandasSeriesItem",
"PickleItem",
"PillowImageItem",
+ "PlotlyFigureItem",
"PolarsDataFrameItem",
"PolarsSeriesItem",
"PrimitiveItem",
diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py
index e2a5cfb0a..0ebb8858f 100644
--- a/skore/src/skore/persistence/item/media_item.py
+++ b/skore/src/skore/persistence/item/media_item.py
@@ -15,7 +15,6 @@
if TYPE_CHECKING:
from altair.vegalite.v5.schema.core import TopLevelSpec as Altair
from matplotlib.figure import Figure as Matplotlib
- from plotly.basedatatypes import BaseFigure as Plotly
def lazy_is_instance(object: Any, cls_fullname: str) -> bool:
@@ -118,8 +117,6 @@ def factory(cls, media, *args, **kwargs):
return cls.factory_altair(media, *args, **kwargs)
if lazy_is_instance(media, "matplotlib.figure.Figure"):
return cls.factory_matplotlib(media, *args, **kwargs)
- if lazy_is_instance(media, "plotly.basedatatypes.BaseFigure"):
- return cls.factory_plotly(media, *args, **kwargs)
raise ItemTypeError(f"Type '{media.__class__}' is not supported.")
@@ -225,28 +222,3 @@ def factory_matplotlib(cls, media: Matplotlib) -> MediaItem:
media_encoding="utf-8",
media_type="image/svg+xml",
)
-
- @classmethod
- def factory_plotly(cls, media: Plotly) -> MediaItem:
- """
- Create a new MediaItem instance from a Plotly figure.
-
- Parameters
- ----------
- media : Plotly
- The Plotly figure to store.
-
- Returns
- -------
- MediaItem
- A new MediaItem instance.
- """
- import plotly.io
-
- media_bytes = plotly.io.to_json(media, engine="json").encode("utf-8")
-
- return cls(
- media_bytes=media_bytes,
- media_encoding="utf-8",
- media_type="application/vnd.plotly.v1+json",
- )
diff --git a/skore/src/skore/persistence/item/plotly_figure_item.py b/skore/src/skore/persistence/item/plotly_figure_item.py
new file mode 100644
index 000000000..a705ca48a
--- /dev/null
+++ b/skore/src/skore/persistence/item/plotly_figure_item.py
@@ -0,0 +1,84 @@
+"""PlotlyFigureItem.
+
+This module defines the PlotlyFigureItem class, used to persist Ploty figures.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+from .item import Item, ItemTypeError
+from .media_item import lazy_is_instance
+
+if TYPE_CHECKING:
+ import plotly.basedatatypes as plotly
+
+
+class PlotlyFigureItem(Item):
+ """A class used to persist a Plotly figure."""
+
+ def __init__(
+ self,
+ figure_str: str,
+ created_at: Optional[str] = None,
+ updated_at: Optional[str] = None,
+ note: Optional[str] = None,
+ ):
+ """
+ Initialize a PlotlyFigureItem.
+
+ Parameters
+ ----------
+ figure_str : bytes
+ The JSON str of the Plotly figure.
+ created_at : str, optional
+ The creation timestamp in ISO format.
+ updated_at : str, optional
+ The last update timestamp in ISO format.
+ note : str, optional
+ A note.
+ """
+ super().__init__(created_at, updated_at, note)
+
+ self.figure_str = figure_str
+
+ @classmethod
+ def factory(cls, figure: plotly.BaseFigure, /, **kwargs) -> PlotlyFigureItem:
+ """
+ Create a new PlotlyFigureItem instance from a Plotly figure.
+
+ Parameters
+ ----------
+ figure : plotly.basedatatypes.BaseFigure
+ The Plotly figure to store.
+
+ Returns
+ -------
+ PlotlyFigureItem
+ A new PlotlyFigureItem instance.
+ """
+ if not lazy_is_instance(figure, "plotly.basedatatypes.BaseFigure"):
+ raise ItemTypeError(f"Type '{figure.__class__}' is not supported.")
+
+ import plotly.io
+
+ return cls(plotly.io.to_json(figure, engine="json"), **kwargs)
+
+ @property
+ def figure(self) -> plotly.BaseFigure:
+ """The figure from the persistence."""
+ import plotly.io
+
+ return plotly.io.from_json(self.figure_str)
+
+ def as_serializable_dict(self):
+ """Return item as a JSONable dict to export to frontend."""
+ import base64
+
+ figure_bytes = self.figure_str.encode("utf-8")
+ figure_bytes_b64 = base64.b64encode(figure_bytes).decode()
+
+ return super().as_serializable_dict() | {
+ "media_type": "application/vnd.plotly.v1+json;base64",
+ "value": figure_bytes_b64,
+ }
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index b9470a50d..d0a02d57d 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -172,6 +172,38 @@ def test_serialize_pillow_item(
}
+def test_serialize_plotly_item(
+ client,
+ in_memory_project,
+ monkeypatch_datetime,
+ mock_nowstr,
+):
+ bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2])
+ figure = plotly.graph_objects.Figure(data=[bar])
+ figure_str = plotly.io.to_json(figure, engine="json")
+ figure_bytes = figure_str.encode("utf-8")
+ figure_bytes_b64 = base64.b64encode(figure_bytes).decode()
+
+ in_memory_project.put("figure", figure)
+ response = client.get("/api/project/items")
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "views": {},
+ "items": {
+ "figure": [
+ {
+ "name": "figure",
+ "media_type": "application/vnd.plotly.v1+json;base64",
+ "value": figure_bytes_b64,
+ "updated_at": mock_nowstr,
+ "created_at": mock_nowstr,
+ }
+ ]
+ },
+ }
+
+
def test_serialize_media_item(client, in_memory_project): ...
diff --git a/skore/tests/unit/item/test_media_item.py b/skore/tests/unit/item/test_media_item.py
index 0794b50bc..0db89c082 100644
--- a/skore/tests/unit/item/test_media_item.py
+++ b/skore/tests/unit/item/test_media_item.py
@@ -1,6 +1,5 @@
import altair
import matplotlib.pyplot
-import plotly.graph_objects as go
import pytest
from skore.persistence.item import ItemTypeError, MediaItem
@@ -59,18 +58,6 @@ def test_factory_matplotlib(self, mock_nowstr):
assert item.created_at == mock_nowstr
assert item.updated_at == mock_nowstr
- def test_factory_plotly(self, mock_nowstr):
- figure = go.Figure(data=[go.Bar(x=[1, 2, 3], y=[1, 3, 2])])
- figure_bytes = figure.to_json().encode("utf-8")
-
- item = MediaItem.factory(figure)
-
- assert item.media_bytes == figure_bytes
- assert item.media_encoding == "utf-8"
- assert item.media_type == "application/vnd.plotly.v1+json"
- assert item.created_at == mock_nowstr
- assert item.updated_at == mock_nowstr
-
def test_get_serializable_dict(self, mock_nowstr):
item = MediaItem.factory("")
diff --git a/skore/tests/unit/item/test_plotly_figure_item.py b/skore/tests/unit/item/test_plotly_figure_item.py
new file mode 100644
index 000000000..9d1f8f18e
--- /dev/null
+++ b/skore/tests/unit/item/test_plotly_figure_item.py
@@ -0,0 +1,51 @@
+import base64
+
+import plotly.graph_objects
+import plotly.io
+import pytest
+from skore.persistence.item import ItemTypeError, PlotlyFigureItem
+
+
+class TestPlotlyFigureItem:
+ @pytest.fixture(autouse=True)
+ def monkeypatch_datetime(self, monkeypatch, MockDatetime):
+ monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime)
+
+ def test_factory(self, mock_nowstr):
+ bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2])
+ figure = plotly.graph_objects.Figure(data=[bar])
+ item = PlotlyFigureItem.factory(figure)
+
+ assert item.figure_str == plotly.io.to_json(figure, engine="json")
+ assert item.created_at == mock_nowstr
+ assert item.updated_at == mock_nowstr
+
+ def test_factory_exception(self):
+ with pytest.raises(ItemTypeError):
+ PlotlyFigureItem.factory(None)
+
+ def test_figure(self):
+ bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2])
+ figure = plotly.graph_objects.Figure(data=[bar])
+ item1 = PlotlyFigureItem.factory(figure)
+ item2 = PlotlyFigureItem(plotly.io.to_json(figure, engine="json"))
+
+ assert item1.figure == figure
+ assert item2.figure == figure
+
+ def test_as_serializable_dict(self, mock_nowstr):
+ bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2])
+ figure = plotly.graph_objects.Figure(data=[bar])
+ figure_str = plotly.io.to_json(figure, engine="json")
+ figure_bytes = figure_str.encode("utf-8")
+ figure_bytes_b64 = base64.b64encode(figure_bytes).decode()
+
+ item = PlotlyFigureItem.factory(figure)
+
+ assert item.as_serializable_dict() == {
+ "updated_at": mock_nowstr,
+ "created_at": mock_nowstr,
+ "note": None,
+ "media_type": "application/vnd.plotly.v1+json;base64",
+ "value": figure_bytes_b64,
+ }
diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py
index eb9790169..c37a8fdf0 100644
--- a/skore/tests/unit/project/test_project.py
+++ b/skore/tests/unit/project/test_project.py
@@ -3,6 +3,7 @@
import numpy.testing
import pandas
import pandas.testing
+import plotly
import polars
import polars.testing
import pytest
@@ -121,7 +122,7 @@ def test_put_vega_chart(in_memory_project):
assert isinstance(in_memory_project.get("vega_chart"), bytes)
-def test_put_pil_image(in_memory_project):
+def test_put_pillow_image(in_memory_project):
image1 = Image.new("RGB", (100, 100), color="red")
image2 = Image.new("RGBA", (150, 150), color="blue")
@@ -132,6 +133,15 @@ def test_put_pil_image(in_memory_project):
assert in_memory_project.get("image2") == image2
+def test_put_plotly_figure(in_memory_project):
+ bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2])
+ figure = plotly.graph_objects.Figure(data=[bar])
+
+ in_memory_project.put("figure", figure)
+
+ assert in_memory_project.get("figure") == figure
+
+
def test_put_rf_model(in_memory_project, monkeypatch):
# Add a scikit-learn model
monkeypatch.setattr("sklearn.utils.estimator_html_repr", lambda _: "")