Skip to content

Commit

Permalink
Refactor gpu function
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jun 27, 2024
1 parent 60da109 commit 6797cd2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions bluecast/general_utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@ def check_gpu_support() -> Dict[str, str]:
d_train = xgb.DMatrix(data, label=label)

params_list = [
{"device": "cuda", "tree_method": "gpu_hist", "predictor": "gpu_predictor"},
{"device": "cuda", "tree_method": "gpu_hist"},
{"device": "cuda"},
{"tree_method": "gpu_hist"},
]

for params in params_list:
try:
xgb.train(params, d_train, num_boost_round=2)
logging.info("Xgboost is using GPU with parameters: %s", params)
return params
booster = xgb.train(params, d_train, num_boost_round=2)
if "gpu" in booster.attributes() or "cuda" in booster.attributes():
logging.info("Xgboost is using GPU with parameters: %s", params)
return params
else:
logging.warning(
"GPU settings applied but no GPU detected in booster attributes: %s",
params,
)
except xgb.core.XGBoostError as e:
logging.warning("Failed with params %s. Error: %s", params, str(e))

# If no GPU parameters work, fall back to CPU
params = {"tree_method": "exact", "device": "cpu"}
params = {"tree_method": "hist", "device": "cpu"}
logging.info("No GPU detected. Xgboost will use CPU with parameters: %s", params)
return params

Expand Down
Binary file modified dist/bluecast-1.4.2-py3-none-any.whl
Binary file not shown.
Binary file modified dist/bluecast-1.4.2.tar.gz
Binary file not shown.

0 comments on commit 6797cd2

Please sign in to comment.