From 16e18f4fff99e960d3993b73b1be6fc6c61576e0 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Mon, 5 Aug 2024 10:28:33 +0000 Subject: [PATCH] adding score function to lss models --- mambular/models/sklearn_base_lss.py | 30 +++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 4e0d6e4..ad7100f 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -467,7 +467,9 @@ def predict(self, X, raw=False): # Perform inference with torch.no_grad(): - predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors) + predictions = self.task_model( + num_features=num_tensors, cat_features=cat_tensors + ) if not raw: return self.task_model.family(predictions).cpu().numpy() @@ -506,7 +508,9 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None): """ # Infer distribution family from model settings if not provided if distribution_family is None: - distribution_family = getattr(self.task_model, "distribution_family", "normal") + distribution_family = getattr( + self.task_model, "distribution_family", "normal" + ) # Setup default metrics if none are provided if metrics is None: @@ -559,3 +563,25 @@ def get_default_metrics(self, distribution_family): "categorical": {"Accuracy": accuracy_score}, } return default_metrics.get(distribution_family, {}) + + def score(self, X, y, metric="NLL"): + """ + Calculate the score of the model using the specified metric. + + Parameters + ---------- + X : array-like or pd.DataFrame of shape (n_samples, n_features) + The input samples to predict. + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The true target values against which to evaluate the predictions. + metric : str, default="NLL" + So far, only negative log-likelihood is supported + + Returns + ------- + score : float + The score calculated using the specified metric. + """ + predictions = self.predict(X) + score = self.task_model.family.evaluate_nll(y, predictions) + return score