From 5688604b7284daa4e6f393b090fef8fe4fa4b760 Mon Sep 17 00:00:00 2001 From: Luis Moneda Date: Wed, 22 Dec 2021 11:10:51 -0300 Subject: [PATCH] Stratify env wise hyper opt (#44) * Replace KFold by StratifiedKFold * Bump version --- pyproject.toml | 2 +- time_robust_forest/hyper_opt.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 448df7f..7918357 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "time-robust-forest" -version = "0.1.10" +version = "0.1.11" description = "Explores time information to train a robust random forest" readme = "README.md" authors = [ diff --git a/time_robust_forest/hyper_opt.py b/time_robust_forest/hyper_opt.py index 3e6e418..493eebc 100644 --- a/time_robust_forest/hyper_opt.py +++ b/time_robust_forest/hyper_opt.py @@ -1,6 +1,6 @@ import pandas as pd from sklearn.metrics import make_scorer, roc_auc_score -from sklearn.model_selection import GridSearchCV, KFold +from sklearn.model_selection import GridSearchCV, StratifiedKFold def extract_results_from_grid_cv(cv_results, kfolds, envs): @@ -52,8 +52,8 @@ def leave_one_env_out_cv(data, env_column="period", cv=5): """ envs = data[env_column].unique() cv_sets = [] - kfolds = KFold(n_splits=cv) - for train_idx, test_idx in kfolds.split(data): + kfolds = StratifiedKFold(n_splits=cv) + for train_idx, test_idx in kfolds.split(data, data[env_column]): for env in envs: all_env_elements = data[data[env_column] == env].index test_env_idx = [i for i in test_idx if i in all_env_elements]