Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jan 13, 2024
1 parent 9f7b9a9 commit 2a46704
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions bluecast/tests/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,32 +314,29 @@ def transform(
) # due to custom model and fit method


@pytest.fixture
def bluecast_instance():
# Create a fixture to instantiate the BlueCast class with default values for testing
return BlueCast(class_problem="binary", target_column="target")


def test_enable_feature_selection_warning(bluecast_instance, capsys):
def test_enable_feature_selection_warning(capsys):
# Test if a warning is raised when feature selection is disabled
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.initial_checks(df)
captured = capsys.readouterr()
assert "Feature selection is disabled." in captured.err


def test_hypertuning_cv_folds_warning(bluecast_instance, capsys):
def test_hypertuning_cv_folds_warning(capsys):
# Test if a warning is raised when hypertuning_cv_folds is set to 1
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_training.hypertuning_cv_folds = 1
bluecast_instance.initial_checks(df)
captured = capsys.readouterr()
assert "Cross validation is disabled." in captured.err


def test_missing_feature_selector_warning(bluecast_instance, capsys):
def test_missing_feature_selector_warning(capsys):
# Test if a warning is raised when feature selection is enabled but no feature selector is provided
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_training.enable_feature_selection = True
bluecast_instance.initial_checks(df)
captured = capsys.readouterr()
Expand All @@ -349,18 +346,20 @@ def test_missing_feature_selector_warning(bluecast_instance, capsys):
)


def test_missing_xgboost_tune_params_config_warning(bluecast_instance, capsys):
def test_missing_xgboost_tune_params_config_warning(capsys):
# Test if a warning is raised when XgboostTuneParamsConfig is not provided
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_xgboost = None
bluecast_instance.initial_checks(df)
captured = capsys.readouterr()
assert "No XgboostTuneParamsConfig has been provided." in captured.err


def test_min_features_to_select_warning(bluecast_instance, capsys):
def test_min_features_to_select_warning(capsys):
# Test if a warning is raised when min_features_to_select is greater than or equal to the number of features
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_training.enable_feature_selection = True
bluecast_instance.conf_training.min_features_to_select = 3
bluecast_instance.initial_checks(df)
Expand All @@ -371,9 +370,10 @@ def test_min_features_to_select_warning(bluecast_instance, capsys):
)


def test_shap_values_and_ml_algorithm_warning(bluecast_instance, capsys):
def test_shap_values_and_ml_algorithm_warning(capsys):
# Test if a warning is raised when calculate_shap_values is True and cat_encoding_via_ml_algorithm is True
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_training.calculate_shap_values = True
bluecast_instance.conf_training.cat_encoding_via_ml_algorithm = True
bluecast_instance.initial_checks(df)
Expand All @@ -384,9 +384,10 @@ def test_shap_values_and_ml_algorithm_warning(bluecast_instance, capsys):
)


def test_cat_encoding_via_ml_algorithm_and_ml_model_warning(bluecast_instance, capsys):
def test_cat_encoding_via_ml_algorithm_and_ml_model_warning(capsys):
# Test if a warning is raised when cat_encoding_via_ml_algorithm is True and ml_model is provided
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_training.cat_encoding_via_ml_algorithm = True
bluecast_instance.ml_model = (
bluecast_instance.ml_model
Expand All @@ -399,9 +400,10 @@ def test_cat_encoding_via_ml_algorithm_and_ml_model_warning(bluecast_instance, c
)


def test_precise_cv_tuning_warnings(bluecast_instance, capsys):
def test_precise_cv_tuning_warnings(capsys):
# Test if warnings are raised for precise_cv_tuning conditions
df = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
bluecast_instance = BlueCast(class_problem="binary", target_column="target")
bluecast_instance.conf_training.precise_cv_tuning = True
bluecast_instance.initial_checks(df)
captured = capsys.readouterr()
Expand All @@ -416,7 +418,7 @@ def test_precise_cv_tuning_warnings(bluecast_instance, capsys):
)


def test_class_problem_mismatch_warnings(bluecast_instance, capsys):
def test_class_problem_mismatch_warnings(capsys):
# Test if warnings are raised for class problem mismatch
df_binary = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 0]})
df_multiclass = pd.DataFrame({"feature1": [1, 2, 3], "target": [0, 1, 2]})
Expand Down

0 comments on commit 2a46704

Please sign in to comment.