Skip to content

Commit

Permalink
PostpredictionWrapper: update get and set params (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann authored Sep 27, 2024
1 parent 22efd0d commit d67807f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 17 deletions.
29 changes: 12 additions & 17 deletions molpipeline/post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from numpy import typing as npt
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.base import BaseEstimator, TransformerMixin

from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement
from molpipeline.error_handling import FilterReinserter
Expand Down Expand Up @@ -194,15 +194,10 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
dict[str, Any]
Parameters.
"""
param_dict = {"wrapped_estimator": self.wrapped_estimator}
if deep:
param_dict = {
"wrapped_estimator": clone(self.wrapped_estimator),
}
else:
param_dict = {
"wrapped_estimator": self.wrapped_estimator,
}
param_dict.update(self.wrapped_estimator.get_params(deep=deep))
for key, value in self.wrapped_estimator.get_params(deep=deep).items():
param_dict[f"wrapped_estimator__{key}"] = value
return param_dict

def set_params(self, **params: Any) -> Self:
Expand All @@ -219,12 +214,12 @@ def set_params(self, **params: Any) -> Self:
Parameters.
"""
param_copy = dict(params)
wrapped_estimator = param_copy.pop("wrapped_estimator")
if wrapped_estimator:
self.wrapped_estimator = wrapped_estimator
if param_copy:
if isinstance(self.wrapped_estimator, ABCPipelineElement):
self.wrapped_estimator.set_params(**param_copy)
else:
self.wrapped_estimator.set_params(**param_copy)
if "wrapped_estimator" in param_copy:
self.wrapped_estimator = param_copy.pop("wrapped_estimator")
wrapped_estimator_params = {}
for key, value in param_copy.items():
estimator, _, param = key.partition("__")
if estimator == "wrapped_estimator":
wrapped_estimator_params[param] = value
self.wrapped_estimator.set_params(**wrapped_estimator_params)
return self
83 changes: 83 additions & 0 deletions tests/test_elements/test_post_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Test the module post_prediction.py."""

import unittest

import numpy as np
from sklearn.base import clone
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier

from molpipeline.post_prediction import PostPredictionWrapper


class TestPostPredictionWrapper(unittest.TestCase):
"""Test the PostPredictionWrapper class."""

def test_get_params(self) -> None:
"""Test get_params method."""
rf = RandomForestClassifier()
rf_params = rf.get_params(deep=True)

ppw = PostPredictionWrapper(rf)
ppw_params = ppw.get_params(deep=True)

wrapped_params = {}
for key, value in ppw_params.items():
first, _, rest = key.partition("__")
if first == "wrapped_estimator":
if rest == "":
self.assertIs(rf, value)
else:
wrapped_params[rest] = value

self.assertDictEqual(rf_params, wrapped_params)

def test_set_params(self) -> None:
"""Test set_params method."""
rf = RandomForestClassifier()
ppw = PostPredictionWrapper(rf)

ppw.set_params(wrapped_estimator__n_estimators=10)
self.assertIsInstance(ppw.wrapped_estimator, RandomForestClassifier)
if not isinstance(ppw.wrapped_estimator, RandomForestClassifier):
raise TypeError("Wrapped estimator is not a RandomForestClassifier.")
self.assertEqual(ppw.wrapped_estimator.n_estimators, 10)

ppw_params = ppw.get_params(deep=True)
self.assertEqual(ppw_params["wrapped_estimator__n_estimators"], 10)

def test_fit_transform(self) -> None:
"""Test fit method."""
rng = np.random.default_rng(20240918)
features = rng.random((10, 5))

pca = PCA(n_components=3)
pca.fit(features)
pca_transformed = pca.transform(features)

ppw = PostPredictionWrapper(clone(pca))
ppw.fit(features)
ppw_transformed = ppw.transform(features)

self.assertEqual(pca_transformed.shape, ppw_transformed.shape)
self.assertTrue(np.allclose(pca_transformed, ppw_transformed))

def test_inverse_transform(self) -> None:
"""Test inverse_transform method."""
rng = np.random.default_rng(20240918)
features = rng.random((10, 5))

pca = PCA(n_components=3)
pca.fit(features)
pca_transformed = pca.transform(features)
pca_inverse = pca.inverse_transform(pca_transformed)

ppw = PostPredictionWrapper(clone(pca))
ppw.fit(features)
ppw_transformed = ppw.transform(features)
ppw_inverse = ppw.inverse_transform(ppw_transformed)

self.assertEqual(features.shape, ppw_inverse.shape)
self.assertEqual(pca_inverse.shape, ppw_inverse.shape)

self.assertTrue(np.allclose(pca_inverse, ppw_inverse))

0 comments on commit d67807f

Please sign in to comment.