Skip to content

Commit

Permalink
adding score function to lss models
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Aug 5, 2024
1 parent a996e6e commit 16e18f4
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def predict(self, X, raw=False):

# Perform inference
with torch.no_grad():
predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
predictions = self.task_model(
num_features=num_tensors, cat_features=cat_tensors
)

if not raw:
return self.task_model.family(predictions).cpu().numpy()
Expand Down Expand Up @@ -506,7 +508,9 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None):
"""
# Infer distribution family from model settings if not provided
if distribution_family is None:
distribution_family = getattr(self.task_model, "distribution_family", "normal")
distribution_family = getattr(
self.task_model, "distribution_family", "normal"
)

# Setup default metrics if none are provided
if metrics is None:
Expand Down Expand Up @@ -559,3 +563,25 @@ def get_default_metrics(self, distribution_family):
"categorical": {"Accuracy": accuracy_score},
}
return default_metrics.get(distribution_family, {})

def score(self, X, y, metric="NLL"):
"""
Calculate the score of the model using the specified metric.
Parameters
----------
X : array-like or pd.DataFrame of shape (n_samples, n_features)
The input samples to predict.
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
The true target values against which to evaluate the predictions.
metric : str, default="NLL"
So far, only negative log-likelihood is supported
Returns
-------
score : float
The score calculated using the specified metric.
"""
predictions = self.predict(X)
score = self.task_model.family.evaluate_nll(y, predictions)
return score

0 comments on commit 16e18f4

Please sign in to comment.