Skip to content

Commit

Permalink
Update unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jul 11, 2024
1 parent d49100d commit 180d024
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions bluecast/tests/test_eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
mean_squared_error,
root_mean_squared_error,
)
from unittest.mock import patch

from bluecast.evaluation.eval_metrics import (
ClassificationEvalWrapper,
Expand Down Expand Up @@ -110,3 +111,39 @@ def test_regression_run_witouth_args(sample_data_regression):
score = wrapper.regression_eval_func_wrapper(y_true, y_hat)
expected_score = mean_squared_error(y_true, y_hat)
assert score == expected_score


def test_root_mean_squared_error_import():
with patch.dict('sys.modules', {'sklearn.metrics.root_mean_squared_error': None}):
with patch('sklearn.metrics.mean_squared_error') as mock_mse:
from sklearn.metrics import mean_squared_error as root_mean_squared_error
assert root_mean_squared_error is mock_mse


def test_root_mean_squared_error_direct_import():
with patch.dict('sys.modules', {'sklearn.metrics.root_mean_squared_error': object()}):
from sklearn.metrics import root_mean_squared_error
assert root_mean_squared_error is not None
assert callable(root_mean_squared_error) # assuming root_mean_squared_error is callable


def test_classification_eval_func_wrapper_invalid_eval_against():
# Instantiate the class with an invalid eval_against value
def dummy_metric_func(y_true, y_pred, **kwargs):
# Dummy metric function for testing purposes
return 0.5

Check warning on line 134 in bluecast/tests/test_eval_metrics.py

View check run for this annotation

Codecov / codecov/patch

bluecast/tests/test_eval_metrics.py#L134

Added line #L134 was not covered by tests

class TestClass(ClassificationEvalWrapper):
def __init__(self, eval_against, metric_func):
super().__init__(eval_against=eval_against, metric_func=metric_func)

test_obj = TestClass(eval_against="probas_all_classes", metric_func=dummy_metric_func)
test_obj.eval_against = "invalid_value"

# Dummy data for testing
y_true = [1, 0, 1]
y_probs = [[0.4, 0.6], [0.7, 0.3], [0.2, 0.8]]

# Use pytest to check if ValueError is raised
with pytest.raises(ValueError, match=r"Unknown value for eval_against: invalid_value\. Possible values are 'probas' or 'classes'"):
test_obj.classification_eval_func_wrapper(y_true, y_probs)

0 comments on commit 180d024

Please sign in to comment.