diff --git a/bluecast/general_utils/general_utils.py b/bluecast/general_utils/general_utils.py index e0517e90..0a8bebfb 100644 --- a/bluecast/general_utils/general_utils.py +++ b/bluecast/general_utils/general_utils.py @@ -33,7 +33,7 @@ def check_gpu_support() -> Dict[str, str]: d_train = xgb.DMatrix(data, label=label) params_list = [ - {"device": "cuda", "tree_method": "gpu_hist"}, + {"tree_method": "gpu_hist"}, ] for params in params_list: @@ -67,7 +67,7 @@ def check_gpu_support() -> Dict[str, str]: logger.warning("Failed with params %s. Error: %s", params, str(e)) # If no GPU parameters work, fall back to CPU - params = {"tree_method": "hist", "device": "cpu"} + params = {"tree_method": "hist"} logger.info("No GPU detected. Xgboost will use CPU with parameters: %s", params) return params diff --git a/bluecast/ml_modelling/parameter_tuning_utils.py b/bluecast/ml_modelling/parameter_tuning_utils.py index 03a88940..7c10927c 100644 --- a/bluecast/ml_modelling/parameter_tuning_utils.py +++ b/bluecast/ml_modelling/parameter_tuning_utils.py @@ -24,12 +24,12 @@ def update_params_based_on_tree_method( param["tree_method"] = trial.suggest_categorical( "tree_method", xgboost_params.tree_method ) - if param["tree_method"] in ["hist", "approx"]: + if param["tree_method"] in ["hist", "approx", "gpu_hist"]: param["max_bin"] = trial.suggest_int( "max_bin", xgboost_params.max_bin_min, xgboost_params.max_bin_max ) - if param.get("device", "cpu") == "cpu": + if param.get("device", None) == "cpu": del param["device"] param["booster"] = trial.suggest_categorical("booster", xgboost_params.booster) @@ -58,7 +58,7 @@ def get_params_based_on_device( """Get parameters based on available or chosen device.""" if conf_training.autotune_on_device in ["auto"]: train_on = check_gpu_support() - conf_params_xgboost.params["device"] = train_on["device"] + conf_params_xgboost.params["device"] = train_on.get("device", None) if "exact" in conf_xgboost.tree_method and conf_params_xgboost.params[ "device" ] in ["gpu", "cuda"]: diff --git a/dist/bluecast-1.6.0-py3-none-any.whl b/dist/bluecast-1.6.0-py3-none-any.whl index 08154fda..932d3915 100644 Binary files a/dist/bluecast-1.6.0-py3-none-any.whl and b/dist/bluecast-1.6.0-py3-none-any.whl differ diff --git a/dist/bluecast-1.6.0.tar.gz b/dist/bluecast-1.6.0.tar.gz index d768e784..20965239 100644 Binary files a/dist/bluecast-1.6.0.tar.gz and b/dist/bluecast-1.6.0.tar.gz differ