Skip to content

Commit

Permalink
Stratify env wise hyper opt (#44)
Browse files Browse the repository at this point in the history
* Replace KFold by StratifiedKFold

* Bump version
  • Loading branch information
lgmoneda authored Dec 22, 2021
1 parent 3a6913e commit 5688604
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "time-robust-forest"
version = "0.1.10"
version = "0.1.11"
description = "Explores time information to train a robust random forest"
readme = "README.md"
authors = [
Expand Down
6 changes: 3 additions & 3 deletions time_robust_forest/hyper_opt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pandas as pd
from sklearn.metrics import make_scorer, roc_auc_score
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.model_selection import GridSearchCV, StratifiedKFold


def extract_results_from_grid_cv(cv_results, kfolds, envs):
Expand Down Expand Up @@ -52,8 +52,8 @@ def leave_one_env_out_cv(data, env_column="period", cv=5):
"""
envs = data[env_column].unique()
cv_sets = []
kfolds = KFold(n_splits=cv)
for train_idx, test_idx in kfolds.split(data):
kfolds = StratifiedKFold(n_splits=cv)
for train_idx, test_idx in kfolds.split(data, data[env_column]):
for env in envs:
all_env_elements = data[data[env_column] == env].index
test_env_idx = [i for i in test_idx if i in all_env_elements]
Expand Down

0 comments on commit 5688604

Please sign in to comment.