Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

81 postpredictionwrapper handle set params without wrapped estimator #82

29 changes: 12 additions & 17 deletions molpipeline/post_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from numpy import typing as npt
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.base import BaseEstimator, TransformerMixin

from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement
from molpipeline.error_handling import FilterReinserter
Expand Down Expand Up @@ -194,15 +194,10 @@ def get_params(self, deep: bool = True) -> dict[str, Any]:
dict[str, Any]
Parameters.
"""
param_dict = {"wrapped_estimator": self.wrapped_estimator}
if deep:
param_dict = {
"wrapped_estimator": clone(self.wrapped_estimator),
}
else:
param_dict = {
"wrapped_estimator": self.wrapped_estimator,
}
param_dict.update(self.wrapped_estimator.get_params(deep=deep))
for key, value in self.wrapped_estimator.get_params(deep=deep).items():
param_dict[f"wrapped_estimator__{key}"] = value
return param_dict

def set_params(self, **params: Any) -> Self:
Expand All @@ -219,12 +214,12 @@ def set_params(self, **params: Any) -> Self:
Parameters.
"""
param_copy = dict(params)
wrapped_estimator = param_copy.pop("wrapped_estimator")
if wrapped_estimator:
self.wrapped_estimator = wrapped_estimator
if param_copy:
if isinstance(self.wrapped_estimator, ABCPipelineElement):
self.wrapped_estimator.set_params(**param_copy)
else:
self.wrapped_estimator.set_params(**param_copy)
if "wrapped_estimator" in param_copy:
self.wrapped_estimator = param_copy.pop("wrapped_estimator")
wrapped_estimator_params = {}
for key, value in param_copy.items():
estimator, _, param = key.partition("__")
if estimator == "wrapped_estimator":
wrapped_estimator_params[param] = value
self.wrapped_estimator.set_params(**wrapped_estimator_params)
return self
Loading