From e18185f71a1fb2299f03caf0a88a020d1185103b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?In=C3=AAs=20Silva?= Date: Thu, 1 Feb 2024 14:27:38 +0000 Subject: [PATCH] Replaced classifier instanciation with util function --- .../flow/methods/preprocessing/massaging.py | 32 +++---------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/src/aequitas/flow/methods/preprocessing/massaging.py b/src/aequitas/flow/methods/preprocessing/massaging.py index 78785670..a5c09525 100644 --- a/src/aequitas/flow/methods/preprocessing/massaging.py +++ b/src/aequitas/flow/methods/preprocessing/massaging.py @@ -2,20 +2,17 @@ import pandas as pd import math -import inspect from ...utils import create_logger -from ...utils.imports import import_object +from ...utils.imports import instantiate_object from .preprocessing import PreProcessing class Massaging(PreProcessing): def __init__( self, - classifier: Union[ - str, Callable - ] = "sklearn.naive_bayes.GaussianNB", - **classifier_args + classifier: Union[str, Callable] = "sklearn.naive_bayes.GaussianNB", + **classifier_args, ): """ Instantiates a Massaging preprocessing method. @@ -25,27 +22,8 @@ def __init__( self.logger = create_logger("methods.preprocessing.Massaging") self.logger.info("Instantiating a Massaging preprocessing method.") - if isinstance(classifier, str): - classifier = import_object(classifier) - signature = inspect.signature(classifier) - if ( - signature.parameters[list(signature.parameters.keys())[-1]].kind - == inspect.Parameter.VAR_KEYWORD - ): - args = ( - classifier_args # Estimator takes **kwargs, so all args are valid - ) - else: - args = { - arg: value - for arg, value in classifier_args.items() - if arg in signature.parameters - } - self.classifier = classifier(**args) - self.logger.info( - f"Created base estimator {self.classifier} with params {args}, " - F"discarded args:{list(set(classifier_args.keys()) - set(args.keys()))}" - ) + self.classifier = instantiate_object(classifier, **classifier_args) + self.logger.info(f"Created base estimator {self.classifier}") def _rank( self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]