diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index c6443e2b086..f8df605be8c 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -32,6 +32,8 @@ public class XgbNDManager extends BaseNDManager { private static final XgbNDManager SYSTEM_MANAGER = new SystemManager(); + private float missingValue = Float.NaN; + private XgbNDManager(NDManager parent, Device device) { super(parent, device); } @@ -40,6 +42,10 @@ static XgbNDManager getSystemManager() { return SYSTEM_MANAGER; } + public void setMissingValue(float missingValue) { + this.missingValue = missingValue; + } + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -95,7 +101,7 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) { if (data.isDirect() && data instanceof ByteBuffer) { // TODO: allow user to set missing value - long handle = JniUtils.createDMatrix(data, shape, 0.0f); + long handle = JniUtils.createDMatrix(data, shape, missingValue); return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE); } @@ -109,7 +115,7 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) { ByteBuffer buf = allocateDirect(size); buf.asFloatBuffer().put((FloatBuffer) data); buf.rewind(); - long handle = JniUtils.createDMatrix(buf, shape, 0.0f); + long handle = JniUtils.createDMatrix(buf, shape, missingValue); return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE); } diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java index 97ccb34608d..4310b342643 100644 --- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java +++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java @@ -68,6 +68,7 @@ public void testNDArray() { try (XgbNDManager manager = (XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) { + manager.setMissingValue(Float.NaN); NDArray zeros = manager.zeros(new Shape(1, 2)); Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray); diff --git a/gradle.properties b/gradle.properties index e5805ecb58a..a2efedbbe74 100644 --- a/gradle.properties +++ b/gradle.properties @@ -19,7 +19,7 @@ paddlepaddle_version=2.0.2 sentencepiece_version=0.1.95 tokenizers_version=0.11.0 fasttext_version=0.9.2 -xgboost_version=1.4.1 +xgboost_version=1.5.2 commons_cli_version=1.5.0 commons_compress_version=1.21