Skip to content

Commit

Permalink
fixed set and get_params functinoality
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Jul 25, 2024
1 parent b16af74 commit bd941d2
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 62 deletions.
69 changes: 50 additions & 19 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..preprocessing import Preprocessor
import numpy as np
from lightning.pytorch.callbacks import ModelSummary
from sklearn.metrics import log_loss


class SklearnBaseClassifier(BaseEstimator):
Expand Down Expand Up @@ -49,23 +50,22 @@ def __init__(self, model, config, **kwargs):

def get_params(self, deep=True):
"""
Get parameters for this estimator. Overrides the BaseEstimator method.
Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = self.config_kwargs # Parameters used to initialize DefaultConfig
params = {}
params.update(self.config_kwargs)

# If deep=True, include parameters from nested components like preprocessor
if deep:
# Assuming Preprocessor has a get_params method
preprocessor_params = {
"preprocessor__" + key: value
for key, value in self.preprocessor.get_params().items()
Expand All @@ -76,35 +76,36 @@ def get_params(self, deep=True):

def set_params(self, **parameters):
"""
Set the parameters of this estimator. Overrides the BaseEstimator method.
Set the parameters of this estimator.
Parameters
----------
**parameters : dict
Estimator parameters to be set.
Estimator parameters.
Returns
-------
self : object
The instance with updated parameters.
Estimator instance.
"""
# Update config_kwargs with provided parameters
valid_config_keys = self.config_kwargs.keys()
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
self.config_kwargs.update(config_updates)

# Update the config object
for key, value in config_updates.items():
setattr(self.config, key, value)

# Handle preprocessor parameters (prefixed with 'preprocessor__')
config_params = {
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
}
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
if k.startswith("preprocessor__")
}

if config_params:
self.config_kwargs.update(config_params)
if self.config is not None:
for key, value in config_params.items():
setattr(self.config, key, value)
else:
self.config = self.config_class(**self.config_kwargs)

if preprocessor_params:
# Assuming Preprocessor has a set_params method
self.preprocessor.set_params(**preprocessor_params)

return self
Expand Down Expand Up @@ -559,3 +560,33 @@ def evaluate(self, X, y_true, metrics=None):
scores[metric_name] = metric_func(y_true, predictions)

return scores

def score(self, X, y, metric=(log_loss, True)):
"""
Calculate the score of the model using the specified metric.
Parameters
----------
X : array-like or pd.DataFrame of shape (n_samples, n_features)
The input samples to predict.
y : array-like of shape (n_samples,)
The true class labels against which to evaluate the predictions.
metric : tuple, default=(log_loss, True)
A tuple containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
Returns
-------
score : float
The score calculated using the specified metric.
"""
metric_func, use_proba = metric

if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)

if use_proba:
probabilities = self.predict_proba(X)
return metric_func(y, probabilities)
else:
predictions = self.predict(X)
return metric_func(y, predictions)
38 changes: 19 additions & 19 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,22 @@ def __init__(self, model, config, **kwargs):

def get_params(self, deep=True):
"""
Get parameters for this estimator. Overrides the BaseEstimator method.
Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = self.config_kwargs # Parameters used to initialize DefaultConfig
params = {}
params.update(self.config_kwargs)

# If deep=True, include parameters from nested components like preprocessor
if deep:
# Assuming Preprocessor has a get_params method
preprocessor_params = {
"preprocessor__" + key: value
for key, value in self.preprocessor.get_params().items()
Expand All @@ -98,35 +97,36 @@ def get_params(self, deep=True):

def set_params(self, **parameters):
"""
Set the parameters of this estimator. Overrides the BaseEstimator method.
Set the parameters of this estimator.
Parameters
----------
**parameters : dict
Estimator parameters to be set.
Estimator parameters.
Returns
-------
self : object
The instance with updated parameters.
Estimator instance.
"""
# Update config_kwargs with provided parameters
valid_config_keys = self.config_kwargs.keys()
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
self.config_kwargs.update(config_updates)

# Update the config object
for key, value in config_updates.items():
setattr(self.config, key, value)

# Handle preprocessor parameters (prefixed with 'preprocessor__')
config_params = {
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
}
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
if k.startswith("preprocessor__")
}

