From d677123243f693b6665213ad92de8497f313c87c Mon Sep 17 00:00:00 2001 From: Lukas Weidenholzer Date: Fri, 14 Jul 2023 15:48:09 +0200 Subject: [PATCH] move xgb import to occur at runtime --- .../process_implementations/ml/random_forest.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/openeo_processes_dask/process_implementations/ml/random_forest.py b/openeo_processes_dask/process_implementations/ml/random_forest.py index 2de31ea1..266a180c 100644 --- a/openeo_processes_dask/process_implementations/ml/random_forest.py +++ b/openeo_processes_dask/process_implementations/ml/random_forest.py @@ -6,7 +6,7 @@ import geopandas as gpd import numpy as np import xarray as xr -import xgboost as xgb +from xgboost.core import Booster from openeo_processes_dask.process_implementations.cubes.experimental import ( load_vector_cube, @@ -27,7 +27,9 @@ def fit_regr_random_forest( predictors_vars: Optional[list[str]] = None, target_var: str = None, **kwargs, -) -> xgb.core.Booster: +) -> Booster: + import xgboost as xgb + params = { "learning_rate": 1, "max_depth": 5, @@ -72,9 +74,11 @@ def fit_regr_random_forest( def predict_random_forest( data: RasterCube, - model: xgb.Booster, + model: Booster, axis: int = -1, ) -> RasterCube: + import xgboost as xgb + n_features = len(model.feature_names) if n_features != data.shape[axis]: raise Exception(