From 666f0946f16489de11c747fe3c7f296b035021a7 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 7 Jan 2025 15:43:54 +0100 Subject: [PATCH 01/22] [skip ci] Move `item/` to `persistence/item/` --- skore/src/skore/{ => persistence}/item/__init__.py | 0 skore/src/skore/{ => persistence}/item/cross_validation_item.py | 0 skore/src/skore/{ => persistence}/item/item.py | 0 skore/src/skore/{ => persistence}/item/media_item.py | 0 skore/src/skore/{ => persistence}/item/numpy_array_item.py | 0 skore/src/skore/{ => persistence}/item/pandas_dataframe_item.py | 0 skore/src/skore/{ => persistence}/item/pandas_series_item.py | 0 skore/src/skore/{ => persistence}/item/polars_dataframe_item.py | 0 skore/src/skore/{ => persistence}/item/polars_series_item.py | 0 skore/src/skore/{ => persistence}/item/primitive_item.py | 0 .../skore/{ => persistence}/item/sklearn_base_estimator_item.py | 0 skore/src/skore/{ => persistence}/item/skrub_table_report_item.py | 0 .../src/skore/{ => persistence}/item/standalone_widget.html.jinja | 0 skore/src/skore/persistence/repository/__init__.py | 0 .../src/skore/{item => persistence/repository}/item_repository.py | 0 skore/src/skore/persistence/{ => storage}/abstract_storage.py | 0 skore/src/skore/persistence/{ => storage}/disk_cache_storage.py | 0 skore/src/skore/persistence/{ => storage}/in_memory_storage.py | 0 18 files changed, 0 insertions(+), 0 deletions(-) rename skore/src/skore/{ => persistence}/item/__init__.py (100%) rename skore/src/skore/{ => persistence}/item/cross_validation_item.py (100%) rename skore/src/skore/{ => persistence}/item/item.py (100%) rename skore/src/skore/{ => persistence}/item/media_item.py (100%) rename skore/src/skore/{ => persistence}/item/numpy_array_item.py (100%) rename skore/src/skore/{ => persistence}/item/pandas_dataframe_item.py (100%) rename skore/src/skore/{ => persistence}/item/pandas_series_item.py (100%) rename skore/src/skore/{ => persistence}/item/polars_dataframe_item.py (100%) rename skore/src/skore/{ => persistence}/item/polars_series_item.py (100%) rename skore/src/skore/{ => persistence}/item/primitive_item.py (100%) rename skore/src/skore/{ => persistence}/item/sklearn_base_estimator_item.py (100%) rename skore/src/skore/{ => persistence}/item/skrub_table_report_item.py (100%) rename skore/src/skore/{ => persistence}/item/standalone_widget.html.jinja (100%) create mode 100644 skore/src/skore/persistence/repository/__init__.py rename skore/src/skore/{item => persistence/repository}/item_repository.py (100%) rename skore/src/skore/persistence/{ => storage}/abstract_storage.py (100%) rename skore/src/skore/persistence/{ => storage}/disk_cache_storage.py (100%) rename skore/src/skore/persistence/{ => storage}/in_memory_storage.py (100%) diff --git a/skore/src/skore/item/__init__.py b/skore/src/skore/persistence/item/__init__.py similarity index 100% rename from skore/src/skore/item/__init__.py rename to skore/src/skore/persistence/item/__init__.py diff --git a/skore/src/skore/item/cross_validation_item.py b/skore/src/skore/persistence/item/cross_validation_item.py similarity index 100% rename from skore/src/skore/item/cross_validation_item.py rename to skore/src/skore/persistence/item/cross_validation_item.py diff --git a/skore/src/skore/item/item.py b/skore/src/skore/persistence/item/item.py similarity index 100% rename from skore/src/skore/item/item.py rename to skore/src/skore/persistence/item/item.py diff --git a/skore/src/skore/item/media_item.py b/skore/src/skore/persistence/item/media_item.py similarity index 100% rename from skore/src/skore/item/media_item.py rename to skore/src/skore/persistence/item/media_item.py diff --git a/skore/src/skore/item/numpy_array_item.py b/skore/src/skore/persistence/item/numpy_array_item.py similarity index 100% rename from skore/src/skore/item/numpy_array_item.py rename to skore/src/skore/persistence/item/numpy_array_item.py diff --git a/skore/src/skore/item/pandas_dataframe_item.py b/skore/src/skore/persistence/item/pandas_dataframe_item.py similarity index 100% rename from skore/src/skore/item/pandas_dataframe_item.py rename to skore/src/skore/persistence/item/pandas_dataframe_item.py diff --git a/skore/src/skore/item/pandas_series_item.py b/skore/src/skore/persistence/item/pandas_series_item.py similarity index 100% rename from skore/src/skore/item/pandas_series_item.py rename to skore/src/skore/persistence/item/pandas_series_item.py diff --git a/skore/src/skore/item/polars_dataframe_item.py b/skore/src/skore/persistence/item/polars_dataframe_item.py similarity index 100% rename from skore/src/skore/item/polars_dataframe_item.py rename to skore/src/skore/persistence/item/polars_dataframe_item.py diff --git a/skore/src/skore/item/polars_series_item.py b/skore/src/skore/persistence/item/polars_series_item.py similarity index 100% rename from skore/src/skore/item/polars_series_item.py rename to skore/src/skore/persistence/item/polars_series_item.py diff --git a/skore/src/skore/item/primitive_item.py b/skore/src/skore/persistence/item/primitive_item.py similarity index 100% rename from skore/src/skore/item/primitive_item.py rename to skore/src/skore/persistence/item/primitive_item.py diff --git a/skore/src/skore/item/sklearn_base_estimator_item.py b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py similarity index 100% rename from skore/src/skore/item/sklearn_base_estimator_item.py rename to skore/src/skore/persistence/item/sklearn_base_estimator_item.py diff --git a/skore/src/skore/item/skrub_table_report_item.py b/skore/src/skore/persistence/item/skrub_table_report_item.py similarity index 100% rename from skore/src/skore/item/skrub_table_report_item.py rename to skore/src/skore/persistence/item/skrub_table_report_item.py diff --git a/skore/src/skore/item/standalone_widget.html.jinja b/skore/src/skore/persistence/item/standalone_widget.html.jinja similarity index 100% rename from skore/src/skore/item/standalone_widget.html.jinja rename to skore/src/skore/persistence/item/standalone_widget.html.jinja diff --git a/skore/src/skore/persistence/repository/__init__.py b/skore/src/skore/persistence/repository/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skore/src/skore/item/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py similarity index 100% rename from skore/src/skore/item/item_repository.py rename to skore/src/skore/persistence/repository/item_repository.py diff --git a/skore/src/skore/persistence/abstract_storage.py b/skore/src/skore/persistence/storage/abstract_storage.py similarity index 100% rename from skore/src/skore/persistence/abstract_storage.py rename to skore/src/skore/persistence/storage/abstract_storage.py diff --git a/skore/src/skore/persistence/disk_cache_storage.py b/skore/src/skore/persistence/storage/disk_cache_storage.py similarity index 100% rename from skore/src/skore/persistence/disk_cache_storage.py rename to skore/src/skore/persistence/storage/disk_cache_storage.py diff --git a/skore/src/skore/persistence/in_memory_storage.py b/skore/src/skore/persistence/storage/in_memory_storage.py similarity index 100% rename from skore/src/skore/persistence/in_memory_storage.py rename to skore/src/skore/persistence/storage/in_memory_storage.py From 6811b6917ee306e1f9b48b4c89752f37050f81df Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 7 Jan 2025 15:55:35 +0100 Subject: [PATCH 02/22] [skip ci] Move `view/` to `persistence/view` --- .../src/skore/{view => persistence/repository}/view_repository.py | 0 skore/src/skore/{ => persistence}/view/__init__.py | 0 skore/src/skore/{ => persistence}/view/view.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename skore/src/skore/{view => persistence/repository}/view_repository.py (100%) rename skore/src/skore/{ => persistence}/view/__init__.py (100%) rename skore/src/skore/{ => persistence}/view/view.py (100%) diff --git a/skore/src/skore/view/view_repository.py b/skore/src/skore/persistence/repository/view_repository.py similarity index 100% rename from skore/src/skore/view/view_repository.py rename to skore/src/skore/persistence/repository/view_repository.py diff --git a/skore/src/skore/view/__init__.py b/skore/src/skore/persistence/view/__init__.py similarity index 100% rename from skore/src/skore/view/__init__.py rename to skore/src/skore/persistence/view/__init__.py diff --git a/skore/src/skore/view/view.py b/skore/src/skore/persistence/view/view.py similarity index 100% rename from skore/src/skore/view/view.py rename to skore/src/skore/persistence/view/view.py From c08f284fe97449731600ba9bf5901c0edc0b2d6b Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 7 Jan 2025 16:22:43 +0100 Subject: [PATCH 03/22] [skip ci] Remove `put_item`/`get_item` --- skore/src/skore/persistence/item/__init__.py | 27 +++++- skore/src/skore/project/project.py | 92 ++++---------------- 2 files changed, 43 insertions(+), 76 deletions(-) diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index b7c89a040..0735647f5 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -43,7 +43,32 @@ def object_to_item(object: Any) -> Item: # correct type. If not, they throw a ItemTypeError exception. return cls.factory(object) - raise NotImplementedError(f"Type '{object.__class__}' is not supported.") + return PickleItem(object) + + +def item_to_object(item: Item) -> Any: + if isinstance(item, PrimitiveItem): + return item.primitive + elif isinstance(item, NumpyArrayItem): + return item.array + elif isinstance(item, PandasDataFrameItem): + return item.dataframe + elif isinstance(item, PandasSeriesItem): + return item.series + elif isinstance(item, PolarsDataFrameItem): + return item.dataframe + elif isinstance(item, PolarsSeriesItem): + return item.series + elif isinstance(item, SklearnBaseEstimatorItem): + return item.estimator + elif isinstance(item, CrossValidationItem): + return item.cv_results_serialized + elif isinstance(item, MediaItem): + return item.media_bytes + elif isinstance(item, PickleItem): + return repr(item.pickle_bytes) + else: + raise ValueError(f"Item {item} is not a known item type.") __all__ = [ diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index 289b04d10..299409425 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -1,28 +1,18 @@ """Define a Project.""" -import logging -from typing import Any, Optional, Union - -from skore.item import ( - CrossValidationItem, - Item, - ItemRepository, - MediaItem, - NumpyArrayItem, - PandasDataFrameItem, - PandasSeriesItem, - PolarsDataFrameItem, - PolarsSeriesItem, - PrimitiveItem, - SklearnBaseEstimatorItem, - object_to_item, -) -from skore.view.view import View -from skore.view.view_repository import ViewRepository - -logger = logging.getLogger(__name__) -logger.addHandler(logging.NullHandler()) # Default to no output -logger.setLevel(logging.INFO) +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Union + +from skore.persistence import item_to_object, object_to_item + +if TYPE_CHECKING: + from skore.persistence import ( + Item, + ItemRepository, + View, + ViewRepository, + ) MISSING = object() @@ -128,28 +118,21 @@ def put(self, key: Union[str, dict[str, Any]], value: Optional[Any] = MISSING): If the value type is not supported. """ if value is not MISSING: - key_to_item = {key: value} + key_to_value = {key: value} elif isinstance(key, dict): - key_to_item = key + key_to_value = key else: raise TypeError( f"Bad parameters. " f"When value is not specified, key must be a dict (found '{type(key)}')" ) - for key, value in key_to_item.items(): + for key, value in key_to_value.items(): if not isinstance(key, str): raise TypeError(f"Key must be a string (found '{type(key)}')") self.item_repository.put_item(key, object_to_item(value)) - def put_item(self, key: str, item: Item): - """Add an Item to the Project.""" - if not isinstance(key, str): - raise TypeError(f"Key must be a string (found '{type(key)}')") - - self.item_repository.put_item(key, item) - def get(self, key: str) -> Any: """Get the value corresponding to ``key`` from the Project. @@ -163,48 +146,7 @@ def get(self, key: str) -> Any: KeyError If the key does not correspond to any item. """ - item = self.get_item(key) - - if isinstance(item, PrimitiveItem): - return item.primitive - elif isinstance(item, NumpyArrayItem): - return item.array - elif isinstance(item, PandasDataFrameItem): - return item.dataframe - elif isinstance(item, PandasSeriesItem): - return item.series - elif isinstance(item, PolarsDataFrameItem): - return item.dataframe - elif isinstance(item, PolarsSeriesItem): - return item.series - elif isinstance(item, SklearnBaseEstimatorItem): - return item.estimator - elif isinstance(item, CrossValidationItem): - return item.cv_results_serialized - elif isinstance(item, MediaItem): - return item.media_bytes - else: - raise ValueError(f"Item {item} is not a known item type.") - - def get_item(self, key: str) -> Item: - """Get the item corresponding to ``key`` from the Project. - - Parameters - ---------- - key : str - The key corresponding to the item to get. - - Returns - ------- - item : Item - The Item corresponding to ``key``. - - Raises - ------ - KeyError - If the key does not correspond to any item. - """ - return self.item_repository.get_item(key) + return item_to_object(self.item_repository.get_item(key)) def get_item_versions(self, key: str) -> list[Item]: """ From 0277edb1f87125c48e4cb33b24f98d0274a29fba Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 7 Jan 2025 16:40:14 +0100 Subject: [PATCH 04/22] [skip ci] Add `PickleItem` --- skore/src/skore/persistence/item/__init__.py | 37 +++++++++---------- .../src/skore/persistence/item/pickle_item.py | 25 +++++++++++++ 2 files changed, 43 insertions(+), 19 deletions(-) create mode 100644 skore/src/skore/persistence/item/pickle_item.py diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index 0735647f5..57e3b4c10 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -5,18 +5,19 @@ from contextlib import suppress from typing import Any -from skore.item import skrub_table_report_item as SkrubTableReportItem -from skore.item.cross_validation_item import CrossValidationItem -from skore.item.item import Item, ItemTypeError -from skore.item.item_repository import ItemRepository -from skore.item.media_item import MediaItem -from skore.item.numpy_array_item import NumpyArrayItem -from skore.item.pandas_dataframe_item import PandasDataFrameItem -from skore.item.pandas_series_item import PandasSeriesItem -from skore.item.polars_dataframe_item import PolarsDataFrameItem -from skore.item.polars_series_item import PolarsSeriesItem -from skore.item.primitive_item import PrimitiveItem -from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem +from .cross_validation_item import CrossValidationItem +from .item import Item, ItemTypeError +from .item_repository import ItemRepository +from .media_item import MediaItem +from .numpy_array_item import NumpyArrayItem +from .pandas_dataframe_item import PandasDataFrameItem +from .pandas_series_item import PandasSeriesItem +from .pickle_item import PickleItem +from .polars_dataframe_item import PolarsDataFrameItem +from .polars_series_item import PolarsSeriesItem +from .primitive_item import PrimitiveItem +from .sklearn_base_estimator_item import SklearnBaseEstimatorItem +from .skrub_table_report_item import SkrubTableReportItem def object_to_item(object: Any) -> Item: @@ -51,13 +52,9 @@ def item_to_object(item: Item) -> Any: return item.primitive elif isinstance(item, NumpyArrayItem): return item.array - elif isinstance(item, PandasDataFrameItem): + elif isinstance(item, PandasDataFrameItem) or isinstance(item, PolarsDataFrameItem): return item.dataframe - elif isinstance(item, PandasSeriesItem): - return item.series - elif isinstance(item, PolarsDataFrameItem): - return item.dataframe - elif isinstance(item, PolarsSeriesItem): + elif isinstance(item, PandasSeriesItem) or isinstance(item, PolarsSeriesItem): return item.series elif isinstance(item, SklearnBaseEstimatorItem): return item.estimator @@ -66,7 +63,7 @@ def item_to_object(item: Item) -> Any: elif isinstance(item, MediaItem): return item.media_bytes elif isinstance(item, PickleItem): - return repr(item.pickle_bytes) + return item.object else: raise ValueError(f"Item {item} is not a known item type.") @@ -79,10 +76,12 @@ def item_to_object(item: Item) -> Any: "NumpyArrayItem", "PandasDataFrameItem", "PandasSeriesItem", + "PickleItem", "PolarsDataFrameItem", "PolarsSeriesItem", "PrimitiveItem", "SklearnBaseEstimatorItem", "SkrubTableReportItem", + "item_to_object", "object_to_item", ] diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py new file mode 100644 index 000000000..ce9019488 --- /dev/null +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -0,0 +1,25 @@ +from functools import cached_property +from pickle import dumps, loads +from typing import Any + +from skore.item.item import Item + + +class PickleItem(Item): + def __init__( + self, + pickle_bytes: bytes, + created_at: str | None = None, + updated_at: str | None = None, + ): + super().__init__(created_at, updated_at) + + self.pickle_bytes = pickle_bytes + + @cached_property + def object(self) -> Any: + return loads(self.pickle_bytes) + + @classmethod + def factory(cls, object: Any) -> PickleItem: + return cls(dumps(object)) From f4552e97f40e02af2adec7e3ec0a69eedd3c67c3 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 7 Jan 2025 16:52:33 +0100 Subject: [PATCH 05/22] [skip ci] Fix import of `SkrubTableReportItem` --- skore/src/skore/persistence/item/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index 57e3b4c10..770ce5979 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -5,6 +5,8 @@ from contextlib import suppress from typing import Any +import skrub_table_report_item as SkrubTableReportItem + from .cross_validation_item import CrossValidationItem from .item import Item, ItemTypeError from .item_repository import ItemRepository @@ -17,7 +19,6 @@ from .polars_series_item import PolarsSeriesItem from .primitive_item import PrimitiveItem from .sklearn_base_estimator_item import SklearnBaseEstimatorItem -from .skrub_table_report_item import SkrubTableReportItem def object_to_item(object: Any) -> Item: From aff2a5f0c8db2edfa9cfcd58f7f7cb4d286ef425 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 8 Jan 2025 12:05:51 +0100 Subject: [PATCH 06/22] Fix imports and tests Co-authored-by: Auguste Baum commit 86281ca475f59c2fa43760b243c366c23c38b787 Author: Auguste Baum Date: Wed Jan 8 12:00:46 2025 +0100 mob next [ci-skip] [ci skip] [skip ci] lastFile:skore/tests/integration/ui/test_ui.py commit 9f6f1931584de22ab08a9e4a9779a5f17dee4501 Author: Thomas S Date: Wed Jan 8 11:52:12 2025 +0100 mob next [ci-skip] [ci skip] [skip ci] lastFile:skore/tests/integration/ui/test_ui.py commit cd385f66a805390893a6e415761ab1c0ea65550e Author: Auguste Baum Date: Wed Jan 8 11:41:10 2025 +0100 mob next [ci-skip] [ci skip] [skip ci] lastFile:skore/src/skore/persistence/item/sklearn_base_estimator_item.py commit 41c53559266f139c785211fc97a43e1e2b281203 Author: Thomas S Date: Wed Jan 8 11:27:16 2025 +0100 mob next [ci-skip] [ci skip] [skip ci] lastFile:skore/tests/integration/sklearn/test_cross_validate.py commit a7d956edfeab6ef56020aaae8752e59f7f4e3491 Author: Auguste Baum Date: Wed Jan 8 11:14:59 2025 +0100 mob next [ci-skip] [ci skip] [skip ci] lastFile:skore/tests/unit/view/test_view_repository.py commit 38446cf8499b2d3bcf79d371b00e374a916dd77a Author: Thomas S Date: Wed Jan 8 11:01:43 2025 +0100 mob next [ci-skip] [ci skip] [skip ci] lastFile:skore/src/skore/persistence/item/pickle_item.py commit a6e2a7914ada1e003c5295ccf233c73238972dfc Author: Thomas S Date: Wed Jan 8 10:51:07 2025 +0100 mob start [ci-skip] [ci skip] [skip ci] Co-authored-by: Auguste Baum --- skore/src/skore/persistence/item/__init__.py | 5 +- .../persistence/item/cross_validation_item.py | 3 +- .../src/skore/persistence/item/media_item.py | 2 +- .../persistence/item/numpy_array_item.py | 2 +- .../persistence/item/pandas_dataframe_item.py | 2 +- .../persistence/item/pandas_series_item.py | 2 +- .../src/skore/persistence/item/pickle_item.py | 4 +- .../persistence/item/polars_dataframe_item.py | 2 +- .../persistence/item/polars_series_item.py | 2 +- .../skore/persistence/item/primitive_item.py | 2 +- .../item/sklearn_base_estimator_item.py | 7 ++- .../item/skrub_table_report_item.py | 4 +- .../skore/persistence/repository/__init__.py | 7 +++ .../persistence/repository/item_repository.py | 27 +++++----- .../src/skore/persistence/storage/__init__.py | 9 ++++ skore/src/skore/project/create.py | 2 +- skore/src/skore/project/load.py | 8 +-- skore/src/skore/project/project.py | 7 ++- skore/src/skore/ui/project_routes.py | 4 +- skore/tests/conftest.py | 5 +- .../sklearn/test_cross_validate.py | 4 +- skore/tests/integration/ui/test_ui.py | 51 ++++++++++++------- .../unit/item/test_cross_validation_item.py | 8 +-- skore/tests/unit/item/test_item_repository.py | 3 +- skore/tests/unit/item/test_media_item.py | 4 +- .../tests/unit/item/test_numpy_array_item.py | 4 +- .../unit/item/test_pandas_dataframe_item.py | 4 +- .../unit/item/test_pandas_series_item.py | 4 +- .../unit/item/test_polars_dataframe_item.py | 6 +-- .../unit/item/test_polars_series_item.py | 4 +- skore/tests/unit/item/test_primitive_item.py | 6 +-- .../item/test_sklearn_base_estimator_item.py | 4 +- .../unit/item/test_skrub_table_report_item.py | 4 +- skore/tests/unit/persistence/test_disk.py | 2 +- skore/tests/unit/persistence/test_memory.py | 2 +- skore/tests/unit/test_project.py | 20 ++++---- skore/tests/unit/view/test_view_repository.py | 6 +-- 37 files changed, 142 insertions(+), 100 deletions(-) create mode 100644 skore/src/skore/persistence/storage/__init__.py diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index 770ce5979..7b2d2f002 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -5,11 +5,9 @@ from contextlib import suppress from typing import Any -import skrub_table_report_item as SkrubTableReportItem - +from . import skrub_table_report_item as SkrubTableReportItem from .cross_validation_item import CrossValidationItem from .item import Item, ItemTypeError -from .item_repository import ItemRepository from .media_item import MediaItem from .numpy_array_item import NumpyArrayItem from .pandas_dataframe_item import PandasDataFrameItem @@ -72,7 +70,6 @@ def item_to_object(item: Item) -> Any: __all__ = [ "CrossValidationItem", "Item", - "ItemRepository", "MediaItem", "NumpyArrayItem", "PandasDataFrameItem", diff --git a/skore/src/skore/persistence/item/cross_validation_item.py b/skore/src/skore/persistence/item/cross_validation_item.py index e4e4c4a2a..1121117be 100644 --- a/skore/src/skore/persistence/item/cross_validation_item.py +++ b/skore/src/skore/persistence/item/cross_validation_item.py @@ -20,9 +20,10 @@ import plotly.graph_objects import plotly.io -from skore.item.item import Item, ItemTypeError from skore.sklearn.cross_validation import CrossValidationReporter +from .item import Item, ItemTypeError + if TYPE_CHECKING: import sklearn.base diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py index b83de93e4..b0c957bd6 100644 --- a/skore/src/skore/persistence/item/media_item.py +++ b/skore/src/skore/persistence/item/media_item.py @@ -9,7 +9,7 @@ from io import BytesIO from typing import TYPE_CHECKING, Any -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: from altair.vegalite.v5.schema.core import TopLevelSpec as Altair diff --git a/skore/src/skore/persistence/item/numpy_array_item.py b/skore/src/skore/persistence/item/numpy_array_item.py index 59d80bf66..5ebe0c82c 100644 --- a/skore/src/skore/persistence/item/numpy_array_item.py +++ b/skore/src/skore/persistence/item/numpy_array_item.py @@ -9,7 +9,7 @@ from json import dumps, loads from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: import numpy diff --git a/skore/src/skore/persistence/item/pandas_dataframe_item.py b/skore/src/skore/persistence/item/pandas_dataframe_item.py index f61fef26c..3d124677e 100644 --- a/skore/src/skore/persistence/item/pandas_dataframe_item.py +++ b/skore/src/skore/persistence/item/pandas_dataframe_item.py @@ -9,7 +9,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: import pandas diff --git a/skore/src/skore/persistence/item/pandas_series_item.py b/skore/src/skore/persistence/item/pandas_series_item.py index cd27ebea3..bfc52457c 100644 --- a/skore/src/skore/persistence/item/pandas_series_item.py +++ b/skore/src/skore/persistence/item/pandas_series_item.py @@ -9,7 +9,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: import pandas diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py index ce9019488..42a0ace5d 100644 --- a/skore/src/skore/persistence/item/pickle_item.py +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -1,8 +1,10 @@ +from __future__ import annotations + from functools import cached_property from pickle import dumps, loads from typing import Any -from skore.item.item import Item +from .item import Item class PickleItem(Item): diff --git a/skore/src/skore/persistence/item/polars_dataframe_item.py b/skore/src/skore/persistence/item/polars_dataframe_item.py index d42292dbe..dc6b05533 100644 --- a/skore/src/skore/persistence/item/polars_dataframe_item.py +++ b/skore/src/skore/persistence/item/polars_dataframe_item.py @@ -9,7 +9,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: import polars diff --git a/skore/src/skore/persistence/item/polars_series_item.py b/skore/src/skore/persistence/item/polars_series_item.py index 7d846f9a9..ba0dd3f8c 100644 --- a/skore/src/skore/persistence/item/polars_series_item.py +++ b/skore/src/skore/persistence/item/polars_series_item.py @@ -9,7 +9,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: import polars diff --git a/skore/src/skore/persistence/item/primitive_item.py b/skore/src/skore/persistence/item/primitive_item.py index 3e7d1e1a5..039032312 100644 --- a/skore/src/skore/persistence/item/primitive_item.py +++ b/skore/src/skore/persistence/item/primitive_item.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: from typing import Union diff --git a/skore/src/skore/persistence/item/sklearn_base_estimator_item.py b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py index d40b61bba..40321970e 100644 --- a/skore/src/skore/persistence/item/sklearn_base_estimator_item.py +++ b/skore/src/skore/persistence/item/sklearn_base_estimator_item.py @@ -9,7 +9,7 @@ from functools import cached_property from typing import TYPE_CHECKING -from skore.item.item import Item, ItemTypeError +from .item import Item, ItemTypeError if TYPE_CHECKING: import sklearn.base @@ -101,11 +101,14 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI """ import sklearn.base import sklearn.utils - import skops.io if not isinstance(estimator, sklearn.base.BaseEstimator): raise ItemTypeError(f"Type '{estimator.__class__}' is not supported.") + # This line is only needed if we know `estimator` has the right type, so we do + # it after the type check + import skops.io + estimator_html_repr = sklearn.utils.estimator_html_repr(estimator) estimator_skops = skops.io.dumps(estimator) estimator_skops_untrusted_types = skops.io.get_untrusted_types( diff --git a/skore/src/skore/persistence/item/skrub_table_report_item.py b/skore/src/skore/persistence/item/skrub_table_report_item.py index b411256f2..07eae5bbd 100644 --- a/skore/src/skore/persistence/item/skrub_table_report_item.py +++ b/skore/src/skore/persistence/item/skrub_table_report_item.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING -from skore.item.item import ItemTypeError -from skore.item.media_item import MediaItem +from .item import ItemTypeError +from .media_item import MediaItem if TYPE_CHECKING: from skrub import TableReport diff --git a/skore/src/skore/persistence/repository/__init__.py b/skore/src/skore/persistence/repository/__init__.py index e69de29bb..c088fdacf 100644 --- a/skore/src/skore/persistence/repository/__init__.py +++ b/skore/src/skore/persistence/repository/__init__.py @@ -0,0 +1,7 @@ +from .item_repository import ItemRepository +from .view_repository import ViewRepository + +__all__ = [ + "ItemRepository", + "ViewRepository", +] diff --git a/skore/src/skore/persistence/repository/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py index 74210964f..aea849f80 100644 --- a/skore/src/skore/persistence/repository/item_repository.py +++ b/skore/src/skore/persistence/repository/item_repository.py @@ -8,20 +8,21 @@ from typing import TYPE_CHECKING +from skore.persistence.item import ( + CrossValidationItem, + MediaItem, + NumpyArrayItem, + PandasDataFrameItem, + PandasSeriesItem, + PolarsDataFrameItem, + PolarsSeriesItem, + PrimitiveItem, + SklearnBaseEstimatorItem, +) + if TYPE_CHECKING: - from skore.item.item import Item - from skore.persistence.abstract_storage import AbstractStorage - - -from skore.item.cross_validation_item import CrossValidationItem -from skore.item.media_item import MediaItem -from skore.item.numpy_array_item import NumpyArrayItem -from skore.item.pandas_dataframe_item import PandasDataFrameItem -from skore.item.pandas_series_item import PandasSeriesItem -from skore.item.polars_dataframe_item import PolarsDataFrameItem -from skore.item.polars_series_item import PolarsSeriesItem -from skore.item.primitive_item import PrimitiveItem -from skore.item.sklearn_base_estimator_item import SklearnBaseEstimatorItem + from skore.persistence.item import Item + from skore.persistence.storage import AbstractStorage class ItemRepository: diff --git a/skore/src/skore/persistence/storage/__init__.py b/skore/src/skore/persistence/storage/__init__.py new file mode 100644 index 000000000..21b6a3b1a --- /dev/null +++ b/skore/src/skore/persistence/storage/__init__.py @@ -0,0 +1,9 @@ +from .abstract_storage import AbstractStorage +from .disk_cache_storage import DiskCacheStorage +from .in_memory_storage import InMemoryStorage + +__all__ = [ + "AbstractStorage", + "DiskCacheStorage", + "InMemoryStorage", +] diff --git a/skore/src/skore/project/create.py b/skore/src/skore/project/create.py index bee8449d5..fcffce9d0 100644 --- a/skore/src/skore/project/create.py +++ b/skore/src/skore/project/create.py @@ -11,10 +11,10 @@ ProjectCreationError, ProjectPermissionError, ) +from skore.persistence.view.view import View from skore.project.load import load from skore.project.project import Project, logger from skore.utils._logger import logger_context -from skore.view.view import View def _validate_project_name(project_name: str) -> tuple[bool, Optional[Exception]]: diff --git a/skore/src/skore/project/load.py b/skore/src/skore/project/load.py index c78948690..0e8e7f064 100644 --- a/skore/src/skore/project/load.py +++ b/skore/src/skore/project/load.py @@ -3,10 +3,12 @@ from pathlib import Path from typing import Union -from skore.item import ItemRepository -from skore.persistence.disk_cache_storage import DirectoryDoesNotExist, DiskCacheStorage +from skore.persistence.repository import ItemRepository, ViewRepository +from skore.persistence.storage.disk_cache_storage import ( + DirectoryDoesNotExist, + DiskCacheStorage, +) from skore.project.project import Project -from skore.view.view_repository import ViewRepository class ProjectLoadError(Exception): diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index 299409425..3cca337f2 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -2,9 +2,10 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any, Optional, Union -from skore.persistence import item_to_object, object_to_item +from skore.persistence.item import item_to_object, object_to_item if TYPE_CHECKING: from skore.persistence import ( @@ -15,6 +16,10 @@ ) +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) # Default to no output +logger.setLevel(logging.INFO) + MISSING = object() diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py index 040e5dbca..79eff5ce6 100644 --- a/skore/src/skore/ui/project_routes.py +++ b/skore/src/skore/ui/project_routes.py @@ -9,9 +9,9 @@ from fastapi import APIRouter, HTTPException, Request, status -from skore.item import Item +from skore.persistence.item import Item +from skore.persistence.view.view import Layout, View from skore.project import Project -from skore.view.view import Layout, View router = APIRouter(prefix="/project") diff --git a/skore/tests/conftest.py b/skore/tests/conftest.py index 6b2e3e0ad..61642fb5b 100644 --- a/skore/tests/conftest.py +++ b/skore/tests/conftest.py @@ -1,10 +1,9 @@ from datetime import datetime, timezone import pytest -from skore.item.item_repository import ItemRepository -from skore.persistence.in_memory_storage import InMemoryStorage +from skore.persistence.repository import ItemRepository, ViewRepository +from skore.persistence.storage import InMemoryStorage from skore.project import Project -from skore.view.view_repository import ViewRepository @pytest.fixture diff --git a/skore/tests/integration/sklearn/test_cross_validate.py b/skore/tests/integration/sklearn/test_cross_validate.py index 8c8992728..e39f6ba58 100644 --- a/skore/tests/integration/sklearn/test_cross_validate.py +++ b/skore/tests/integration/sklearn/test_cross_validate.py @@ -10,7 +10,7 @@ from sklearn.multiclass import OneVsOneClassifier from sklearn.svm import SVC from skore import CrossValidationReporter -from skore.item.cross_validation_item import CrossValidationItem +from skore.persistence.item.cross_validation_item import CrossValidationItem from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add @@ -200,7 +200,7 @@ def test_cross_validation_reporter(in_memory_project, fixture_name, request): in_memory_project.put("cross-validation", reporter) - retrieved_item = in_memory_project.get_item("cross-validation") + retrieved_item = in_memory_project.item_repository.get_item("cross-validation") assert isinstance(retrieved_item, CrossValidationItem) diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index f467ff699..f5f984e26 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -1,7 +1,9 @@ import datetime +import json import numpy import pandas +import plotly import polars import pytest from fastapi.testclient import TestClient @@ -9,9 +11,8 @@ from sklearn.linear_model import Lasso from sklearn.model_selection import KFold from skore import CrossValidationReporter -from skore.item.media_item import MediaItem +from skore.persistence.view.view import View from skore.ui.app import create_app -from skore.view.view import View @pytest.fixture @@ -136,16 +137,12 @@ def test_serialize_media_item(client, in_memory_project): html = "

