Skip to content

Commit

Permalink
add tests for sign
Browse files Browse the repository at this point in the history
  • Loading branch information
JenniferHem committed Jun 13, 2024
1 parent 26ab0d2 commit eda6602
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_metrics/test_ignore_error_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import unittest

import numpy as np
from sklearn import linear_model

from molpipeline.metrics import ignored_value_scorer
from sklearn.metrics import get_scorer


class IgnoreErrorScorerTest(unittest.TestCase):
Expand Down Expand Up @@ -45,3 +47,39 @@ def test_filter_none_with_nan(self) -> None:
ba_score._score_func(y_true, y_pred), # pylint: disable=protected-access
1.0,
)

def test_correct_init_mse(self) -> None:
"""Test that initialization is correct as we access via protected vars."""
x_train = np.array([0.1,0.2, 0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]).reshape(-1, 1)
y_train = np.array([0.1,0.3, 0.3,0.4,0.5,0.5,0.7,0.88,0.9,1])
regr = linear_model.LinearRegression()
regr.fit(x_train, y_train)
cix_scorer = ignored_value_scorer("neg_mean_squared_error", None)
scikit_scorer = get_scorer("neg_mean_squared_error")
self.assertEqual(
cix_scorer(regr, x_train,y_train), scikit_scorer(regr, x_train,y_train)
)

def test_correct_init_rmse(self) -> None:
"""Test that initialization is correct as we access via protected vars."""
x_train = np.array([0.1,0.2, 0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]).reshape(-1, 1)
y_train = np.array([0.1,0.3, 0.3,0.4,0.5,0.5,0.7,0.88,0.9,1])
regr = linear_model.LinearRegression()
regr.fit(x_train, y_train)
cix_scorer = ignored_value_scorer("neg_root_mean_squared_error", None)
scikit_scorer = get_scorer("neg_root_mean_squared_error")
self.assertEqual(
cix_scorer(regr, x_train,y_train), scikit_scorer(regr, x_train,y_train)
)

def test_correct_init_inheritance(self) -> None:
"""Test that initialization is correct if we pass an initialized scorer."""
x_train = np.array([0.1,0.2, 0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]).reshape(-1, 1)
y_train = np.array([0.1,0.3, 0.3,0.4,0.5,0.5,0.7,0.88,0.9,1])
regr = linear_model.LinearRegression()
regr.fit(x_train, y_train)
scikit_scorer = get_scorer("neg_root_mean_squared_error")
cix_scorer = ignored_value_scorer(get_scorer("neg_root_mean_squared_error"), None)
self.assertEqual(
cix_scorer(regr, x_train,y_train), scikit_scorer(regr, x_train,y_train)
)

0 comments on commit eda6602

Please sign in to comment.