Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

81 postpredictionwrapper handle set params without wrapped estimator #82

29 changes: 12 additions & 17 deletions molpipeline/post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from numpy import typing as npt
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.base import BaseEstimator, TransformerMixin

from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement
from molpipeline.error_handling import FilterReinserter
Expand Down Expand Up @@ -194,15 +194,10 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
dict[str, Any]
Parameters.
"""
param_dict = {"wrapped_estimator": self.wrapped_estimator}
if deep:
param_dict = {
"wrapped_estimator": clone(self.wrapped_estimator),
}
else:
param_dict = {
"wrapped_estimator": self.wrapped_estimator,
}
param_dict.update(self.wrapped_estimator.get_params(deep=deep))
for key, value in self.wrapped_estimator.get_params(deep=deep).items():
param_dict[f"wrapped_estimator__{key}"] = value
return param_dict

def set_params(self, **params: Any) -> Self:
Expand All @@ -219,12 +214,12 @@ def set_params(self, **params: Any) -> Self:
Parameters.
"""
param_copy = dict(params)
wrapped_estimator = param_copy.pop("wrapped_estimator")
if wrapped_estimator:
self.wrapped_estimator = wrapped_estimator
if param_copy:
if isinstance(self.wrapped_estimator, ABCPipelineElement):
self.wrapped_estimator.set_params(**param_copy)
else:
self.wrapped_estimator.set_params(**param_copy)
if "wrapped_estimator" in param_copy:
self.wrapped_estimator = param_copy.pop("wrapped_estimator")
wrapped_estimator_params = {}
for key, value in param_copy.items():
estimator, _, param = key.partition("__")
if estimator == "wrapped_estimator":
wrapped_estimator_params[param] = value
self.wrapped_estimator.set_params(**wrapped_estimator_params)
return self
82 changes: 82 additions & 0 deletions tests/test_elements/test_post_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Test the module post_prediction.py."""

import unittest

import numpy as np
from sklearn.base import clone
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier

from molpipeline.post_prediction import PostPredictionWrapper


class TestPostPredictionWrapper(unittest.TestCase):
"""Test the PostPredictionWrapper class."""

def test_get_params(self) -> None:
"""Test get_params method."""
rf = RandomForestClassifier()
rf_params = rf.get_params(deep=True)

ppw = PostPredictionWrapper(rf)
ppw_params = ppw.get_params(deep=True)

wrapped_params = {}
for key, value in ppw_params.items():
first, _, rest = key.partition("__")
if first == "wrapped_estimator":
if rest == "":
self.assertEqual(rf, value)
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved
else:
wrapped_params[rest] = value

self.assertDictEqual(rf_params, wrapped_params)

def test_set_params(self) -> None:
"""Test set_params method."""
rf = RandomForestClassifier()
ppw = PostPredictionWrapper(rf)

ppw.set_params(wrapped_estimator__n_estimators=10)
if not isinstance(ppw.wrapped_estimator, RandomForestClassifier):
raise TypeError("Wrapped estimator is not a RandomForestClassifier.")
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(ppw.wrapped_estimator.n_estimators, 10)

ppw_params = ppw.get_params(deep=True)
self.assertEqual(ppw_params["wrapped_estimator__n_estimators"], 10)

def test_fit_transform(self) -> None:
"""Test fit method."""
rng = np.random.default_rng(20240918)
features = rng.random((100, 10))
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved

pca = PCA(n_components=3)
pca.fit(features)
pca_transformed = pca.transform(features)

ppw = PostPredictionWrapper(clone(pca))
ppw.fit(features)
ppw_transformed = ppw.transform(features)

self.assertEqual(pca_transformed.shape, ppw_transformed.shape)
self.assertTrue(np.allclose(pca_transformed, ppw_transformed))

def test_inverse_transform(self) -> None:
"""Test inverse_transform method."""
rng = np.random.default_rng(20240918)
features = rng.random((5, 10))

pca = PCA(n_components=3)
pca.fit(features)
pca_transformed = pca.transform(features)
pca_inverse = pca.inverse_transform(pca_transformed)

ppw = PostPredictionWrapper(clone(pca))
ppw.fit(features)
ppw_transformed = ppw.transform(features)
ppw_inverse = ppw.inverse_transform(ppw_transformed)

self.assertEqual(features.shape, ppw_inverse.shape)
self.assertEqual(pca_inverse.shape, ppw_inverse.shape)

self.assertTrue(np.allclose(pca_inverse, ppw_inverse))
Loading