Skip to content

Commit

Permalink
Update saving out of fold data
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jul 13, 2024
1 parent 83c36bb commit 60d0cfd
Show file tree
Hide file tree
Showing 6 changed files with 669 additions and 334 deletions.
2 changes: 1 addition & 1 deletion bluecast/blueprints/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def fit_eval(
if len(self.experiment_tracker.experiment_id) == 0:
self.experiment_tracker.experiment_id.append(0)

save_out_of_fold_data(df_eval, y_probs, self.conf_training)
save_out_of_fold_data(df_eval, y_probs, self.class_problem, self.conf_training)

# enrich experiment tracker
for metric, higher_is_better in zip(
Expand Down
2 changes: 1 addition & 1 deletion bluecast/blueprints/cast_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def fit_eval(
if len(self.experiment_tracker.experiment_id) == 0:
self.experiment_tracker.experiment_id.append(0)

save_out_of_fold_data(df_eval, y_preds, self.conf_training)
save_out_of_fold_data(df_eval, y_preds, self.class_problem, self.conf_training)

# enrich experiment tracker
for metric, higher_is_better in zip(
Expand Down
17 changes: 14 additions & 3 deletions bluecast/general_utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import warnings
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Literal, Optional, Union

import dill as pickle
import numpy as np
Expand Down Expand Up @@ -135,17 +135,28 @@ def log_sampling(nb_rows: int, alpha: float = 2.0) -> int:
def save_out_of_fold_data(
oof_data: pd.DataFrame,
y_hat: Union[pd.Series, np.ndarray],
class_problem: Literal["binary", "multiclass", "regression"],
training_config: TrainingConfig,
) -> None:
"""Save out of fold data.
:param oof_data: Data to save.
:param y_hat: Predictions. Will be appended to oof_data and saved together.
:param y_hat: Predictions. Will be appended to oof_data and saved together. When class_problem is "binary", only the
target class score is expected.
:param class_problem: Takes a string containing the class problem type. Either "binary", "multiclass" or
"regression".
:param training_config: Training configuration.
"""
logging.info("Start saving out of fold data.")
oof_data_copy = oof_data.copy()
oof_data_copy["preditions"] = y_hat

if class_problem == "binary":
oof_data_copy["predictions_class_1"] = y_hat
elif class_problem == "multiclass":
for cls_idx in range(y_hat.shape[1]):
oof_data_copy[f"predictions_class_{cls_idx}"] = y_hat[:, cls_idx]
else:
oof_data_copy["predictions"] = y_hat

if isinstance(training_config.out_of_fold_dataset_store_path, str):
oof_data_copy.to_parquet(
Expand Down
65 changes: 64 additions & 1 deletion bluecast/tests/test_save_out_of_fold_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import tempfile

import numpy as np
import pandas as pd
import pytest

from bluecast.config.training_config import TrainingConfig
from bluecast.general_utils.general_utils import save_out_of_fold_data
Expand All @@ -20,7 +22,7 @@ def test_save_out_of_fold_data():
)

# Call the function to save out of fold data
save_out_of_fold_data(oof_data, y_hat, training_config)
save_out_of_fold_data(oof_data, y_hat, "regression", training_config)

# Construct the expected file path
expected_file_path = os.path.join(
Expand All @@ -37,3 +39,64 @@ def test_save_out_of_fold_data():
expected_oof_data["preditions"] = y_hat

pd.testing.assert_frame_equal(saved_oof_data, expected_oof_data)


@pytest.fixture
def sample_oof_data():
return pd.DataFrame({"feature1": [1, 2, 3], "feature2": [4, 5, 6]})


@pytest.fixture
def training_config():
return TrainingConfig(out_of_fold_dataset_store_path="test_path/")


def test_save_out_of_fold_data_binary(sample_oof_data, training_config, mocker):
y_hat = pd.Series([0.1, 0.4, 0.8])
mock_to_parquet = mocker.patch("pandas.DataFrame.to_parquet")

save_out_of_fold_data(sample_oof_data, y_hat, "binary", training_config)

expected_data = sample_oof_data.copy()
expected_data["predictions_class_1"] = y_hat

mock_to_parquet.assert_called_once_with("test_path/oof_data_33.parquet")
pd.testing.assert_frame_equal(mock_to_parquet.call_args[0][0], expected_data)


def test_save_out_of_fold_data_multiclass(sample_oof_data, training_config, mocker):
y_hat = np.array([[0.1, 0.7, 0.2], [0.2, 0.5, 0.3], [0.3, 0.4, 0.3]])
mock_to_parquet = mocker.patch("pandas.DataFrame.to_parquet")

save_out_of_fold_data(sample_oof_data, y_hat, "multiclass", training_config)

expected_data = sample_oof_data.copy()
expected_data["predictions_class_0"] = y_hat[:, 0]
expected_data["predictions_class_1"] = y_hat[:, 1]
expected_data["predictions_class_2"] = y_hat[:, 2]

mock_to_parquet.assert_called_once_with("test_path/oof_data_33.parquet")
pd.testing.assert_frame_equal(mock_to_parquet.call_args[0][0], expected_data)


def test_save_out_of_fold_data_regression(sample_oof_data, training_config, mocker):
y_hat = pd.Series([10.5, 20.3, 30.2])
mock_to_parquet = mocker.patch("pandas.DataFrame.to_parquet")

save_out_of_fold_data(sample_oof_data, y_hat, "regression", training_config)

expected_data = sample_oof_data.copy()
expected_data["predictions"] = y_hat

mock_to_parquet.assert_called_once_with("test_path/oof_data_33.parquet")
pd.testing.assert_frame_equal(mock_to_parquet.call_args[0][0], expected_data)


def test_save_out_of_fold_data_no_store_path(sample_oof_data, mocker):
y_hat = pd.Series([0.1, 0.4, 0.8])
training_config = TrainingConfig(out_of_fold_dataset_store_path=None)
mock_to_parquet = mocker.patch("pandas.DataFrame.to_parquet")

save_out_of_fold_data(sample_oof_data, y_hat, "binary", training_config)

mock_to_parquet.assert_not_called()
Loading

0 comments on commit 60d0cfd

Please sign in to comment.