diff --git a/src/skore/item/sklearn_base_estimator_item.py b/src/skore/item/sklearn_base_estimator_item.py index b81959b08..6100fb456 100644 --- a/src/skore/item/sklearn_base_estimator_item.py +++ b/src/skore/item/sklearn_base_estimator_item.py @@ -25,8 +25,9 @@ class SklearnBaseEstimatorItem(Item): def __init__( self, - estimator_skops, - estimator_html_repr, + estimator_html_repr: str, + estimator_skops: bytes, + estimator_skops_untrusted_types: list[str], created_at: str | None = None, updated_at: str | None = None, ): @@ -35,10 +36,12 @@ def __init__( Parameters ---------- - estimator_skops : Any - The skops representation of the scikit-learn estimator. estimator_html_repr : str The HTML representation of the scikit-learn estimator. + estimator_skops : bytes + The skops representation of the scikit-learn estimator. + estimator_skops_untrusted_types : list[str] + The list of untrusted types in the skops representation. created_at : str, optional The creation timestamp in ISO format. updated_at : str, optional @@ -46,8 +49,9 @@ def __init__( """ super().__init__(created_at, updated_at) - self.estimator_skops = estimator_skops self.estimator_html_repr = estimator_html_repr + self.estimator_skops = estimator_skops + self.estimator_skops_untrusted_types = estimator_skops_untrusted_types @cached_property def estimator(self) -> sklearn.base.BaseEstimator: @@ -61,7 +65,9 @@ def estimator(self) -> sklearn.base.BaseEstimator: """ import skops.io - return skops.io.loads(self.estimator_skops) + return skops.io.loads( + self.estimator_skops, trusted=self.estimator_skops_untrusted_types + ) @classmethod def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorItem: @@ -85,9 +91,16 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI if not isinstance(estimator, sklearn.base.BaseEstimator): raise TypeError(f"Type '{estimator.__class__}' is not supported.") + estimator_html_repr = sklearn.utils.estimator_html_repr(estimator) + estimator_skops = skops.io.dumps(estimator) + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) + instance = cls( - estimator_skops=skops.io.dumps(estimator), - estimator_html_repr=sklearn.utils.estimator_html_repr(estimator), + estimator_html_repr=estimator_html_repr, + estimator_skops=estimator_skops, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, ) # add estimator as cached property diff --git a/tests/unit/item/test_sklearn_base_estimator_item.py b/tests/unit/item/test_sklearn_base_estimator_item.py index 8003430bd..1eca76231 100644 --- a/tests/unit/item/test_sklearn_base_estimator_item.py +++ b/tests/unit/item/test_sklearn_base_estimator_item.py @@ -4,6 +4,10 @@ from skore.item import SklearnBaseEstimatorItem +class Estimator(sklearn.svm.SVC): + pass + + class TestSklearnBaseEstimatorItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): @@ -11,19 +15,26 @@ def monkeypatch_datetime(self, monkeypatch, MockDatetime): @pytest.mark.order(0) def test_factory(self, monkeypatch, mock_nowstr): - monkeypatch.setattr("skops.io.dumps", lambda _: "") - monkeypatch.setattr( - "sklearn.utils.estimator_html_repr", lambda _: "" - ) - estimator = sklearn.svm.SVC() - estimator_skops = "" estimator_html_repr = "" + estimator_skops = "" + estimator_skops_untrusted_types = "" + + monkeypatch.setattr( + "sklearn.utils.estimator_html_repr", + lambda *args, **kwargs: estimator_html_repr, + ) + monkeypatch.setattr("skops.io.dumps", lambda *args, **kwargs: estimator_skops) + monkeypatch.setattr( + "skops.io.get_untrusted_types", + lambda *args, **kwargs: estimator_skops_untrusted_types, + ) item = SklearnBaseEstimatorItem.factory(estimator) - assert item.estimator_skops == estimator_skops assert item.estimator_html_repr == estimator_html_repr + assert item.estimator_skops == estimator_skops + assert item.estimator_skops_untrusted_types == estimator_skops_untrusted_types assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr @@ -31,15 +42,47 @@ def test_factory(self, monkeypatch, mock_nowstr): def test_estimator(self, mock_nowstr): estimator = sklearn.svm.SVC() estimator_skops = skops.io.dumps(estimator) - estimator_html_repr = "" + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) item1 = SklearnBaseEstimatorItem.factory(estimator) item2 = SklearnBaseEstimatorItem( + estimator_html_repr=None, estimator_skops=estimator_skops, - estimator_html_repr=estimator_html_repr, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, created_at=mock_nowstr, updated_at=mock_nowstr, ) assert isinstance(item1.estimator, sklearn.svm.SVC) assert isinstance(item2.estimator, sklearn.svm.SVC) + + @pytest.mark.order(1) + def test_estimator_untrusted(self, mock_nowstr): + estimator = Estimator() + estimator_skops = skops.io.dumps(estimator) + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) + + if not estimator_skops_untrusted_types: + pytest.skip( + """ + This test is only intended to exhaustively test an untrusted estimator. + The untrusted Estimator class seems to be trusted by default. + Something changed in `skops`. + """ + ) + + item1 = SklearnBaseEstimatorItem.factory(estimator) + item2 = SklearnBaseEstimatorItem( + estimator_html_repr=None, + estimator_skops=estimator_skops, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, + created_at=mock_nowstr, + updated_at=mock_nowstr, + ) + + assert isinstance(item1.estimator, Estimator) + assert isinstance(item2.estimator, Estimator)