Skip to content

Commit

Permalink
Add fixes to feeat selection
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Dec 2, 2023
1 parent 78fa8db commit e32fd91
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 8 deletions.
1 change: 1 addition & 0 deletions bluecast/blueprints/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
self.custom_feature_selector = RFECVSelector(
random_state=self.conf_training.global_random_state,
min_features_to_select=self.conf_training.min_features_to_select,
class_problem=self.class_problem,
)

if self.conf_training.enable_feature_selection:
Expand Down
7 changes: 1 addition & 6 deletions bluecast/blueprints/cast_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

from bluecast.config.training_config import TrainingConfig, XgboostFinalParamConfig
from bluecast.config.training_config import (
Expand Down Expand Up @@ -285,11 +284,7 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
self.custom_feature_selector = RFECVSelector(
random_state=self.conf_training.global_random_state,
min_features_to_select=self.conf_training.min_features_to_select,
stratifier=KFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
),
class_problem=self.class_problem,
)

if self.conf_training.enable_feature_selection:
Expand Down
4 changes: 2 additions & 2 deletions bluecast/preprocessing/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ class RFECVSelector(CustomPreprocessing):
"""

def __init__(
self, random_state: int = 0, min_features_to_select: int = 5, stratifier=None
self, random_state: int = 0, min_features_to_select: int = 5, class_problem=None
):
super().__init__()
self.selected_features = None
self.random_state = random_state
if not stratifier:
if class_problem in ["regression"]:
stratifier = KFold(5, random_state=random_state, shuffle=True)
model = xgb.XGBRegressor()
scorer = make_scorer(mean_squared_error)
Expand Down
Binary file modified dist/bluecast-0.80-py3-none-any.whl
Binary file not shown.
Binary file modified dist/bluecast-0.80.tar.gz
Binary file not shown.

0 comments on commit e32fd91

Please sign in to comment.