éપUœALDXIWDŸΩΩ

" in_memory_project.put("html", html) - in_memory_project.put_item( - "media html", MediaItem.factory_str(html, media_type="text/html") - ) response = client.get("/api/project/items") assert response.status_code == 200 project = response.json() assert "image" in project["items"]["img"][0]["media_type"] assert project["items"]["html"][0]["value"] == html - assert project["items"]["media html"][0]["value"] == html @pytest.fixture @@ -178,9 +175,21 @@ def test_serialize_cross_validation_item( mock_nowstr, fake_cross_validate, ): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + monkeypatch.setattr( + "skore.persistence.item.cross_validation_item.CrossValidationItem.plots", {} + ) + monkeypatch.setattr( + "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_compare_scores", + lambda _: {}, + ) + monkeypatch.setattr( + "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_timing_normalized", + lambda _: {}, + ) monkeypatch.setattr( - "skore.item.cross_validation_item.CrossValidationItem.plots", {} + "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_timing", + lambda _: {}, ) def prepare_cv(): @@ -196,17 +205,12 @@ def prepare_cv(): reporter = CrossValidationReporter(model, X, y, cv=KFold(3)) in_memory_project.put("cv", reporter) - # Mock the item to make the plot empty - item = in_memory_project.get_item("cv") - item.plots_bytes = {"compare_scores": b"{}"} - in_memory_project.put_item("cv_mocked", item) - response = client.get("/api/project/items") assert response.status_code == 200 project = response.json() expected = { - "name": "cv_mocked", + "name": "cv", "media_type": "application/vnd.skore.cross_validation+json", "value": { "scalar_results": [ @@ -227,7 +231,20 @@ def prepare_cv(): ], } ], - "plots": [{"name": "compare_scores", "value": {}}], + "plots": [ + { + "name": "Scores", + "value": json.loads(plotly.io.to_json({}, engine="json")), + }, + { + "name": "Timings", + "value": json.loads(plotly.io.to_json({}, engine="json")), + }, + { + "name": "Normalized timings", + "value": json.loads(plotly.io.to_json({}, engine="json")), + }, + ], "sections": [ { "title": "Model", @@ -266,7 +283,7 @@ def prepare_cv(): "updated_at": mock_nowstr, "created_at": mock_nowstr, } - actual = project["items"]["cv_mocked"][0] + actual = project["items"]["cv"][0] assert expected == actual @@ -282,7 +299,7 @@ def now(*args, **kwargs): MockDatetime.NOW += MockDatetime.TIMEDELTA return MockDatetime.NOW - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) for i in range(5): in_memory_project.put(str(i), i) diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_item.py index 6ee82a3a7..af691fc06 100644 --- a/skore/tests/unit/item/test_cross_validation_item.py +++ b/skore/tests/unit/item/test_cross_validation_item.py @@ -4,9 +4,9 @@ import plotly.graph_objects import pytest from sklearn.model_selection import StratifiedKFold -from skore.item.cross_validation_item import ( +from skore.persistence.item import ItemTypeError +from skore.persistence.item.cross_validation_item import ( CrossValidationItem, - ItemTypeError, _hash_numpy, ) from skore.sklearn.cross_validation import CrossValidationReporter @@ -67,7 +67,7 @@ class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): class TestCrossValidationItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): @@ -111,7 +111,7 @@ def test_factory(self, mock_nowstr, reporter): def test_get_serializable_dict(self, monkeypatch, mock_nowstr): monkeypatch.setattr( - "skore.item.cross_validation_item.CrossValidationReporter", + "skore.persistence.item.cross_validation_item.CrossValidationReporter", FakeCrossValidationReporter, ) diff --git a/skore/tests/unit/item/test_item_repository.py b/skore/tests/unit/item/test_item_repository.py index 79b93548d..b2672bd30 100644 --- a/skore/tests/unit/item/test_item_repository.py +++ b/skore/tests/unit/item/test_item_repository.py @@ -1,7 +1,8 @@ from datetime import datetime, timezone import pytest -from skore.item import ItemRepository, MediaItem +from skore.persistence.item import MediaItem +from skore.persistence.repository import ItemRepository class TestItemRepository: diff --git a/skore/tests/unit/item/test_media_item.py b/skore/tests/unit/item/test_media_item.py index 97d4b579e..024f2fb5a 100644 --- a/skore/tests/unit/item/test_media_item.py +++ b/skore/tests/unit/item/test_media_item.py @@ -5,13 +5,13 @@ import PIL as pillow import plotly.graph_objects as go import pytest -from skore.item import ItemTypeError, MediaItem +from skore.persistence.item import ItemTypeError, MediaItem class TestMediaItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_numpy_array_item.py b/skore/tests/unit/item/test_numpy_array_item.py index a793cfed1..ed1c29317 100644 --- a/skore/tests/unit/item/test_numpy_array_item.py +++ b/skore/tests/unit/item/test_numpy_array_item.py @@ -2,13 +2,13 @@ import numpy import pytest -from skore.item import ItemTypeError, NumpyArrayItem +from skore.persistence.item import ItemTypeError, NumpyArrayItem class TestNumpyArrayItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_pandas_dataframe_item.py b/skore/tests/unit/item/test_pandas_dataframe_item.py index da9b8f80a..2b6b0edf9 100644 --- a/skore/tests/unit/item/test_pandas_dataframe_item.py +++ b/skore/tests/unit/item/test_pandas_dataframe_item.py @@ -2,13 +2,13 @@ import pytest from pandas import DataFrame, Index, MultiIndex from pandas.testing import assert_frame_equal -from skore.item import ItemTypeError, PandasDataFrameItem +from skore.persistence.item import ItemTypeError, PandasDataFrameItem class TestPandasDataFrameItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_pandas_series_item.py b/skore/tests/unit/item/test_pandas_series_item.py index eb160f765..4a2e396ac 100644 --- a/skore/tests/unit/item/test_pandas_series_item.py +++ b/skore/tests/unit/item/test_pandas_series_item.py @@ -2,13 +2,13 @@ import pytest from pandas import Index, MultiIndex, Series from pandas.testing import assert_series_equal -from skore.item import ItemTypeError, PandasSeriesItem +from skore.persistence.item import ItemTypeError, PandasSeriesItem class TestPandasSeriesItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_polars_dataframe_item.py b/skore/tests/unit/item/test_polars_dataframe_item.py index 8be5250fc..0335e7d06 100644 --- a/skore/tests/unit/item/test_polars_dataframe_item.py +++ b/skore/tests/unit/item/test_polars_dataframe_item.py @@ -2,14 +2,14 @@ import pytest from polars import DataFrame from polars.testing import assert_frame_equal -from skore.item import ItemTypeError, PolarsDataFrameItem -from skore.item.polars_dataframe_item import PolarsToJSONError +from skore.persistence.item import ItemTypeError, PolarsDataFrameItem +from skore.persistence.item.polars_dataframe_item import PolarsToJSONError class TestPolarsDataFrameItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_polars_series_item.py b/skore/tests/unit/item/test_polars_series_item.py index 8ca235c41..1e40a0c3f 100644 --- a/skore/tests/unit/item/test_polars_series_item.py +++ b/skore/tests/unit/item/test_polars_series_item.py @@ -2,13 +2,13 @@ import pytest from polars import Series from polars.testing import assert_series_equal -from skore.item import ItemTypeError, PolarsSeriesItem +from skore.persistence.item import ItemTypeError, PolarsSeriesItem class TestPolarsSeriesItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_primitive_item.py b/skore/tests/unit/item/test_primitive_item.py index 2a97f6eff..babcdf2b8 100644 --- a/skore/tests/unit/item/test_primitive_item.py +++ b/skore/tests/unit/item/test_primitive_item.py @@ -1,5 +1,5 @@ import pytest -from skore.item import ItemTypeError, PrimitiveItem +from skore.persistence.item import ItemTypeError, PrimitiveItem class TestPrimitiveItem: @@ -16,7 +16,7 @@ class TestPrimitiveItem: ], ) def test_factory(self, monkeypatch, mock_nowstr, MockDatetime, primitive): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) item = PrimitiveItem.factory(primitive) @@ -43,7 +43,7 @@ def test_factory_exception(self): def test_get_serializable_dict( self, monkeypatch, mock_nowstr, MockDatetime, primitive ): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) item = PrimitiveItem.factory(primitive) serializable = item.as_serializable_dict() diff --git a/skore/tests/unit/item/test_sklearn_base_estimator_item.py b/skore/tests/unit/item/test_sklearn_base_estimator_item.py index 469a705bc..d0cbe93cf 100644 --- a/skore/tests/unit/item/test_sklearn_base_estimator_item.py +++ b/skore/tests/unit/item/test_sklearn_base_estimator_item.py @@ -1,7 +1,7 @@ import pytest import sklearn.svm import skops.io -from skore.item import ItemTypeError, SklearnBaseEstimatorItem +from skore.persistence.item import ItemTypeError, SklearnBaseEstimatorItem class Estimator(sklearn.svm.SVC): @@ -11,7 +11,7 @@ class Estimator(sklearn.svm.SVC): class TestSklearnBaseEstimatorItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): diff --git a/skore/tests/unit/item/test_skrub_table_report_item.py b/skore/tests/unit/item/test_skrub_table_report_item.py index eef317096..d35236166 100644 --- a/skore/tests/unit/item/test_skrub_table_report_item.py +++ b/skore/tests/unit/item/test_skrub_table_report_item.py @@ -1,12 +1,12 @@ import pytest from pandas import DataFrame -from skore.item import ItemTypeError, SkrubTableReportItem +from skore.persistence.item import ItemTypeError, SkrubTableReportItem from skrub import TableReport class TestSkrubTableReportItem: def test_factory(self, monkeypatch, mock_nowstr, MockDatetime): - monkeypatch.setattr("skore.item.item.datetime", MockDatetime) + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) monkeypatch.setattr("secrets.token_hex", lambda: "azertyuiop") df = DataFrame(dict(a=[1, 2], b=["one", "two"], c=[11.1, 11.1])) diff --git a/skore/tests/unit/persistence/test_disk.py b/skore/tests/unit/persistence/test_disk.py index 77838978e..7c747c400 100644 --- a/skore/tests/unit/persistence/test_disk.py +++ b/skore/tests/unit/persistence/test_disk.py @@ -2,7 +2,7 @@ import shutil from pathlib import Path -from skore.persistence.disk_cache_storage import DiskCacheStorage +from skore.persistence.storage import DiskCacheStorage def test_disk_storage(tmp_path: Path): diff --git a/skore/tests/unit/persistence/test_memory.py b/skore/tests/unit/persistence/test_memory.py index f28dcc73d..53608496e 100644 --- a/skore/tests/unit/persistence/test_memory.py +++ b/skore/tests/unit/persistence/test_memory.py @@ -1,4 +1,4 @@ -from skore.persistence.in_memory_storage import InMemoryStorage +from skore.persistence.storage import InMemoryStorage def test_in_memory_storage(): diff --git a/skore/tests/unit/test_project.py b/skore/tests/unit/test_project.py index 88b256e54..662ed5e90 100644 --- a/skore/tests/unit/test_project.py +++ b/skore/tests/unit/test_project.py @@ -17,6 +17,7 @@ ProjectAlreadyExistsError, ProjectCreationError, ) +from skore.persistence.view.view import View from skore.project import ( Project, create, @@ -24,7 +25,6 @@ ) from skore.project.create import _validate_project_name from skore.project.load import ProjectLoadError -from skore.view.view import View def test_put_string_item(in_memory_project): @@ -262,17 +262,15 @@ def test_put_several_nested(in_memory_project): assert in_memory_project.get("a") == {"b": "baz"} -def test_put_several_error(in_memory_project): - """If some key-value pairs are wrong, add all that are valid and print a warning.""" - with pytest.raises(NotImplementedError): - in_memory_project.put( - { - "a": "foo", - "b": (lambda: "unsupported object"), - } - ) +def test_put_several_lambda(in_memory_project): + in_memory_project.put( + { + "a": "foo", + "b": (lambda: "unsupported object"), + } + ) - assert in_memory_project.list_item_keys() == ["a"] + assert in_memory_project.list_item_keys() == ["a", "b"] def test_put_key_is_a_tuple(in_memory_project): diff --git a/skore/tests/unit/view/test_view_repository.py b/skore/tests/unit/view/test_view_repository.py index ed46d271d..87c21cea1 100644 --- a/skore/tests/unit/view/test_view_repository.py +++ b/skore/tests/unit/view/test_view_repository.py @@ -1,7 +1,7 @@ import pytest -from skore.persistence.in_memory_storage import InMemoryStorage -from skore.view.view import View -from skore.view.view_repository import ViewRepository +from skore.persistence.repository import ViewRepository +from skore.persistence.storage import InMemoryStorage +from skore.persistence.view.view import View @pytest.fixture From 9a3698ca002365d992b07fceb2129ccbdd29644d Mon Sep 17 00:00:00 2001 From: Thomas S Date: Thu, 9 Jan 2025 10:48:45 +0100 Subject: [PATCH 07/22] Change the way a pickle item is exported to UI --- skore/src/skore/persistence/item/pickle_item.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py index 42a0ace5d..2f6d8ff55 100644 --- a/skore/src/skore/persistence/item/pickle_item.py +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -25,3 +25,9 @@ def object(self) -> Any: @classmethod def factory(cls, object: Any) -> PickleItem: return cls(dumps(object)) + + def as_serializable_dict(self): + return super().as_serializable_dict() | { + "media_type": "text/markdown", + "value": repr(self.object), + } From 8643aa9b7d66e148441b235bb67f9ef4e7f40108 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Fri, 10 Jan 2025 11:30:09 +0100 Subject: [PATCH 08/22] Add tests --- skore/src/skore/persistence/item/__init__.py | 6 +-- .../persistence/repository/item_repository.py | 26 +----------- skore/tests/integration/ui/test_ui.py | 29 ++++++++++++++ skore/tests/unit/item/test_pickle_item.py | 40 +++++++++++++++++++ skore/tests/unit/project/test_project.py | 9 +---- 5 files changed, 76 insertions(+), 34 deletions(-) create mode 100644 skore/tests/unit/item/test_pickle_item.py diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index 7b2d2f002..79b0eb138 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -43,7 +43,7 @@ def object_to_item(object: Any) -> Item: # correct type. If not, they throw a ItemTypeError exception. return cls.factory(object) - return PickleItem(object) + return PickleItem.factory(object) def item_to_object(item: Item) -> Any: @@ -51,9 +51,9 @@ def item_to_object(item: Item) -> Any: return item.primitive elif isinstance(item, NumpyArrayItem): return item.array - elif isinstance(item, PandasDataFrameItem) or isinstance(item, PolarsDataFrameItem): + elif isinstance(item, (PandasDataFrameItem, PolarsDataFrameItem)): return item.dataframe - elif isinstance(item, PandasSeriesItem) or isinstance(item, PolarsSeriesItem): + elif isinstance(item, (PandasSeriesItem, PolarsSeriesItem)): return item.series elif isinstance(item, SklearnBaseEstimatorItem): return item.estimator diff --git a/skore/src/skore/persistence/repository/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py index aea849f80..dd99258ef 100644 --- a/skore/src/skore/persistence/repository/item_repository.py +++ b/skore/src/skore/persistence/repository/item_repository.py @@ -8,17 +8,7 @@ from typing import TYPE_CHECKING -from skore.persistence.item import ( - CrossValidationItem, - MediaItem, - NumpyArrayItem, - PandasDataFrameItem, - PandasSeriesItem, - PolarsDataFrameItem, - PolarsSeriesItem, - PrimitiveItem, - SklearnBaseEstimatorItem, -) +import skore.persistence.item if TYPE_CHECKING: from skore.persistence.item import Item @@ -35,18 +25,6 @@ class ItemRepository: storage as a map from keys to *lists of* values. """ - ITEM_CLASS_NAME_TO_ITEM_CLASS = { - "MediaItem": MediaItem, - "NumpyArrayItem": NumpyArrayItem, - "PandasDataFrameItem": PandasDataFrameItem, - "PandasSeriesItem": PandasSeriesItem, - "PolarsDataFrameItem": PolarsDataFrameItem, - "PolarsSeriesItem": PolarsSeriesItem, - "PrimitiveItem": PrimitiveItem, - "CrossValidationItem": CrossValidationItem, - "SklearnBaseEstimatorItem": SklearnBaseEstimatorItem, - } - def __init__(self, storage: AbstractStorage): """ Initialize the ItemRepository with a storage system. @@ -68,7 +46,7 @@ def __deconstruct_item(item: Item) -> dict: @staticmethod def __construct_item(value) -> Item: item_class_name = value["item_class_name"] - item_class = ItemRepository.ITEM_CLASS_NAME_TO_ITEM_CLASS[item_class_name] + item_class = getattr(skore.persistence.item, item_class_name) item = value["item"] return item_class(**item) diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index a4d641ee7..360761a6f 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -337,3 +337,32 @@ def now(*args, **kwargs): ("5", 5), ("4", 5), ] + + +def test_get_items_with_pickle_item( + monkeypatch, + MockDatetime, + mock_nowstr, + client, + in_memory_project, +): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + in_memory_project.put("pickle", object) + + response = client.get("/api/project/items") + + assert response.status_code == 200 + assert response.json() == { + "items": { + "pickle": [ + { + "created_at": mock_nowstr, + "updated_at": mock_nowstr, + "name": "pickle", + "media_type": "text/markdown", + "value": "", + }, + ], + }, + "views": {}, + } diff --git a/skore/tests/unit/item/test_pickle_item.py b/skore/tests/unit/item/test_pickle_item.py new file mode 100644 index 000000000..acbc3924e --- /dev/null +++ b/skore/tests/unit/item/test_pickle_item.py @@ -0,0 +1,40 @@ +import pickle + +import pytest +from skore.persistence.item import PickleItem + + +class TestPickleItem: + @pytest.fixture(autouse=True) + def monkeypatch_datetime(self, monkeypatch, MockDatetime): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + + @pytest.mark.parametrize("object", [0, 0.0, int, True, [0], {0: 0}]) + def test_factory(self, mock_nowstr, object): + item = PickleItem.factory(object) + + assert item.pickle_bytes == pickle.dumps(object) + assert item.created_at == mock_nowstr + assert item.updated_at == mock_nowstr + + def test_object(self, mock_nowstr): + item1 = PickleItem.factory(int) + item2 = PickleItem( + pickle_bytes=pickle.dumps(int), + created_at=mock_nowstr, + updated_at=mock_nowstr, + ) + + assert item1.object is int + assert item2.object is int + + def test_get_serializable_dict(self, mock_nowstr): + item = PickleItem.factory(int) + serializable = item.as_serializable_dict() + + assert serializable == { + "updated_at": mock_nowstr, + "created_at": mock_nowstr, + "media_type": "text/markdown", + "value": repr(int), + } diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py index a1da72d71..13414b7e5 100644 --- a/skore/tests/unit/project/test_project.py +++ b/skore/tests/unit/project/test_project.py @@ -235,13 +235,8 @@ def test_put_several_nested(in_memory_project): assert in_memory_project.get("a") == {"b": "baz"} -def test_put_several_lambda(in_memory_project): - in_memory_project.put( - { - "a": "foo", - "b": (lambda: "unsupported object"), - } - ) +def test_put_several_complex(in_memory_project): + in_memory_project.put({"a": int, "b": float}) assert in_memory_project.list_item_keys() == ["a", "b"] From 80df8dbdf67fc488d7c90ca5ea2c0702ca535cda Mon Sep 17 00:00:00 2001 From: Thomas S Date: Fri, 10 Jan 2025 15:05:48 +0100 Subject: [PATCH 09/22] Add docstring --- skore/src/skore/persistence/item/__init__.py | 1 + .../src/skore/persistence/item/pickle_item.py | 41 +++++++++++++++++++ .../skore/persistence/repository/__init__.py | 2 + .../src/skore/persistence/storage/__init__.py | 2 + 4 files changed, 46 insertions(+) diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index 79b0eb138..fca1e5c18 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -47,6 +47,7 @@ def object_to_item(object: Any) -> Item: def item_to_object(item: Item) -> Any: + """Transform an Item into its original object.""" if isinstance(item, PrimitiveItem): return item.primitive elif isinstance(item, NumpyArrayItem): diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py index 2f6d8ff55..30c627cde 100644 --- a/skore/src/skore/persistence/item/pickle_item.py +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -1,3 +1,9 @@ +"""PickleItem. + +This module defines the PickleItem class, which is used to persist objects that cannot +be otherwise. +""" + from __future__ import annotations from functools import cached_property @@ -8,25 +14,60 @@ class PickleItem(Item): + """ + A class to represent any object item. + + This class is generally used to persist objects that cannot be otherwise. + It encapsulates the object with its pickle representaton, its creation and update + timestamps. + """ + def __init__( self, pickle_bytes: bytes, created_at: str | None = None, updated_at: str | None = None, ): + """ + Initialize a PickleItem. + + Parameters + ---------- + pickle_bytes : bytes + The raw bytes of the object pickle representation. + created_at : str + The creation timestamp in ISO format. + updated_at : str + The last update timestamp in ISO format. + """ super().__init__(created_at, updated_at) self.pickle_bytes = pickle_bytes @cached_property def object(self) -> Any: + """The object from the persistence.""" return loads(self.pickle_bytes) @classmethod def factory(cls, object: Any) -> PickleItem: + """ + Create a new PickleItem with any object. + + Parameters + ---------- + object: Any + The object to store. + + Returns + ------- + PickleItem + A new PickleItem instance. + """ return cls(dumps(object)) def as_serializable_dict(self): + """Get a JSON serializable representation of the item.""" return super().as_serializable_dict() | { "media_type": "text/markdown", "value": repr(self.object), diff --git a/skore/src/skore/persistence/repository/__init__.py b/skore/src/skore/persistence/repository/__init__.py index c088fdacf..845bb362c 100644 --- a/skore/src/skore/persistence/repository/__init__.py +++ b/skore/src/skore/persistence/repository/__init__.py @@ -1,3 +1,5 @@ +"""Provide a set of classes responsible for manipulating items and views.""" + from .item_repository import ItemRepository from .view_repository import ViewRepository diff --git a/skore/src/skore/persistence/storage/__init__.py b/skore/src/skore/persistence/storage/__init__.py index 21b6a3b1a..75bb67428 100644 --- a/skore/src/skore/persistence/storage/__init__.py +++ b/skore/src/skore/persistence/storage/__init__.py @@ -1,3 +1,5 @@ +"""Provide a set of storage classes for storing and retrieving data.""" + from .abstract_storage import AbstractStorage from .disk_cache_storage import DiskCacheStorage from .in_memory_storage import InMemoryStorage From c8f3ecfa0b5dd7f26bf4e72954d61be319bf83e7 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Fri, 10 Jan 2025 17:29:32 +0100 Subject: [PATCH 10/22] WIP - Change examples --- .../getting_started/plot_tracking_items.py | 2 +- .../plot_working_with_projects.py | 18 ++++++------------ .../model_evaluation/plot_cross_validate.py | 2 +- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/examples/getting_started/plot_tracking_items.py b/examples/getting_started/plot_tracking_items.py index 11bfe30aa..ec8e1a5fb 100644 --- a/examples/getting_started/plot_tracking_items.py +++ b/examples/getting_started/plot_tracking_items.py @@ -73,7 +73,7 @@ # We retrieve the history of the ``my_int`` item: # %% -item_histories = my_project.get_item_versions("my_int") +item_histories = my_project.get_item_versions("my_int") # TO CHANGE /!\ # %% # We can print the first history (first iteration) of this item: diff --git a/examples/getting_started/plot_working_with_projects.py b/examples/getting_started/plot_working_with_projects.py index 0b16b2074..3d672df1d 100644 --- a/examples/getting_started/plot_working_with_projects.py +++ b/examples/getting_started/plot_working_with_projects.py @@ -119,25 +119,19 @@ def my_func(x): ) # %% -# Moreover, we can also explicitly tell skore the media type of an object, for example -# in HTML: +# Moreover, we can also explicitly tell skore the way we want to display an object, for +# example in HTML: # %% -from skore.item import MediaItem -my_project.put_item( +my_project.put( "my_string_3", - MediaItem.factory( - "

Title

bold, italic, etc.

", media_type="text/html" - ), + "

Title

bold, italic, etc.

", + display_as="html", ) # %% -# .. note:: -# We used :func:`~skore.Project.put_item` instead of :func:`~skore.Project.put`. - -# %% -# Note that the media type is only used for the UI, and not in this notebook at hand: +# Note that the `display_as` is only used for the UI, and not in this notebook at hand: # %% my_project.get("my_string_3") diff --git a/examples/model_evaluation/plot_cross_validate.py b/examples/model_evaluation/plot_cross_validate.py index 9422e37b0..a47b6b74f 100644 --- a/examples/model_evaluation/plot_cross_validate.py +++ b/examples/model_evaluation/plot_cross_validate.py @@ -157,7 +157,7 @@ # %% # We can also access the plot after we have stored the ``CrossValidationReporter``: my_project.put("cross_validation_regression", reporter) -cv_item = my_project.get_item("cross_validation_regression") +cv_item = my_project.get_item("cross_validation_regression") # TO CHANGE /!\ cv_item.plots["Scores"] # %% From c4b2ee8cdac0bdf69ef814bd5fef4e3cf23b1c85 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Mon, 13 Jan 2025 16:22:18 +0100 Subject: [PATCH 11/22] Rename CrossValidationItem to CrossValidationReporterItem --- skore/src/skore/persistence/item/__init__.py | 10 +++++----- ...ation_item.py => cross_validation_reporter_item.py} | 0 2 files changed, 5 insertions(+), 5 deletions(-) rename skore/src/skore/persistence/item/{cross_validation_item.py => cross_validation_reporter_item.py} (100%) diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index fca1e5c18..bc3890ce5 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -6,7 +6,7 @@ from typing import Any from . import skrub_table_report_item as SkrubTableReportItem -from .cross_validation_item import CrossValidationItem +from .cross_validation_reporter_item import CrossValidationReporterItem from .item import Item, ItemTypeError from .media_item import MediaItem from .numpy_array_item import NumpyArrayItem @@ -31,7 +31,7 @@ def object_to_item(object: Any) -> Item: SklearnBaseEstimatorItem, MediaItem, SkrubTableReportItem, - CrossValidationItem, + CrossValidationReporterItem, ): with suppress(ImportError, ItemTypeError): # ImportError: @@ -58,8 +58,8 @@ def item_to_object(item: Item) -> Any: return item.series elif isinstance(item, SklearnBaseEstimatorItem): return item.estimator - elif isinstance(item, CrossValidationItem): - return item.cv_results_serialized + elif isinstance(item, CrossValidationReporterItem): + return item.reporter elif isinstance(item, MediaItem): return item.media_bytes elif isinstance(item, PickleItem): @@ -69,7 +69,7 @@ def item_to_object(item: Item) -> Any: __all__ = [ - "CrossValidationItem", + "CrossValidationReporterItem", "Item", "MediaItem", "NumpyArrayItem", diff --git a/skore/src/skore/persistence/item/cross_validation_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py similarity index 100% rename from skore/src/skore/persistence/item/cross_validation_item.py rename to skore/src/skore/persistence/item/cross_validation_reporter_item.py From 4d08a193321bbaa1e9d7a31e7868a76cbd24df5d Mon Sep 17 00:00:00 2001 From: Thomas S Date: Mon, 13 Jan 2025 17:06:18 +0100 Subject: [PATCH 12/22] CrossValidationReporterItem's factory now pickles the reporter --- .../item/cross_validation_reporter_item.py | 369 +++++++----------- .../sklearn/test_cross_validate.py | 4 +- skore/tests/integration/ui/test_ui.py | 5 +- ...=> test_cross_validation_reporter_item.py} | 58 ++- 4 files changed, 172 insertions(+), 264 deletions(-) rename skore/tests/unit/item/{test_cross_validation_item.py => test_cross_validation_reporter_item.py} (80%) diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py index 23c81dab9..263e1528c 100644 --- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py +++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py @@ -1,20 +1,21 @@ -"""CrossValidationItem class. +"""CrossValidationReporterItem. -This class represents the output of a cross-validation workflow. +This module defines the CrossValidationReporterItem class, which is used to persist +reporters of cross-validation. """ from __future__ import annotations import contextlib -import copy import dataclasses import hashlib import importlib import json +import pickle import re import statistics from functools import cached_property -from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union +from typing import TYPE_CHECKING, Literal, Optional, TypedDict import numpy import plotly.graph_objects @@ -27,8 +28,6 @@ if TYPE_CHECKING: import sklearn.base - CVSplitter = Any - class EstimatorParamInfo(TypedDict): """Information about an estimator parameter.""" @@ -43,6 +42,12 @@ class EstimatorInfo(TypedDict): params: dict[str, EstimatorParamInfo] +HUMANIZED_PLOT_NAMES = { + "scores": "Scores", + "timing": "Timings", +} + + def _hash_numpy(arr: numpy.ndarray) -> str: """Compute a hash string from a numpy array. @@ -117,86 +122,122 @@ def _params_to_str(estimator_info) -> str: return "\n".join(params_list) -# Data used for training, passed as input to scikit-learn -Data = Any -# Target used for training, passed as input to scikit-learn -Target = Any +def _estimator_info(estimator: sklearn.base.BaseEstimator) -> EstimatorInfo: + estimator_params = ( + estimator.get_params() if hasattr(estimator, "get_params") else {} + ) + + name = estimator.__class__.__name__ + module = estimator.__module__ + + # Figure out the default parameters of the estimator, + # so that we can highlight the non-default ones in the UI + + # This is done by instantiating the class with no arguments and + # computing the diff between the default and ours + try: + estimator_module = importlib.import_module(module) + EstimatorClass = getattr(estimator_module, name) + default_estimator_params = EstimatorClass().get_params() + except Exception: + default_estimator_params = {} + + final_estimator_params: dict[str, EstimatorParamInfo] = {} + for k, v in estimator_params.items(): + param_is_default: bool = ( + k in default_estimator_params and default_estimator_params[k] == v + ) + final_estimator_params[str(k)] = { + "value": repr(v), + "default": param_is_default, + } + return { + "name": name, + "module": module, + "params": final_estimator_params, + } -class CrossValidationItem(Item): - """ - A class to represent the output of a cross-validation workflow. - This class encapsulates the output of the - :func:`sklearn.model_selection.cross_validate` function along with its creation and - update timestamps. - """ +class CrossValidationReporterItem(Item): + """Class to persist the reporter of cross-validation.""" def __init__( self, - cv_results_serialized: dict, - estimator_info: EstimatorInfo, - X_info: dict, - y_info: Union[dict, None], - plots_bytes: dict[str, bytes], - cv_info: dict, - created_at: Union[str, None] = None, - updated_at: Union[str, None] = None, + reporter_bytes: bytes, + created_at: Optional[str] = None, + updated_at: Optional[str] = None, ): """ - Initialize a CrossValidationItem. + Initialize a CrossValidationReporterItem. Parameters ---------- - cv_results_serialized : dict - The dict output of the :func:`sklearn.model_selection.cross_validate` - function, in a form suitable for serialization. - estimator_info : dict - The estimator that was cross-validated. - X_info : dict - A summary of the data, input of the - :func:`sklearn.model_selection.cross_validate` function. - y_info : dict - A summary of the target, input of the - :func:`sklearn.model_selection.cross_validate` function. - plots_bytes : dict[str, bytes] - A collection of plots of the cross-validation results, in the form of bytes. - cv_info: dict - A dict containing cross validation splitting strategy params. - created_at : str + reporter_bytes : bytes + The raw bytes of the reporter pickle representation. + created_at : str, optional The creation timestamp in ISO format. - updated_at : str + updated_at : str, optional The last update timestamp in ISO format. """ super().__init__(created_at, updated_at) - self.cv_results_serialized = cv_results_serialized - self.estimator_info = estimator_info - self.X_info = X_info - self.y_info = y_info - self.plots_bytes = plots_bytes - self.cv_info = cv_info + self.reporter_bytes = reporter_bytes - def as_serializable_dict(self): - """Get a serializable dict from the item. + @classmethod + def factory(cls, reporter: CrossValidationReporter) -> CrossValidationReporterItem: + """ + Create a CrossValidationReporterItem instance from a CrossValidationReporter. - Derived class must call their super implementation - and merge the result with their output. + Parameters + ---------- + reporter : CrossValidationReporter + + Returns + ------- + CrossValidationReporterItem + A new CrossValidationReporterItem instance. """ + if not isinstance(reporter, CrossValidationReporter): + raise ItemTypeError(f"Type '{reporter.__class__}' is not supported.") + + instance = cls(pickle.dumps(reporter)) + + # add reporter as cached property + instance.reporter = reporter + + return instance + + @cached_property + def reporter(self) -> CrossValidationReporter: + """The CrossValidationReporter from the persistence.""" + return pickle.loads(self.reporter_bytes) + + def as_serializable_dict(self): + """Get a serializable dict from the item.""" # Get tabular results (the cv results in a dataframe-like structure) - cv_results = copy.deepcopy(self.cv_results_serialized) - cv_results.pop("indices", None) - - metrics_names = list(cv_results.keys()) - tabular_results = { - "name": "Cross validation results", - "columns": metrics_names, - "data": list(zip(*cv_results.values())), - "favorability": [_metric_favorability(m) for m in metrics_names], + cv_results = { + key: value.tolist() + for key, value in self.reporter.cv_results.items() + if ( + key != "estimator" + and key != "indices" + and isinstance(value, numpy.ndarray) + ) } + metrics_names = list(cv_results) + tabular_results = [ + { + "name": "Cross validation results", + "columns": metrics_names, + "data": list(zip(*cv_results.values())), + "favorability": [_metric_favorability(m) for m in metrics_names], + } + ] + # Get scalar results (summary statistics of the cv results) - mean_cv_results = [ + scalar_results = [ { "name": _metric_title(k), "value": statistics.mean(v), @@ -206,163 +247,22 @@ def as_serializable_dict(self): for k, v in cv_results.items() ] - scalar_results = mean_cv_results - - params_as_str = _params_to_str(self.estimator_info) - # If the estimator is from sklearn, make the class name a hyperlink # to the relevant docs - name = self.estimator_info["name"] - module = re.sub(r"\.\_.+", "", self.estimator_info["module"]) + estimator_info = _estimator_info(self.reporter.estimator) + name = estimator_info["name"] + module = re.sub(r"\.\_.+", "", estimator_info["module"]) if module.startswith("sklearn"): doc_url = f"https://scikit-learn.org/stable/modules/generated/{module}.{name}.html" doc_link = f'{name}' else: doc_link = f"`{name}`" + params_as_str = _params_to_str(estimator_info) estimator_params_as_str = f"{doc_link}\n{params_as_str}" - # Get cross-validation details - cv_params_as_str = ", ".join(f"{k}: *{v}*" for k, v in self.cv_info.items()) - - r = super().as_serializable_dict() - sections = [ - { - "title": "Model", - "icon": "icon-square-cursor", - "items": [ - { - "name": "Estimator parameters", - "description": "Core model configuration used for training", - "value": estimator_params_as_str, - }, - { - "name": "Cross-validation parameters", - "description": "Controls how data is split and validated", - "value": cv_params_as_str, - }, - ], - } - ] - value = { - "scalar_results": scalar_results, - "tabular_results": [tabular_results], - "plots": [ - { - "name": plot_name, - "value": json.loads(plot_bytes.decode("utf-8")), - } - for plot_name, plot_bytes in self.plots_bytes.items() - ], - "sections": sections, - } - r.update( - { - "media_type": "application/vnd.skore.cross_validation+json", - "value": value, - } - ) - return r - - @staticmethod - def _estimator_info(estimator: sklearn.base.BaseEstimator) -> EstimatorInfo: - estimator_params = ( - estimator.get_params() if hasattr(estimator, "get_params") else {} - ) - - name = estimator.__class__.__name__ - module = estimator.__module__ - - # Figure out the default parameters of the estimator, - # so that we can highlight the non-default ones in the UI - - # This is done by instantiating the class with no arguments and - # computing the diff between the default and ours - try: - estimator_module = importlib.import_module(module) - EstimatorClass = getattr(estimator_module, name) - default_estimator_params = EstimatorClass().get_params() - except Exception: - default_estimator_params = {} - - final_estimator_params: dict[str, EstimatorParamInfo] = {} - for k, v in estimator_params.items(): - param_is_default: bool = ( - k in default_estimator_params and default_estimator_params[k] == v - ) - final_estimator_params[str(k)] = { - "value": repr(v), - "default": param_is_default, - } - - return { - "name": name, - "module": module, - "params": final_estimator_params, - } - - @classmethod - def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem: - """ - Create a new CrossValidationItem instance from a CrossValidationReporter. - - Parameters - ---------- - reporter : CrossValidationReporter - - Returns - ------- - CrossValidationItem - A new CrossValidationItem instance. - """ - if not isinstance(reporter, CrossValidationReporter): - raise ItemTypeError( - f"Type '{reporter.__class__}' is not supported, " - f"only '{CrossValidationReporter.__name__}' is." - ) - - cv_results = reporter._cv_results - estimator = reporter.estimator - X = reporter.X - y = reporter.y - plots = reporter.plots - cv = reporter.cv - - cv_results_serialized = {} - for k, v in cv_results.items(): - if k == "estimator": - continue - if k == "indices": - cv_results_serialized["indices"] = { - "train": tuple(arr.tolist() for arr in v["train"]), - "test": tuple(arr.tolist() for arr in v["test"]), - } - if isinstance(v, numpy.ndarray): - cv_results_serialized[k] = v.tolist() - - estimator_info = CrossValidationItem._estimator_info(estimator) - - y_array = y if isinstance(y, numpy.ndarray) else numpy.array(y) - y_info = None if y is None else {"hash": _hash_numpy(y_array)} - - X_array = X if isinstance(X, numpy.ndarray) else numpy.array(X) - X_info = { - "nb_rows": X_array.shape[0], - "nb_cols": X_array.shape[1], - "hash": _hash_numpy(X_array), - } - - humanized_plot_names = { - "scores": "Scores", - "timing": "Timings", - } - plots_bytes = { - humanized_plot_names[plot_name]: ( - plotly.io.to_json(plot, engine="json").encode("utf-8") - ) - for plot_name, plot in dataclasses.asdict(plots).items() - } - + # + cv = self.reporter.cv cv_info: dict[str, str] = {} if isinstance(cv, int): cv_info["n_splits"] = repr(cv) @@ -376,19 +276,40 @@ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem: attr = getattr(cv, attr_name) cv_info[attr_name] = repr(attr) - return cls( - cv_results_serialized=cv_results_serialized, - estimator_info=estimator_info, - X_info=X_info, - y_info=y_info, - plots_bytes=plots_bytes, - cv_info=cv_info, - ) + cv_params_as_str = ", ".join(f"{k}: *{v}*" for k, v in cv_info.items()) - @cached_property - def plots(self) -> dict: - """Various plots of the cross-validation results.""" - return { - name: plotly.io.from_json(plot_bytes.decode("utf-8")) - for name, plot_bytes in self.plots_bytes.items() + # + value = { + "scalar_results": scalar_results, + "tabular_results": tabular_results, + "plots": [ + { + "name": HUMANIZED_PLOT_NAMES[plot_name], + "value": json.loads(plotly.io.to_json(plot, engine="json")), + } + for plot_name, plot in dataclasses.asdict(self.reporter.plots).items() + ], + "sections": [ + { + "title": "Model", + "icon": "icon-square-cursor", + "items": [ + { + "name": "Estimator parameters", + "description": "Core model configuration used for training", + "value": estimator_params_as_str, + }, + { + "name": "Cross-validation parameters", + "description": "Controls how data is split and validated", + "value": cv_params_as_str, + }, + ], + } + ], + } + + return super().as_serializable_dict() | { + "media_type": "application/vnd.skore.cross_validation+json", + "value": value, } diff --git a/skore/tests/integration/sklearn/test_cross_validate.py b/skore/tests/integration/sklearn/test_cross_validate.py index e39f6ba58..00758f751 100644 --- a/skore/tests/integration/sklearn/test_cross_validate.py +++ b/skore/tests/integration/sklearn/test_cross_validate.py @@ -10,7 +10,7 @@ from sklearn.multiclass import OneVsOneClassifier from sklearn.svm import SVC from skore import CrossValidationReporter -from skore.persistence.item.cross_validation_item import CrossValidationItem +from skore.persistence.item import CrossValidationReporterItem from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add @@ -201,7 +201,7 @@ def test_cross_validation_reporter(in_memory_project, fixture_name, request): in_memory_project.put("cross-validation", reporter) retrieved_item = in_memory_project.item_repository.get_item("cross-validation") - assert isinstance(retrieved_item, CrossValidationItem) + assert isinstance(retrieved_item, CrossValidationReporterItem) @pytest.mark.parametrize( diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index 360761a6f..44b15d7a3 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -167,7 +167,7 @@ def _fake_cross_validate(*args, **kwargs): monkeypatch.setattr("sklearn.model_selection.cross_validate", _fake_cross_validate) -def test_serialize_cross_validation_item( +def test_serialize_cross_validation_reporter_item( client, in_memory_project, monkeypatch, @@ -176,9 +176,6 @@ def test_serialize_cross_validation_item( fake_cross_validate, ): monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) - monkeypatch.setattr( - "skore.persistence.item.cross_validation_item.CrossValidationItem.plots", {} - ) monkeypatch.setattr( "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_compare_scores", lambda _: {}, diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_reporter_item.py similarity index 80% rename from skore/tests/unit/item/test_cross_validation_item.py rename to skore/tests/unit/item/test_cross_validation_reporter_item.py index 909eff872..973fc933d 100644 --- a/skore/tests/unit/item/test_cross_validation_item.py +++ b/skore/tests/unit/item/test_cross_validation_reporter_item.py @@ -1,13 +1,13 @@ from dataclasses import dataclass +from pickle import dumps import numpy import plotly.graph_objects import pytest from sklearn.model_selection import StratifiedKFold from skore.persistence.item import ItemTypeError -from skore.persistence.item.cross_validation_item import ( - CrossValidationItem, - _hash_numpy, +from skore.persistence.item.cross_validation_reporter_item import ( + CrossValidationReporterItem, _metric_favorability, ) from skore.sklearn.cross_validation import CrossValidationReporter @@ -27,7 +27,7 @@ class FakeEstimatorNoGetParams: @dataclass class FakeCrossValidationReporter(CrossValidationReporter): - _cv_results = { + cv_results = { "test_score": numpy.array([1, 2, 3]), "estimator": [FakeEstimator(), FakeEstimator(), FakeEstimator()], "fit_time": [1, 2, 3], @@ -44,7 +44,7 @@ class FakeCrossValidationReporter(CrossValidationReporter): @dataclass class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): - _cv_results = { + cv_results = { "test_score": numpy.array([1, 2, 3]), "estimator": [ FakeEstimatorNoGetParams(), @@ -63,14 +63,14 @@ class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): cv = StratifiedKFold(n_splits=5) -class TestCrossValidationItem: +class TestCrossValidationReporterItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): - CrossValidationItem.factory(None) + CrossValidationReporterItem.factory(None) @pytest.mark.parametrize( "reporter", @@ -82,42 +82,32 @@ def test_factory_exception(self): ], ) def test_factory(self, mock_nowstr, reporter): - item = CrossValidationItem.factory(reporter) - - assert item.cv_results_serialized == {"test_score": [1, 2, 3]} - assert item.estimator_info == { - "name": reporter.estimator.__class__.__name__, - "params": ( - {} - if isinstance(reporter.estimator, FakeEstimatorNoGetParams) - else {"alpha": {"value": "3", "default": True}} - ), - "module": "tests.unit.item.test_cross_validation_item", - } - assert item.X_info == { - "nb_cols": 1, - "nb_rows": 1, - "hash": _hash_numpy(FakeCrossValidationReporter.X), - } - assert item.y_info == {"hash": _hash_numpy(FakeCrossValidationReporter.y)} - assert item.cv_info == { - "n_splits": "5", - "random_state": "None", - "shuffle": "False", - } - assert isinstance(item.plots_bytes, dict) - assert isinstance(item.plots, dict) + item = CrossValidationReporterItem.factory(reporter) + + assert item.reporter_bytes == dumps(reporter) assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr + def test_reporter(self, mock_nowstr): + reporter = FakeCrossValidationReporter() + item1 = CrossValidationReporterItem.factory(reporter) + item2 = CrossValidationReporterItem( + reporter_bytes=dumps(reporter), + created_at=mock_nowstr, + updated_at=mock_nowstr, + ) + + assert item1.reporter == reporter + assert item2.reporter == reporter + def test_get_serializable_dict(self, monkeypatch, mock_nowstr): monkeypatch.setattr( - "skore.persistence.item.cross_validation_item.CrossValidationReporter", + "skore.persistence.item.cross_validation_reporter_item.CrossValidationReporter", FakeCrossValidationReporter, ) reporter = FakeCrossValidationReporter() - item = CrossValidationItem.factory(reporter) + item = CrossValidationReporterItem.factory(reporter) serializable = item.as_serializable_dict() assert serializable["updated_at"] == mock_nowstr From 43a296bd956d5841695bb6293f39f21f45a45e8a Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 14 Jan 2025 15:34:55 +0100 Subject: [PATCH 13/22] [skip ci] Use `reporter._cv_results` instead of `reporter.cv_results` --- .../persistence/item/cross_validation_reporter_item.py | 8 ++------ .../unit/item/test_cross_validation_reporter_item.py | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py index 263e1528c..258a30be2 100644 --- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py +++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py @@ -218,12 +218,8 @@ def as_serializable_dict(self): # Get tabular results (the cv results in a dataframe-like structure) cv_results = { key: value.tolist() - for key, value in self.reporter.cv_results.items() - if ( - key != "estimator" - and key != "indices" - and isinstance(value, numpy.ndarray) - ) + for key, value in self.reporter._cv_results.items() + if key not in ("estimator", "indices") and isinstance(value, numpy.ndarray) } metrics_names = list(cv_results) diff --git a/skore/tests/unit/item/test_cross_validation_reporter_item.py b/skore/tests/unit/item/test_cross_validation_reporter_item.py index 973fc933d..0979d5fc3 100644 --- a/skore/tests/unit/item/test_cross_validation_reporter_item.py +++ b/skore/tests/unit/item/test_cross_validation_reporter_item.py @@ -27,7 +27,7 @@ class FakeEstimatorNoGetParams: @dataclass class FakeCrossValidationReporter(CrossValidationReporter): - cv_results = { + _cv_results = { "test_score": numpy.array([1, 2, 3]), "estimator": [FakeEstimator(), FakeEstimator(), FakeEstimator()], "fit_time": [1, 2, 3], @@ -44,7 +44,7 @@ class FakeCrossValidationReporter(CrossValidationReporter): @dataclass class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): - cv_results = { + _cv_results = { "test_score": numpy.array([1, 2, 3]), "estimator": [ FakeEstimatorNoGetParams(), From 459e5488917bd833bce32c76bf78abbb3432d4f7 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 14 Jan 2025 17:45:21 +0100 Subject: [PATCH 14/22] Add `display_as` parameter --- skore/src/skore/persistence/item/__init__.py | 79 +++++++++++++------ .../src/skore/persistence/item/media_item.py | 11 ++- skore/src/skore/project/project.py | 32 +++++--- skore/tests/unit/project/test_display_as.py | 35 ++++++++ 4 files changed, 120 insertions(+), 37 deletions(-) create mode 100644 skore/tests/unit/project/test_display_as.py diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index bc3890ce5..aa9ea9319 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -3,12 +3,12 @@ from __future__ import annotations from contextlib import suppress -from typing import Any +from typing import Any, Literal, Optional from . import skrub_table_report_item as SkrubTableReportItem from .cross_validation_reporter_item import CrossValidationReporterItem from .item import Item, ItemTypeError -from .media_item import MediaItem +from .media_item import MediaItem, MediaType from .numpy_array_item import NumpyArrayItem from .pandas_dataframe_item import PandasDataFrameItem from .pandas_series_item import PandasSeriesItem @@ -19,31 +19,60 @@ from .sklearn_base_estimator_item import SklearnBaseEstimatorItem -def object_to_item(object: Any) -> Item: +def object_to_item( + object: Any, + /, + *, + note: Optional[str] = None, + display_as: Optional[Literal["HTML", "MARKDOWN", "SVG"]] = None, +) -> Item: """Transform an object into an Item.""" - for cls in ( - PrimitiveItem, - PandasDataFrameItem, - PandasSeriesItem, - PolarsDataFrameItem, - PolarsSeriesItem, - NumpyArrayItem, - SklearnBaseEstimatorItem, - MediaItem, - SkrubTableReportItem, - CrossValidationReporterItem, - ): - with suppress(ImportError, ItemTypeError): - # ImportError: - # The factories are responsible to import third-party libraries in a - # lazy way. If library is missing, an ImportError exception will - # automatically be thrown. - # ItemTypeError: - # The factories are responsible for checking that parameters are of the - # correct type. If not, they throw a ItemTypeError exception. - return cls.factory(object) + if display_as is not None: + if not isinstance(object, str): + raise TypeError("`object` must be a str if `display_as` is specified") - return PickleItem.factory(object) + if display_as not in MediaType.__members__: + raise ValueError(f"`display_as` must be in {list(MediaType.__members__)}") + + item = MediaItem.factory_str( + media=object, + media_type=MediaType[display_as].value, + ) + else: + for cls in ( + PrimitiveItem, + PandasDataFrameItem, + PandasSeriesItem, + PolarsDataFrameItem, + PolarsSeriesItem, + NumpyArrayItem, + SklearnBaseEstimatorItem, + MediaItem, + SkrubTableReportItem, + CrossValidationReporterItem, + ): + with suppress(ImportError, ItemTypeError): + # ImportError: + # The factories are responsible to import third-party libraries in a + # lazy way. If library is missing, an ImportError exception will + # automatically be thrown. + # ItemTypeError: + # The factories are responsible for checking that parameters are of + # the correct type. If not, they throw a ItemTypeError exception. + item = cls.factory(object) + break + else: + item = PickleItem.factory(object) + + if not isinstance(note, (type(None), str)): + raise TypeError(f"`note` must be a string (found '{type(note)}')") + + # Since the item classes are now private, and to avoid having to pass the `note` + # parameter in the factories of each item class, we define the content of the + # `note` attribute dynamically. + item.note = note + + return item def item_to_object(item: Item) -> Any: diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py index 3736a2345..e57434743 100644 --- a/skore/src/skore/persistence/item/media_item.py +++ b/skore/src/skore/persistence/item/media_item.py @@ -6,6 +6,7 @@ from __future__ import annotations import base64 +from enum import Enum from io import BytesIO from typing import TYPE_CHECKING, Any, Union @@ -25,6 +26,14 @@ def lazy_is_instance(object: Any, cls_fullname: str) -> bool: } +class MediaType(Enum): + """Enum used to map aliases and media types.""" + + HTML = "text/html" + MARKDOWN = "text/markdown" + SVG = "image/svg+xml" + + class MediaItem(Item): """ A class to represent a media item. @@ -165,7 +174,7 @@ def factory_str(cls, media: str, media_type: str = "text/markdown") -> MediaItem media : str The string content to store. media_type : str, optional - The MIME type of the media content, by default "text/html". + The MIME type of the media content, by default "text/markdown". Returns ------- diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index cb335feb5..365a7c97e 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from skore.persistence.item import item_to_object, object_to_item @@ -73,7 +73,14 @@ def __init__( self.item_repository = item_repository self.view_repository = view_repository - def put(self, key: str, value: Any, *, note: Optional[str] = None): + def put( + self, + key: str, + value: Any, + *, + note: Optional[str] = None, + display_as: Optional[Literal["HTML", "MARKDOWN", "SVG"]] = None, + ): """Add a key-value pair to the Project. If an item with the same key already exists, its value is replaced by the new @@ -85,8 +92,11 @@ def put(self, key: str, value: Any, *, note: Optional[str] = None): The key to associate with ``value`` in the Project. value : Any The value to associate with ``key`` in the Project. - note : str or None, optional + note : str, optional A note to attach with the item. + display_as : str, optional + Used in combination with a string value, it customizes the way the value is + displayed in the interface. Raises ------ @@ -99,14 +109,14 @@ def put(self, key: str, value: Any, *, note: Optional[str] = None): if not isinstance(key, str): raise TypeError(f"Key must be a string (found '{type(key)}')") - item = object_to_item(value) - - if note is not None: - if not isinstance(note, str): - raise TypeError(f"Note must be a string (found '{type(note)}')") - item.note = note - - self.item_repository.put_item(key, item) + self.item_repository.put_item( + key, + object_to_item( + value, + note=note, + display_as=display_as, + ), + ) def get(self, key: str) -> Any: """Get the value corresponding to ``key`` from the Project. diff --git a/skore/tests/unit/project/test_display_as.py b/skore/tests/unit/project/test_display_as.py new file mode 100644 index 000000000..50833bc5a --- /dev/null +++ b/skore/tests/unit/project/test_display_as.py @@ -0,0 +1,35 @@ +import pytest +from skore.persistence.item import MediaItem, PrimitiveItem + + +@pytest.fixture(autouse=True) +def monkeypatch_datetime(monkeypatch, MockDatetime): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + + +def test_str_without_display_as(in_memory_project, mock_nowstr): + in_memory_project.put("key", "") + + item = in_memory_project.item_repository.get_item("key") + + assert isinstance(item, PrimitiveItem) + assert item.primitive == "" + + +def test_str_with_display_as(in_memory_project, mock_nowstr): + in_memory_project.put("key", "", display_as="MARKDOWN") + + item = in_memory_project.item_repository.get_item("key") + + assert isinstance(item, MediaItem) + assert item.media_bytes == b"" + assert item.media_encoding == "utf-8" + assert item.media_type == "text/markdown" + + +def test_exception(in_memory_project, mock_nowstr): + with pytest.raises(TypeError): + in_memory_project.put("key", 1, display_as="MARKDOWN") + + with pytest.raises(ValueError): + in_memory_project.put("key", "", display_as="") From daebfe6b1f590d5f580f2e2fcb7ab678a6bedc30 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 15 Jan 2025 10:22:25 +0100 Subject: [PATCH 15/22] Update examples --- examples/model_evaluation/plot_cross_validate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/model_evaluation/plot_cross_validate.py b/examples/model_evaluation/plot_cross_validate.py index a47b6b74f..cb0ab5e3c 100644 --- a/examples/model_evaluation/plot_cross_validate.py +++ b/examples/model_evaluation/plot_cross_validate.py @@ -155,15 +155,16 @@ reporter.plots.scores # %% -# We can also access the plot after we have stored the ``CrossValidationReporter``: -my_project.put("cross_validation_regression", reporter) -cv_item = my_project.get_item("cross_validation_regression") # TO CHANGE /!\ -cv_item.plots["Scores"] +# We can put the reporter in the project, and retrieve it as is: +my_project.put("cross_validation_reporter", reporter) + +reporter = my_project.get("cross_validation_reporter") +reporter.plots.scores # %% # .. note:: # -# If we put a cross-validation item in a skore project, we get some nice +# If we put a cross-validation reporter in a skore project, we get some nice # information in the UI: # # .. image:: https://media.githubusercontent.com/media/probabl-ai/skore/main/sphinx/_static/images/2024_12_12_skore_demo_comp.gif From 3084bb039ca17d92f24b7ad32c23543181b1ff60 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 15 Jan 2025 12:48:59 +0100 Subject: [PATCH 16/22] [skip ci] refactorize `get_item_versions` to be item agnostic --- .../getting_started/plot_tracking_items.py | 23 +++---- skore/src/skore/project/project.py | 67 ++++++++++--------- skore/src/skore/ui/project_routes.py | 13 ++-- skore/tests/integration/ui/test_ui.py | 2 +- skore/tests/unit/project/test_project.py | 65 +++++++++++------- 5 files changed, 93 insertions(+), 77 deletions(-) diff --git a/examples/getting_started/plot_tracking_items.py b/examples/getting_started/plot_tracking_items.py index ec8e1a5fb..9eefc83e2 100644 --- a/examples/getting_started/plot_tracking_items.py +++ b/examples/getting_started/plot_tracking_items.py @@ -73,17 +73,14 @@ # We retrieve the history of the ``my_int`` item: # %% -item_histories = my_project.get_item_versions("my_int") # TO CHANGE /!\ +history = my_project.get("my_int", latest=False, metadata=True) # %% # We can print the first history (first iteration) of this item: # %% -passed_item = item_histories[0] -print(passed_item) -print(passed_item.primitive) -print(passed_item.created_at) -print(passed_item.updated_at) + +print(history[0]) # %% # Let us construct a dataframe with the values and last updated times: @@ -92,13 +89,13 @@ import numpy as np import pandas as pd -list_primitive, list_created_at, list_updated_at = zip( - *[(elem.primitive, elem.created_at, elem.updated_at) for elem in item_histories] +list_value, list_created_at, list_updated_at = zip( + *[(version["value"], history[0]["date"], version["date"]) for version in history] ) df_track = pd.DataFrame( { - "primitive": list_primitive, + "value": list_value, "created_at": list_created_at, "updated_at": list_updated_at, } @@ -111,9 +108,9 @@ # :language: python # # Notice that the ``created_at`` dates are the same for all iterations because they -# correspond to the same item, but the ``updated_at`` dates are spaced by 0.1 second -# (approximately) as we used :python:`time.sleep(0.1)` between each -# :func:`~skore.Project.put`. +# correspond to the date of the first version of the item, but the ``updated_at`` dates +# are spaced by 0.1 second (approximately) as we used :python:`time.sleep(0.1)` between +# each :func:`~skore.Project.put`. # %% # We can now track the value of the item over time: @@ -124,7 +121,7 @@ fig = px.line( df_track, x="iteration_number", - y="primitive", + y="value", hover_data=df_track.columns, markers=True, ) diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index 365a7c97e..6271d8785 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from skore.persistence import ( - Item, ItemRepository, View, ViewRepository, @@ -118,55 +117,57 @@ def put( ), ) - def get(self, key: str) -> Any: - """Get the value corresponding to ``key`` from the Project. - - Parameters - ---------- - key : str - The key corresponding to the item to get. - - Raises - ------ - KeyError - If the key does not correspond to any item. - """ - return item_to_object(self.item_repository.get_item(key)) - - def get_item_versions(self, key: str) -> list[Item]: - """ - Get all the versions of an item associated with ``key`` from the Project. - - The list is ordered from oldest to newest :func:`~skore.Project.put` date. + def get(self, key, *, latest=True, metadata=False): + """Get the value associated to ``key`` from the Project. Parameters ---------- key : str The key corresponding to the item to get. + latest : boolean, optional + Get the latest value or all the values associated to ``key``, default True. + metadata : boolean, optional + Get the metadata in addition of the value, default False. Returns ------- - list[Item] - The list of items corresponding to ``key``. + value : any + Value associated to ``key``, when latest=True and metadata=False. + value_and_metadata : dict + Value associated to ``key`` with its metadata, when latest=True and metadata=True. + list_of_values : list[any] + Values associated to ``key``, ordered by date, when latest=False. + list_of_values_and_metadata : list[dict] + Values associated to ``key`` with their metadata, ordered by date, when + latest=False and metadata=False. Raises ------ KeyError - If the key does not correspond to any item. + If the key is not in the project. """ - return self.item_repository.get_item_versions(key) + if not metadata: - def list_item_keys(self) -> list[str]: - """List all item keys in the Project. + def dto(item): + return item_to_object(item) - Returns - ------- - list[str] - The list of item keys. The list is empty if there is no item. - """ + else: + + def dto(item): + return { + "value": item_to_object(item), + "date": item.updated_at, + "note": item.note, + } + + if latest: + return dto(self.item_repository.get_item(key)) + return list(map(dto, self.item_repository.get_item_versions(key))) + + def keys(self) -> list[str]: return self.item_repository.keys() - def delete_item(self, key: str): + def delete(self, key: str): """Delete the item corresponding to ``key`` from the Project. Parameters diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py index 79eff5ce6..e3ac728cb 100644 --- a/skore/src/skore/ui/project_routes.py +++ b/skore/src/skore/ui/project_routes.py @@ -49,12 +49,15 @@ def __item_as_serializable(name: str, item: Item) -> SerializableItem: def __project_as_serializable(project: Project) -> SerializableProject: items = { key: [ - __item_as_serializable(key, item) for item in project.get_item_versions(key) + __item_as_serializable(key, item) + for item in project.item_repository.get_item_versions(key) ] - for key in project.list_item_keys() + for key in project.item_repository.keys() } - views = {key: project.get_view(key).layout for key in project.list_view_keys()} + views = { + key: project.get_view(key).layout for key in project.view_repository.keys() + } return SerializableProject( items=items, @@ -112,8 +115,8 @@ async def get_activity( return sorted( ( __item_as_serializable(key, version) - for key in project.list_item_keys() - for version in project.get_item_versions(key) + for key in project.item_repository.keys() + for version in project.item_repository.get_item_versions(key) if datetime.fromisoformat(version.updated_at) > after ), key=operator.attrgetter("updated_at"), diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index 44b15d7a3..9d9a481bd 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -33,7 +33,7 @@ def test_get_items(client, in_memory_project): in_memory_project.put("test", "version_1") in_memory_project.put("test", "version_2") - items = in_memory_project.get_item_versions("test") + items = in_memory_project.item_repository.get_item_versions("test") response = client.get("/api/project/items") assert response.status_code == 200 diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py index f5a842011..48b657534 100644 --- a/skore/tests/unit/project/test_project.py +++ b/skore/tests/unit/project/test_project.py @@ -19,6 +19,11 @@ from skore.project.create import _create, _validate_project_name +@pytest.fixture(autouse=True) +def monkeypatch_datetime(monkeypatch, MockDatetime): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + + def test_put_string_item(in_memory_project): in_memory_project.put("string_item", "Hello, World!") assert in_memory_project.get("string_item") == "Hello, World!" @@ -145,19 +150,19 @@ def test_put(in_memory_project): in_memory_project.put("key3", 3) in_memory_project.put("key4", 4) - assert in_memory_project.list_item_keys() == ["key1", "key2", "key3", "key4"] + assert in_memory_project.keys() == ["key1", "key2", "key3", "key4"] def test_put_kwargs(in_memory_project): in_memory_project.put(key="key1", value=1) - assert in_memory_project.list_item_keys() == ["key1"] + assert in_memory_project.keys() == ["key1"] def test_put_wrong_key_type(in_memory_project): with pytest.raises(TypeError): in_memory_project.put(key=2, value=1) - assert in_memory_project.list_item_keys() == [] + assert in_memory_project.keys() == [] def test_put_twice(in_memory_project): @@ -167,39 +172,49 @@ def test_put_twice(in_memory_project): assert in_memory_project.get("key2") == 5 -def test_get(in_memory_project): - in_memory_project.put("key1", 1) - assert in_memory_project.get("key1") == 1 - - with pytest.raises(KeyError): - in_memory_project.get("key2") - - -def test_get_item_versions(in_memory_project): - in_memory_project.put("key", 1) - in_memory_project.put("key", 2) +def test_get(in_memory_project, mock_nowstr): + in_memory_project.put("key", 1, note="1") + in_memory_project.put("key", 2, note="2") - items = in_memory_project.get_item_versions("key") + assert in_memory_project.get("key") == 2 + assert in_memory_project.get("key", latest=True, metadata=False) == 2 + assert in_memory_project.get("key", latest=False, metadata=False) == [1, 2] + assert in_memory_project.get("key", latest=False, metadata=True) == [ + { + "value": 1, + "date": mock_nowstr, + "note": "1", + }, + { + "value": 2, + "date": mock_nowstr, + "note": "2", + }, + ] + assert in_memory_project.get("key", latest=True, metadata=True) == { + "value": 2, + "date": mock_nowstr, + "note": "2", + } - assert len(items) == 2 - assert items[0].primitive == 1 - assert items[1].primitive == 2 + with pytest.raises(KeyError): + in_memory_project.get("") def test_delete(in_memory_project): in_memory_project.put("key1", 1) - in_memory_project.delete_item("key1") + in_memory_project.delete("key1") - assert in_memory_project.list_item_keys() == [] + assert in_memory_project.keys() == [] with pytest.raises(KeyError): - in_memory_project.delete_item("key2") + in_memory_project.delete("key2") def test_keys(in_memory_project): in_memory_project.put("key1", 1) in_memory_project.put("key2", 2) - assert in_memory_project.list_item_keys() == ["key1", "key2"] + assert in_memory_project.keys() == ["key1", "key2"] def test_view(in_memory_project): @@ -222,7 +237,7 @@ def test_put_several_complex(in_memory_project): in_memory_project.put("a", int) in_memory_project.put("b", float) - assert in_memory_project.list_item_keys() == ["a", "b"] + assert in_memory_project.keys() == ["a", "b"] def test_put_key_is_a_tuple(in_memory_project): @@ -230,7 +245,7 @@ def test_put_key_is_a_tuple(in_memory_project): with pytest.raises(TypeError): in_memory_project.put(("a", "foo"), ("b", "bar")) - assert in_memory_project.list_item_keys() == [] + assert in_memory_project.keys() == [] def test_put_key_is_a_set(in_memory_project): @@ -238,7 +253,7 @@ def test_put_key_is_a_set(in_memory_project): with pytest.raises(TypeError): in_memory_project.put(set(), "hello") - assert in_memory_project.list_item_keys() == [] + assert in_memory_project.keys() == [] def test_put_wrong_key_and_value_raise(in_memory_project): From 602b8941b2d248d3c0f758f4ef0108d338f5b4fa Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 15 Jan 2025 13:03:17 +0100 Subject: [PATCH 17/22] [skip ci] Fix examples and sphinx --- examples/getting_started/plot_working_with_projects.py | 8 ++++---- sphinx/api.rst | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/getting_started/plot_working_with_projects.py b/examples/getting_started/plot_working_with_projects.py index 4897b4738..ef65603de 100644 --- a/examples/getting_started/plot_working_with_projects.py +++ b/examples/getting_started/plot_working_with_projects.py @@ -71,20 +71,20 @@ # see :ref:`example_tracking_items`. # %% -# By using the :func:`~skore.Project.delete_item` method, we can also delete an object +# By using the :func:`~skore.Project.delete` method, we can also delete an object # so that our skore UI does not become cluttered: # %% my_project.put("my_int_2", 10) # %% -my_project.delete_item("my_int_2") +my_project.delete("my_int_2") # %% # We can display all the keys in our project: # %% -my_project.list_item_keys() +my_project.keys() # %% # Storing strings and texts @@ -127,7 +127,7 @@ def my_func(x): my_project.put( "my_string_3", "