if config_params:
self.config_kwargs.update(config_params)
if self.config is not None:
for key, value in config_params.items():
setattr(self.config, key, value)
else:
self.config = self.config_class(**self.config_kwargs)

if preprocessor_params:
# Assuming Preprocessor has a set_params method
self.preprocessor.set_params(**preprocessor_params)

return self
Expand Down
70 changes: 46 additions & 24 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
from lightning.pytorch.callbacks import ModelSummary
from dataclasses import asdict, is_dataclass


class SklearnBaseRegressor(BaseEstimator):
def __init__(self, model, config, **kwargs):
preprocessor_arg_names = [
self.preprocessor_arg_names = [
"n_bins",
"numerical_preprocessing",
"use_decision_tree_bins",
Expand All @@ -26,16 +27,18 @@ def __init__(self, model, config, **kwargs):
]

self.config_kwargs = {
k: v for k, v in kwargs.items() if k not in preprocessor_arg_names
k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names
}
self.config = config(**self.config_kwargs)

preprocessor_kwargs = {
k: v for k, v in kwargs.items() if k in preprocessor_arg_names
k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
}

self.preprocessor = Preprocessor(**preprocessor_kwargs)
self.base_model = model
self.model = None
self.built = False

# Raise a warning if task is set to 'classification'
if preprocessor_kwargs.get("task") == "classification":
Expand All @@ -44,27 +47,24 @@ def __init__(self, model, config, **kwargs):
UserWarning,
)

self.base_model = model

def get_params(self, deep=True):
"""
Get parameters for this estimator. Overrides the BaseEstimator method.
Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
If True, returns the parameters for this estimator and contained sub-objects that are estimators.
If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = self.config_kwargs # Parameters used to initialize DefaultConfig
params = {}
params.update(self.config_kwargs)

# If deep=True, include parameters from nested components like preprocessor
if deep:
# Assuming Preprocessor has a get_params method
preprocessor_params = {
"preprocessor__" + key: value
for key, value in self.preprocessor.get_params().items()
Expand All @@ -75,35 +75,36 @@ def get_params(self, deep=True):

def set_params(self, **parameters):
"""
Set the parameters of this estimator. Overrides the BaseEstimator method.
Set the parameters of this estimator.
Parameters
----------
**parameters : dict
Estimator parameters to be set.
Estimator parameters.
Returns
-------
self : object
The instance with updated parameters.
Estimator instance.
"""
# Update config_kwargs with provided parameters
valid_config_keys = self.config_kwargs.keys()
config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
self.config_kwargs.update(config_updates)

# Update the config object
for key, value in config_updates.items():
setattr(self.config, key, value)

# Handle preprocessor parameters (prefixed with 'preprocessor__')
config_params = {
k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
}
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
if k.startswith("preprocessor__")
}

if config_params:
self.config_kwargs.update(config_params)
if self.config is not None:
for key, value in config_params.items():
setattr(self.config, key, value)
else:
self.config = self.config_class(**self.config_kwargs)

if preprocessor_params:
# Assuming Preprocessor has a set_params method
self.preprocessor.set_params(**preprocessor_params)

return self
Expand Down Expand Up @@ -471,3 +472,24 @@ def evaluate(self, X, y_true, metrics=None):
scores[metric_name] = metric_func(y_true, predictions)

return scores

def score(self, X, y, metric=mean_squared_error):
"""
Calculate the score of the model using the specified metric.
Parameters
----------
X : array-like or pd.DataFrame of shape (n_samples, n_features)
The input samples to predict.
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
The true target values against which to evaluate the predictions.
metric : callable, default=mean_squared_error
The metric function to use for evaluation. Must be a callable with the signature `metric(y_true, y_pred)`.
Returns
-------
score : float
The score calculated using the specified metric.
"""
predictions = self.predict(X)
return metric(y, predictions)

0 comments on commit bd941d2

Please sign in to comment.