Skip to content

Commit

Permalink
Grid search njobs (#46)
Browse files Browse the repository at this point in the history
* Add n_jobs as a parameter

* Bump version
  • Loading branch information
lgmoneda authored Jan 3, 2022
1 parent 43f49b4 commit 16e057e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "time-robust-forest"
version = "0.1.12"
version = "0.1.13"
description = "Explores time information to train a robust random forest"
readme = "README.md"
authors = [
Expand Down
7 changes: 4 additions & 3 deletions time_robust_forest/hyper_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def env_wise_score(estimator, X, y, scorer, env, env_column):
return evaluation


def grid_search(X, y, model, param_grid, env_cvs, scorer):
def grid_search(X, y, model, param_grid, env_cvs, scorer, n_jobs=-1):
"""
FIt the grid search and return it.
"""
Expand All @@ -85,7 +85,7 @@ def grid_search(X, y, model, param_grid, env_cvs, scorer):
param_grid=param_grid,
cv=env_cvs,
scoring=scorer,
n_jobs=-1,
n_jobs=n_jobs,
verbose=0,
refit=False,
)
Expand All @@ -103,6 +103,7 @@ def env_wise_hyper_opt(
cv=5,
scorer=make_scorer(roc_auc_score, needs_proba=True),
ret_results=False,
n_jobs=-1,
):
"""
Optimize the hyper parmaters of a model considering the leave one env out
Expand All @@ -119,7 +120,7 @@ def env_wise_hyper_opt(
for env in envs
}

grid_cv = grid_search(X, y, model, param_grid, env_cvs, scoring_fs)
grid_cv = grid_search(X, y, model, param_grid, env_cvs, scoring_fs, n_jobs)

results_df = extract_results_from_grid_cv(grid_cv.cv_results_, cv, envs)

Expand Down

0 comments on commit 16e057e

Please sign in to comment.