Title

bold, italic, etc.

", - display_as="html", + display_as="HTML", ) # %% diff --git a/sphinx/api.rst b/sphinx/api.rst index b5d5527d4..09a3bfd6a 100644 --- a/sphinx/api.rst +++ b/sphinx/api.rst @@ -20,7 +20,6 @@ These functions and classes are meant for managing a Project. :caption: Managing a project Project - item.primitive_item.PrimitiveItem open Get assistance when developing ML/DS projects @@ -35,7 +34,6 @@ These functions and classes enhance scikit-learn's ones. train_test_split CrossValidationReporter - item.cross_validation_item.CrossValidationItem Report for a single estimator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 71aec33a9c85dc78df37f8d24c3e920cdc1e878d Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 15 Jan 2025 14:59:28 +0100 Subject: [PATCH 18/22] [skip ci] Hide view API --- skore/src/skore/project/create.py | 2 +- skore/src/skore/project/project.py | 50 ------------------------ skore/src/skore/ui/project_routes.py | 7 ++-- skore/tests/integration/ui/test_ui.py | 2 +- skore/tests/unit/project/test_project.py | 17 -------- 5 files changed, 6 insertions(+), 72 deletions(-) diff --git a/skore/src/skore/project/create.py b/skore/src/skore/project/create.py index e5a7e92de..fa6eac7f2 100644 --- a/skore/src/skore/project/create.py +++ b/skore/src/skore/project/create.py @@ -143,7 +143,7 @@ def _create( ) from e p = _load(project_directory) - p.put_view("default", View(layout=[])) + p.view_repository.put_view("default", View(layout=[])) console.rule("[bold cyan]skore[/bold cyan]") console.print(f"Project file '{project_directory}' was successfully created.") diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index 6271d8785..acba73822 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from skore.persistence import ( ItemRepository, - View, ViewRepository, ) @@ -182,55 +181,6 @@ def delete(self, key: str): """ self.item_repository.delete_item(key) - def put_view(self, key: str, view: View): - """Add a view to the Project.""" - self.view_repository.put_view(key, view) - - def get_view(self, key: str) -> View: - """Get the view corresponding to ``key`` from the Project. - - Parameters - ---------- - key : str - The key of the item to get. - - Returns - ------- - View - The view corresponding to ``key``. - - Raises - ------ - KeyError - If the key does not correspond to any view. - """ - return self.view_repository.get_view(key) - - def delete_view(self, key: str): - """Delete the view corresponding to ``key`` from the Project. - - Parameters - ---------- - key : str - The key corresponding to the view to delete. - - Raises - ------ - KeyError - If the key does not correspond to any view. - """ - return self.view_repository.delete_view(key) - - def list_view_keys(self) -> list[str]: - """List all view keys in the Project. - - Returns - ------- - list[str] - The list of view keys. The list is empty if there is no view. - """ - return self.view_repository.keys() - def set_note(self, key: str, message: str, *, version=-1): """Attach a note to key ``key``. diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py index e3ac728cb..5606bbd39 100644 --- a/skore/src/skore/ui/project_routes.py +++ b/skore/src/skore/ui/project_routes.py @@ -56,7 +56,8 @@ def __project_as_serializable(project: Project) -> SerializableProject: } views = { - key: project.get_view(key).layout for key in project.view_repository.keys() + key: project.view_repository.get_view(key).layout + for key in project.view_repository.keys() } return SerializableProject( @@ -81,7 +82,7 @@ async def put_view(request: Request, key: str, layout: Layout): project: Project = request.app.state.project view = View(layout=layout) - project.put_view(key, view) + project.view_repository.put_view(key, view) return __project_as_serializable(project) @@ -92,7 +93,7 @@ async def delete_view(request: Request, key: str): project: Project = request.app.state.project try: - project.delete_view(key) + project.view_repository.delete_view(key) except KeyError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="View not found" diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index 9d9a481bd..c5170a958 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -60,7 +60,7 @@ def test_put_view_layout(client): def test_delete_view(client, in_memory_project): - in_memory_project.put_view("hello", View(layout=[])) + in_memory_project.view_repository.put_view("hello", View(layout=[])) response = client.delete("/api/project/views?key=hello") assert response.status_code == 202 diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py index 48b657534..cf2f7225e 100644 --- a/skore/tests/unit/project/test_project.py +++ b/skore/tests/unit/project/test_project.py @@ -15,7 +15,6 @@ InvalidProjectNameError, ProjectCreationError, ) -from skore.persistence.view.view import View from skore.project.create import _create, _validate_project_name @@ -217,22 +216,6 @@ def test_keys(in_memory_project): assert in_memory_project.keys() == ["key1", "key2"] -def test_view(in_memory_project): - layout = ["key1", "key2"] - - view = View(layout=layout) - - in_memory_project.put_view("view", view) - assert in_memory_project.get_view("view") == view - - -def test_list_view_keys(in_memory_project): - view = View(layout=[]) - - in_memory_project.put_view("view", view) - assert in_memory_project.list_view_keys() == ["view"] - - def test_put_several_complex(in_memory_project): in_memory_project.put("a", int) in_memory_project.put("b", float) From a85d2f18b06130c95e3245f2c6dc9f9e66ed206c Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 15 Jan 2025 15:10:04 +0100 Subject: [PATCH 19/22] [skip ci] Fix linter --- .../persistence/repository/item_repository.py | 14 ++++++++++- .../persistence/repository/view_repository.py | 16 +++++++++++-- skore/src/skore/project/project.py | 23 ++++++++++++++++++- skore/src/skore/ui/project_routes.py | 6 ++--- 4 files changed, 52 insertions(+), 7 deletions(-) diff --git a/skore/src/skore/persistence/repository/item_repository.py b/skore/src/skore/persistence/repository/item_repository.py index 601c7b20b..97e29e174 100644 --- a/skore/src/skore/persistence/repository/item_repository.py +++ b/skore/src/skore/persistence/repository/item_repository.py @@ -6,6 +6,7 @@ from __future__ import annotations +from collections.abc import Iterator from typing import TYPE_CHECKING, Union import skore.persistence.item @@ -129,10 +130,21 @@ def keys(self) -> list[str]: Returns ------- list[str] - A list of all keys in the storage. + A list of all keys. """ return list(self.storage.keys()) + def __iter__(self) -> Iterator[str]: + """ + Yield the keys of items stored in the repository. + + Returns + ------- + Iterator[str] + An iterator yielding all keys. + """ + yield from self.storage + def set_item_note(self, key: str, message: str, *, version=-1): """Attach a note to key ``key``. diff --git a/skore/src/skore/persistence/repository/view_repository.py b/skore/src/skore/persistence/repository/view_repository.py index 8ec7dcfc9..02f42baa8 100644 --- a/skore/src/skore/persistence/repository/view_repository.py +++ b/skore/src/skore/persistence/repository/view_repository.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Iterator from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -65,11 +66,22 @@ def delete_view(self, key: str): def keys(self) -> list[str]: """ - Get all keys of items stored in the repository. + Get all keys of views stored in the repository. Returns ------- list[str] - A list of all keys in the storage. + A list of all keys. """ return list(self.storage.keys()) + + def __iter__(self) -> Iterator[str]: + """ + Yield the keys of views stored in the repository. + + Returns + ------- + Iterator[str] + An iterator yielding all keys. + """ + yield from self.storage diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index acba73822..d485e3f72 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from collections.abc import Iterator from typing import TYPE_CHECKING, Any, Literal, Optional, Union from skore.persistence.item import item_to_object, object_to_item @@ -133,7 +134,8 @@ def get(self, key, *, latest=True, metadata=False): value : any Value associated to ``key``, when latest=True and metadata=False. value_and_metadata : dict - Value associated to ``key`` with its metadata, when latest=True and metadata=True. + Value associated to ``key`` with its metadata, when latest=True and + metadata=True. list_of_values : list[any] Values associated to ``key``, ordered by date, when latest=False. list_of_values_and_metadata : list[dict] @@ -164,8 +166,27 @@ def dto(item): return list(map(dto, self.item_repository.get_item_versions(key))) def keys(self) -> list[str]: + """ + Get all keys of items stored in the project. + + Returns + ------- + list[str] + A list of all keys. + """ return self.item_repository.keys() + def __iter__(self) -> Iterator[str]: + """ + Yield the keys of items stored in the project. + + Returns + ------- + Iterator[str] + An iterator yielding all keys. + """ + yield from self.item_repository + def delete(self, key: str): """Delete the item corresponding to ``key`` from the Project. diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py index 5606bbd39..04d1e88e1 100644 --- a/skore/src/skore/ui/project_routes.py +++ b/skore/src/skore/ui/project_routes.py @@ -52,12 +52,12 @@ def __project_as_serializable(project: Project) -> SerializableProject: __item_as_serializable(key, item) for item in project.item_repository.get_item_versions(key) ] - for key in project.item_repository.keys() + for key in project.item_repository } views = { key: project.view_repository.get_view(key).layout - for key in project.view_repository.keys() + for key in project.view_repository } return SerializableProject( @@ -116,7 +116,7 @@ async def get_activity( return sorted( ( __item_as_serializable(key, version) - for key in project.item_repository.keys() + for key in project.item_repository for version in project.item_repository.get_item_versions(key) if datetime.fromisoformat(version.updated_at) > after ), From fe9c497f613ae7403d16c3f311ba384871bb4f2e Mon Sep 17 00:00:00 2001 From: Thomas S Date: Thu, 16 Jan 2025 15:41:32 +0100 Subject: [PATCH 20/22] [skip ci] Leave pillow from MediaItem to PillowImageItem --- skore/src/skore/persistence/item/__init__.py | 5 + .../item/cross_validation_reporter_item.py | 10 +- .../src/skore/persistence/item/media_item.py | 54 ++-------- .../src/skore/persistence/item/pickle_item.py | 3 +- .../persistence/item/pillow_image_item.py | 102 ++++++++++++++++++ skore/tests/integration/ui/test_ui.py | 48 +++++++-- skore/tests/unit/item/test_media_item.py | 18 ---- .../tests/unit/item/test_pillow_image_item.py | 56 ++++++++++ skore/tests/unit/project/test_project.py | 18 ++-- 9 files changed, 222 insertions(+), 92 deletions(-) create mode 100644 skore/src/skore/persistence/item/pillow_image_item.py create mode 100644 skore/tests/unit/item/test_pillow_image_item.py diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index aa9ea9319..de24446be 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -13,6 +13,7 @@ from .pandas_dataframe_item import PandasDataFrameItem from .pandas_series_item import PandasSeriesItem from .pickle_item import PickleItem +from .pillow_image_item import PillowImageItem from .polars_dataframe_item import PolarsDataFrameItem from .polars_series_item import PolarsSeriesItem from .primitive_item import PrimitiveItem @@ -50,6 +51,7 @@ def object_to_item( MediaItem, SkrubTableReportItem, CrossValidationReporterItem, + PillowImageItem, ): with suppress(ImportError, ItemTypeError): # ImportError: @@ -91,6 +93,8 @@ def item_to_object(item: Item) -> Any: return item.reporter elif isinstance(item, MediaItem): return item.media_bytes + elif isinstance(item, PillowImageItem): + return item.image elif isinstance(item, PickleItem): return item.object else: @@ -105,6 +109,7 @@ def item_to_object(item: Item) -> Any: "PandasDataFrameItem", "PandasSeriesItem", "PickleItem", + "PillowImageItem", "PolarsDataFrameItem", "PolarsSeriesItem", "PrimitiveItem", diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py index ecad01237..775133e69 100644 --- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py +++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py @@ -14,7 +14,6 @@ import pickle import re import statistics -from functools import cached_property from typing import TYPE_CHECKING, Literal, Optional, TypedDict import numpy @@ -204,14 +203,9 @@ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationReporterIt if not isinstance(reporter, CrossValidationReporter): raise ItemTypeError(f"Type '{reporter.__class__}' is not supported.") - instance = cls(pickle.dumps(reporter)) + return cls(pickle.dumps(reporter)) - # add reporter as cached property - instance.reporter = reporter - - return instance - - @cached_property + @property def reporter(self) -> CrossValidationReporter: """The CrossValidationReporter from the persistence.""" return pickle.loads(self.reporter_bytes) diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py index e57434743..e2a5cfb0a 100644 --- a/skore/src/skore/persistence/item/media_item.py +++ b/skore/src/skore/persistence/item/media_item.py @@ -1,6 +1,6 @@ """MediaItem. -This module defines the MediaItem class, which represents media items. +This module defines the MediaItem class, used to persist media. """ from __future__ import annotations @@ -15,12 +15,11 @@ if TYPE_CHECKING: from altair.vegalite.v5.schema.core import TopLevelSpec as Altair from matplotlib.figure import Figure as Matplotlib - from PIL.Image import Image as Pillow from plotly.basedatatypes import BaseFigure as Plotly def lazy_is_instance(object: Any, cls_fullname: str) -> bool: - """Return True if object is an instance of a class named `cls_fullname`.""" + """Return True if object is an instance of `cls_fullname`.""" return cls_fullname in { f"{cls.__module__}.{cls.__name__}" for cls in object.__class__.__mro__ } @@ -65,8 +64,8 @@ def __init__( The creation timestamp in ISO format. updated_at : str, optional The last update timestamp in ISO format. - note : Union[str, None] - An optional note. + note : str, optional + A note. """ super().__init__(created_at, updated_at, note) @@ -75,12 +74,7 @@ def __init__( self.media_type = media_type def as_serializable_dict(self): - """Get a serializable dict from the item. - - Derived class must call their super implementation - and merge the result with their output. - """ - d = super().as_serializable_dict() + """Return item as a JSONable dict to export to frontend.""" if "text" in self.media_type: value = self.media_bytes.decode(encoding=self.media_encoding) media_type = f"{self.media_type}" @@ -88,13 +82,10 @@ def as_serializable_dict(self): value = base64.b64encode(self.media_bytes).decode() media_type = f"{self.media_type};base64" - d.update( - { - "media_type": media_type, - "value": value, - } - ) - return d + return super().as_serializable_dict() | { + "media_type": media_type, + "value": value, + } @classmethod def factory(cls, media, *args, **kwargs): @@ -127,8 +118,6 @@ def factory(cls, media, *args, **kwargs): return cls.factory_altair(media, *args, **kwargs) if lazy_is_instance(media, "matplotlib.figure.Figure"): return cls.factory_matplotlib(media, *args, **kwargs) - if lazy_is_instance(media, "PIL.Image.Image"): - return cls.factory_pillow(media, *args, **kwargs) if lazy_is_instance(media, "plotly.basedatatypes.BaseFigure"): return cls.factory_plotly(media, *args, **kwargs) @@ -237,31 +226,6 @@ def factory_matplotlib(cls, media: Matplotlib) -> MediaItem: media_type="image/svg+xml", ) - @classmethod - def factory_pillow(cls, media: Pillow) -> MediaItem: - """ - Create a new MediaItem instance from a Pillow image. - - Parameters - ---------- - media : Pillow - The Pillow image to store. - - Returns - ------- - MediaItem - A new MediaItem instance. - """ - with BytesIO() as stream: - media.save(stream, format="png") - media_bytes = stream.getvalue() - - return cls( - media_bytes=media_bytes, - media_encoding="utf-8", - media_type="image/png", - ) - @classmethod def factory_plotly(cls, media: Plotly) -> MediaItem: """ diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py index d2678f545..a80fea3de 100644 --- a/skore/src/skore/persistence/item/pickle_item.py +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -6,7 +6,6 @@ from __future__ import annotations -from functools import cached_property from pickle import dumps, loads from typing import Any, Optional @@ -47,7 +46,7 @@ def __init__( self.pickle_bytes = pickle_bytes - @cached_property + @property def object(self) -> Any: """The object from the persistence.""" return loads(self.pickle_bytes) diff --git a/skore/src/skore/persistence/item/pillow_image_item.py b/skore/src/skore/persistence/item/pillow_image_item.py new file mode 100644 index 000000000..c67cde339 --- /dev/null +++ b/skore/src/skore/persistence/item/pillow_image_item.py @@ -0,0 +1,102 @@ +"""PillowImageItem. + +This module defines the PillowImageItem class, used to persist Pillow images. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from .item import Item, ItemTypeError +from .media_item import lazy_is_instance + +if TYPE_CHECKING: + import PIL.Image + + +class PillowImageItem(Item): + """A class used to persist a Pillow image.""" + + def __init__( + self, + image_bytes: bytes, + image_mode: str, + image_size: tuple[int], + created_at: Optional[str] = None, + updated_at: Optional[str] = None, + note: Optional[str] = None, + ): + """ + Initialize a PillowImageItem. + + Parameters + ---------- + image_bytes : bytes + The raw bytes of the Pillow image. + image_mode : str + The image mode. + image_size : tuple[int] + The image size. + created_at : str, optional + The creation timestamp in ISO format. + updated_at : str, optional + The last update timestamp in ISO format. + note : str, optional + A note. + """ + super().__init__(created_at, updated_at, note) + + self.image_bytes = image_bytes + self.image_mode = image_mode + self.image_size = image_size + + @classmethod + def factory(cls, image: PIL.Image.Image) -> PillowImageItem: + """ + Create a new PillowImageItem instance from a Pillow image. + + Parameters + ---------- + image : PIL.Image.Image + The Pillow image to store. + + Returns + ------- + PillowImageItem + A new PillowImageItem instance. + """ + if not lazy_is_instance(image, "PIL.Image.Image"): + raise ItemTypeError(f"Type '{image.__class__}' is not supported.") + + return cls( + image_bytes=image.tobytes(), + image_mode=image.mode, + image_size=image.size, + ) + + @property + def image(self) -> PIL.Image.Image: + """The image from the persistence.""" + import PIL.Image + + return PIL.Image.frombytes( + mode=self.image_mode, + size=self.image_size, + data=self.image_bytes, + ) + + def as_serializable_dict(self): + """Return item as a JSONable dict to export to frontend.""" + import base64 + import io + + with io.BytesIO() as stream: + self.image.save(stream, format="png") + + png_bytes = stream.getvalue() + png_bytes_b64 = base64.b64encode(png_bytes).decode() + + return super().as_serializable_dict() | { + "media_type": "image/png;base64", + "value": png_bytes_b64, + } diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index c5170a958..b9470a50d 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -1,4 +1,6 @@ +import base64 import datetime +import io import json import numpy @@ -20,6 +22,11 @@ def client(in_memory_project): return TestClient(app=create_app(project=in_memory_project)) +@pytest.fixture +def monkeypatch_datetime(monkeypatch, MockDatetime): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + + def test_app_state(client): assert client.app.state.project is not None @@ -130,19 +137,42 @@ def test_serialize_sklearn_estimator(client, in_memory_project): assert project["items"]["estimator"][0]["value"] is not None -def test_serialize_media_item(client, in_memory_project): - imarray = numpy.random.rand(100, 100, 3) * 255 - img = Image.fromarray(imarray.astype("uint8")).convert("RGBA") - in_memory_project.put("img", img) +def test_serialize_pillow_item( + client, + in_memory_project, + monkeypatch_datetime, + mock_nowstr, +): + image_array = numpy.random.rand(100, 100, 3) * 255 + image = Image.fromarray(image_array.astype("uint8")).convert("RGBA") + + with io.BytesIO() as stream: + image.save(stream, format="png") - html = "

éપUœALDXIWDŸΩΩ

" - in_memory_project.put("html", html) + png_bytes = stream.getvalue() + png_bytes_b64 = base64.b64encode(png_bytes).decode() + in_memory_project.put("image", image) response = client.get("/api/project/items") + assert response.status_code == 200 - project = response.json() - assert "image" in project["items"]["img"][0]["media_type"] - assert project["items"]["html"][0]["value"] == html + assert response.json() == { + "views": {}, + "items": { + "image": [ + { + "name": "image", + "media_type": "image/png;base64", + "value": png_bytes_b64, + "updated_at": mock_nowstr, + "created_at": mock_nowstr, + } + ] + }, + } + + +def test_serialize_media_item(client, in_memory_project): ... @pytest.fixture diff --git a/skore/tests/unit/item/test_media_item.py b/skore/tests/unit/item/test_media_item.py index 320da8d97..0794b50bc 100644 --- a/skore/tests/unit/item/test_media_item.py +++ b/skore/tests/unit/item/test_media_item.py @@ -1,8 +1,5 @@ -import io - import altair import matplotlib.pyplot -import PIL as pillow import plotly.graph_objects as go import pytest from skore.persistence.item import ItemTypeError, MediaItem @@ -62,21 +59,6 @@ def test_factory_matplotlib(self, mock_nowstr): assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr - def test_factory_pillow(self, mock_nowstr): - image = pillow.Image.new("RGB", (100, 100), color="red") - - with io.BytesIO() as stream: - image.save(stream, format="png") - image_bytes = stream.getvalue() - - item = MediaItem.factory(image) - - assert item.media_bytes == image_bytes - assert item.media_encoding == "utf-8" - assert item.media_type == "image/png" - assert item.created_at == mock_nowstr - assert item.updated_at == mock_nowstr - def test_factory_plotly(self, mock_nowstr): figure = go.Figure(data=[go.Bar(x=[1, 2, 3], y=[1, 3, 2])]) figure_bytes = figure.to_json().encode("utf-8") diff --git a/skore/tests/unit/item/test_pillow_image_item.py b/skore/tests/unit/item/test_pillow_image_item.py new file mode 100644 index 000000000..3e43fa94b --- /dev/null +++ b/skore/tests/unit/item/test_pillow_image_item.py @@ -0,0 +1,56 @@ +import base64 +import io + +import PIL.Image +import pytest +from skore.persistence.item import ItemTypeError, PillowImageItem + + +class TestPillowImageItem: + @pytest.fixture(autouse=True) + def monkeypatch_datetime(self, monkeypatch, MockDatetime): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + + def test_factory(self, mock_nowstr): + image = PIL.Image.new("RGB", (100, 100), color="red") + item = PillowImageItem.factory(image) + + assert item.image_bytes == image.tobytes() + assert item.image_mode == image.mode + assert item.image_size == image.size + assert item.created_at == mock_nowstr + assert item.updated_at == mock_nowstr + + def test_factory_exception(self): + with pytest.raises(ItemTypeError): + PillowImageItem.factory(None) + + def test_image(self): + image = PIL.Image.new("RGB", (100, 100), color="red") + item1 = PillowImageItem.factory(image) + item2 = PillowImageItem( + image_bytes=image.tobytes(), + image_mode=image.mode, + image_size=image.size, + ) + + assert item1.image == image + assert item2.image == image + + def test_as_serializable_dict(self, mock_nowstr): + image = PIL.Image.new("RGB", (100, 100), color="red") + item = PillowImageItem.factory(image) + + with io.BytesIO() as stream: + image.save(stream, format="png") + + png_bytes = stream.getvalue() + png_bytes_b64 = base64.b64encode(png_bytes).decode() + + assert item.as_serializable_dict() == { + "updated_at": mock_nowstr, + "created_at": mock_nowstr, + "note": None, + "media_type": "image/png;base64", + "value": png_bytes_b64, + } diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py index cf2f7225e..eb9790169 100644 --- a/skore/tests/unit/project/test_project.py +++ b/skore/tests/unit/project/test_project.py @@ -1,5 +1,3 @@ -from io import BytesIO - import altair import numpy import numpy.testing @@ -124,14 +122,14 @@ def test_put_vega_chart(in_memory_project): def test_put_pil_image(in_memory_project): - # Add a PIL Image - pil_image = Image.new("RGB", (100, 100), color="red") - with BytesIO() as output: - # FIXME: Not JPEG! - pil_image.save(output, format="jpeg") - - in_memory_project.put("pil_image", pil_image) # MediaItem (PNG) - assert isinstance(in_memory_project.get("pil_image"), bytes) + image1 = Image.new("RGB", (100, 100), color="red") + image2 = Image.new("RGBA", (150, 150), color="blue") + + in_memory_project.put("image1", image1) + in_memory_project.put("image2", image2) + + assert in_memory_project.get("image1") == image1 + assert in_memory_project.get("image2") == image2 def test_put_rf_model(in_memory_project, monkeypatch): From 4550767b1b8bd367caf2c3d7f456dc34895bdcbd Mon Sep 17 00:00:00 2001 From: Thomas S Date: Thu, 16 Jan 2025 15:59:02 +0100 Subject: [PATCH 21/22] [skip ci] Allow to pass created_at/updated_at/note via factory to constructor in PillowImageItem and PickleItem --- skore/src/skore/persistence/item/pickle_item.py | 4 ++-- skore/src/skore/persistence/item/pillow_image_item.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/skore/src/skore/persistence/item/pickle_item.py b/skore/src/skore/persistence/item/pickle_item.py index a80fea3de..1074f474c 100644 --- a/skore/src/skore/persistence/item/pickle_item.py +++ b/skore/src/skore/persistence/item/pickle_item.py @@ -52,7 +52,7 @@ def object(self) -> Any: return loads(self.pickle_bytes) @classmethod - def factory(cls, object: Any) -> PickleItem: + def factory(cls, object: Any, /, **kwargs) -> PickleItem: """ Create a new PickleItem with any object. @@ -66,7 +66,7 @@ def factory(cls, object: Any) -> PickleItem: PickleItem A new PickleItem instance. """ - return cls(dumps(object)) + return cls(dumps(object), **kwargs) def as_serializable_dict(self): """Get a JSON serializable representation of the item.""" diff --git a/skore/src/skore/persistence/item/pillow_image_item.py b/skore/src/skore/persistence/item/pillow_image_item.py index c67cde339..20f53a5f8 100644 --- a/skore/src/skore/persistence/item/pillow_image_item.py +++ b/skore/src/skore/persistence/item/pillow_image_item.py @@ -51,7 +51,7 @@ def __init__( self.image_size = image_size @classmethod - def factory(cls, image: PIL.Image.Image) -> PillowImageItem: + def factory(cls, image: PIL.Image.Image, /, **kwargs) -> PillowImageItem: """ Create a new PillowImageItem instance from a Pillow image. @@ -72,6 +72,7 @@ def factory(cls, image: PIL.Image.Image) -> PillowImageItem: image_bytes=image.tobytes(), image_mode=image.mode, image_size=image.size, + **kwargs, ) @property From 42d9b86169faa132b02047cb2294c5d6f2b33257 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Thu, 16 Jan 2025 17:11:18 +0100 Subject: [PATCH 22/22] [skip ci] Leave plotly from MediaItem to PlotlyFigureItem --- skore/src/skore/persistence/item/__init__.py | 7 ++ .../src/skore/persistence/item/media_item.py | 28 ------- .../persistence/item/plotly_figure_item.py | 84 +++++++++++++++++++ skore/tests/integration/ui/test_ui.py | 32 +++++++ skore/tests/unit/item/test_media_item.py | 13 --- .../unit/item/test_plotly_figure_item.py | 51 +++++++++++ skore/tests/unit/project/test_project.py | 12 ++- 7 files changed, 185 insertions(+), 42 deletions(-) create mode 100644 skore/src/skore/persistence/item/plotly_figure_item.py create mode 100644 skore/tests/unit/item/test_plotly_figure_item.py diff --git a/skore/src/skore/persistence/item/__init__.py b/skore/src/skore/persistence/item/__init__.py index de24446be..3ed8abcd5 100644 --- a/skore/src/skore/persistence/item/__init__.py +++ b/skore/src/skore/persistence/item/__init__.py @@ -14,6 +14,7 @@ from .pandas_series_item import PandasSeriesItem from .pickle_item import PickleItem from .pillow_image_item import PillowImageItem +from .plotly_figure_item import PlotlyFigureItem from .polars_dataframe_item import PolarsDataFrameItem from .polars_series_item import PolarsSeriesItem from .primitive_item import PrimitiveItem @@ -52,6 +53,7 @@ def object_to_item( SkrubTableReportItem, CrossValidationReporterItem, PillowImageItem, + PlotlyFigureItem, ): with suppress(ImportError, ItemTypeError): # ImportError: @@ -74,6 +76,8 @@ def object_to_item( # `note` attribute dynamically. item.note = note + # -> to change in each class + return item @@ -95,6 +99,8 @@ def item_to_object(item: Item) -> Any: return item.media_bytes elif isinstance(item, PillowImageItem): return item.image + elif isinstance(item, PlotlyFigureItem): + return item.figure elif isinstance(item, PickleItem): return item.object else: @@ -110,6 +116,7 @@ def item_to_object(item: Item) -> Any: "PandasSeriesItem", "PickleItem", "PillowImageItem", + "PlotlyFigureItem", "PolarsDataFrameItem", "PolarsSeriesItem", "PrimitiveItem", diff --git a/skore/src/skore/persistence/item/media_item.py b/skore/src/skore/persistence/item/media_item.py index e2a5cfb0a..0ebb8858f 100644 --- a/skore/src/skore/persistence/item/media_item.py +++ b/skore/src/skore/persistence/item/media_item.py @@ -15,7 +15,6 @@ if TYPE_CHECKING: from altair.vegalite.v5.schema.core import TopLevelSpec as Altair from matplotlib.figure import Figure as Matplotlib - from plotly.basedatatypes import BaseFigure as Plotly def lazy_is_instance(object: Any, cls_fullname: str) -> bool: @@ -118,8 +117,6 @@ def factory(cls, media, *args, **kwargs): return cls.factory_altair(media, *args, **kwargs) if lazy_is_instance(media, "matplotlib.figure.Figure"): return cls.factory_matplotlib(media, *args, **kwargs) - if lazy_is_instance(media, "plotly.basedatatypes.BaseFigure"): - return cls.factory_plotly(media, *args, **kwargs) raise ItemTypeError(f"Type '{media.__class__}' is not supported.") @@ -225,28 +222,3 @@ def factory_matplotlib(cls, media: Matplotlib) -> MediaItem: media_encoding="utf-8", media_type="image/svg+xml", ) - - @classmethod - def factory_plotly(cls, media: Plotly) -> MediaItem: - """ - Create a new MediaItem instance from a Plotly figure. - - Parameters - ---------- - media : Plotly - The Plotly figure to store. - - Returns - ------- - MediaItem - A new MediaItem instance. - """ - import plotly.io - - media_bytes = plotly.io.to_json(media, engine="json").encode("utf-8") - - return cls( - media_bytes=media_bytes, - media_encoding="utf-8", - media_type="application/vnd.plotly.v1+json", - ) diff --git a/skore/src/skore/persistence/item/plotly_figure_item.py b/skore/src/skore/persistence/item/plotly_figure_item.py new file mode 100644 index 000000000..a705ca48a --- /dev/null +++ b/skore/src/skore/persistence/item/plotly_figure_item.py @@ -0,0 +1,84 @@ +"""PlotlyFigureItem. + +This module defines the PlotlyFigureItem class, used to persist Ploty figures. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from .item import Item, ItemTypeError +from .media_item import lazy_is_instance + +if TYPE_CHECKING: + import plotly.basedatatypes as plotly + + +class PlotlyFigureItem(Item): + """A class used to persist a Plotly figure.""" + + def __init__( + self, + figure_str: str, + created_at: Optional[str] = None, + updated_at: Optional[str] = None, + note: Optional[str] = None, + ): + """ + Initialize a PlotlyFigureItem. + + Parameters + ---------- + figure_str : bytes + The JSON str of the Plotly figure. + created_at : str, optional + The creation timestamp in ISO format. + updated_at : str, optional + The last update timestamp in ISO format. + note : str, optional + A note. + """ + super().__init__(created_at, updated_at, note) + + self.figure_str = figure_str + + @classmethod + def factory(cls, figure: plotly.BaseFigure, /, **kwargs) -> PlotlyFigureItem: + """ + Create a new PlotlyFigureItem instance from a Plotly figure. + + Parameters + ---------- + figure : plotly.basedatatypes.BaseFigure + The Plotly figure to store. + + Returns + ------- + PlotlyFigureItem + A new PlotlyFigureItem instance. + """ + if not lazy_is_instance(figure, "plotly.basedatatypes.BaseFigure"): + raise ItemTypeError(f"Type '{figure.__class__}' is not supported.") + + import plotly.io + + return cls(plotly.io.to_json(figure, engine="json"), **kwargs) + + @property + def figure(self) -> plotly.BaseFigure: + """The figure from the persistence.""" + import plotly.io + + return plotly.io.from_json(self.figure_str) + + def as_serializable_dict(self): + """Return item as a JSONable dict to export to frontend.""" + import base64 + + figure_bytes = self.figure_str.encode("utf-8") + figure_bytes_b64 = base64.b64encode(figure_bytes).decode() + + return super().as_serializable_dict() | { + "media_type": "application/vnd.plotly.v1+json;base64", + "value": figure_bytes_b64, + } diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index b9470a50d..d0a02d57d 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -172,6 +172,38 @@ def test_serialize_pillow_item( } +def test_serialize_plotly_item( + client, + in_memory_project, + monkeypatch_datetime, + mock_nowstr, +): + bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) + figure = plotly.graph_objects.Figure(data=[bar]) + figure_str = plotly.io.to_json(figure, engine="json") + figure_bytes = figure_str.encode("utf-8") + figure_bytes_b64 = base64.b64encode(figure_bytes).decode() + + in_memory_project.put("figure", figure) + response = client.get("/api/project/items") + + assert response.status_code == 200 + assert response.json() == { + "views": {}, + "items": { + "figure": [ + { + "name": "figure", + "media_type": "application/vnd.plotly.v1+json;base64", + "value": figure_bytes_b64, + "updated_at": mock_nowstr, + "created_at": mock_nowstr, + } + ] + }, + } + + def test_serialize_media_item(client, in_memory_project): ... diff --git a/skore/tests/unit/item/test_media_item.py b/skore/tests/unit/item/test_media_item.py index 0794b50bc..0db89c082 100644 --- a/skore/tests/unit/item/test_media_item.py +++ b/skore/tests/unit/item/test_media_item.py @@ -1,6 +1,5 @@ import altair import matplotlib.pyplot -import plotly.graph_objects as go import pytest from skore.persistence.item import ItemTypeError, MediaItem @@ -59,18 +58,6 @@ def test_factory_matplotlib(self, mock_nowstr): assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr - def test_factory_plotly(self, mock_nowstr): - figure = go.Figure(data=[go.Bar(x=[1, 2, 3], y=[1, 3, 2])]) - figure_bytes = figure.to_json().encode("utf-8") - - item = MediaItem.factory(figure) - - assert item.media_bytes == figure_bytes - assert item.media_encoding == "utf-8" - assert item.media_type == "application/vnd.plotly.v1+json" - assert item.created_at == mock_nowstr - assert item.updated_at == mock_nowstr - def test_get_serializable_dict(self, mock_nowstr): item = MediaItem.factory("") diff --git a/skore/tests/unit/item/test_plotly_figure_item.py b/skore/tests/unit/item/test_plotly_figure_item.py new file mode 100644 index 000000000..9d1f8f18e --- /dev/null +++ b/skore/tests/unit/item/test_plotly_figure_item.py @@ -0,0 +1,51 @@ +import base64 + +import plotly.graph_objects +import plotly.io +import pytest +from skore.persistence.item import ItemTypeError, PlotlyFigureItem + + +class TestPlotlyFigureItem: + @pytest.fixture(autouse=True) + def monkeypatch_datetime(self, monkeypatch, MockDatetime): + monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) + + def test_factory(self, mock_nowstr): + bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) + figure = plotly.graph_objects.Figure(data=[bar]) + item = PlotlyFigureItem.factory(figure) + + assert item.figure_str == plotly.io.to_json(figure, engine="json") + assert item.created_at == mock_nowstr + assert item.updated_at == mock_nowstr + + def test_factory_exception(self): + with pytest.raises(ItemTypeError): + PlotlyFigureItem.factory(None) + + def test_figure(self): + bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) + figure = plotly.graph_objects.Figure(data=[bar]) + item1 = PlotlyFigureItem.factory(figure) + item2 = PlotlyFigureItem(plotly.io.to_json(figure, engine="json")) + + assert item1.figure == figure + assert item2.figure == figure + + def test_as_serializable_dict(self, mock_nowstr): + bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) + figure = plotly.graph_objects.Figure(data=[bar]) + figure_str = plotly.io.to_json(figure, engine="json") + figure_bytes = figure_str.encode("utf-8") + figure_bytes_b64 = base64.b64encode(figure_bytes).decode() + + item = PlotlyFigureItem.factory(figure) + + assert item.as_serializable_dict() == { + "updated_at": mock_nowstr, + "created_at": mock_nowstr, + "note": None, + "media_type": "application/vnd.plotly.v1+json;base64", + "value": figure_bytes_b64, + } diff --git a/skore/tests/unit/project/test_project.py b/skore/tests/unit/project/test_project.py index eb9790169..c37a8fdf0 100644 --- a/skore/tests/unit/project/test_project.py +++ b/skore/tests/unit/project/test_project.py @@ -3,6 +3,7 @@ import numpy.testing import pandas import pandas.testing +import plotly import polars import polars.testing import pytest @@ -121,7 +122,7 @@ def test_put_vega_chart(in_memory_project): assert isinstance(in_memory_project.get("vega_chart"), bytes) -def test_put_pil_image(in_memory_project): +def test_put_pillow_image(in_memory_project): image1 = Image.new("RGB", (100, 100), color="red") image2 = Image.new("RGBA", (150, 150), color="blue") @@ -132,6 +133,15 @@ def test_put_pil_image(in_memory_project): assert in_memory_project.get("image2") == image2 +def test_put_plotly_figure(in_memory_project): + bar = plotly.graph_objects.Bar(x=[1, 2, 3], y=[1, 3, 2]) + figure = plotly.graph_objects.Figure(data=[bar]) + + in_memory_project.put("figure", figure) + + assert in_memory_project.get("figure") == figure + + def test_put_rf_model(in_memory_project, monkeypatch): # Add a scikit-learn model monkeypatch.setattr("sklearn.utils.estimator_html_repr", lambda _: "")