Skip to content

Commit

Permalink
[skl] Enable cat feature without specifying tree method. (#9353)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 3, 2023
1 parent 39390cc commit e964654
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
3 changes: 1 addition & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,8 +930,7 @@ def _duplicated(parameter: str) -> None:
callbacks = self.callbacks if self.callbacks is not None else callbacks

tree_method = params.get("tree_method", None)
cat_support = {"gpu_hist", "approx", "hist"}
if self.enable_categorical and tree_method not in cat_support:
if self.enable_categorical and tree_method == "exact":
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,6 @@ def test_categorical():
X, y = tm.make_categorical(n_samples=32, n_features=2, n_categories=3, onehot=False)
ft = ["c"] * X.shape[1]
reg = xgb.XGBRegressor(
tree_method="hist",
feature_types=ft,
max_cat_to_onehot=1,
enable_categorical=True,
Expand All @@ -1409,7 +1408,7 @@ def test_categorical():
onehot, y = tm.make_categorical(
n_samples=32, n_features=2, n_categories=3, onehot=True
)
reg = xgb.XGBRegressor(tree_method="hist")
reg = xgb.XGBRegressor()
reg.fit(onehot, y, eval_set=[(onehot, y)])
from_enc = reg.evals_result()["validation_0"]["rmse"]
predt_enc = reg.predict(onehot)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_distributed/test_with_dask/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_dask_sparse(client: "Client") -> None:


def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None:
parameters = {"tree_method": tree_method, "max_cat_to_onehot": 9999} # force onehot
parameters = {"tree_method": tree_method, "max_cat_to_onehot": 9999} # force onehot
rounds = 10
m = xgb.dask.DaskDMatrix(client, X_onehot, y, enable_categorical=True)
by_etl_results = xgb.dask.train(
Expand Down Expand Up @@ -364,9 +364,9 @@ def check_model_output(model: xgb.dask.Booster) -> None:
check_model_output(reg.get_booster())

reg = xgb.dask.DaskXGBRegressor(
enable_categorical=True, n_estimators=10
enable_categorical=True, n_estimators=10, tree_method="exact"
)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="categorical data"):
reg.fit(X, y)
# check partition based
reg = xgb.dask.DaskXGBRegressor(
Expand Down

0 comments on commit e964654

Please sign in to comment.