Skip to content

Commit

Permalink
[xgb] Set default missing value to NaN
Browse files Browse the repository at this point in the history
Upgrade XGBoost to 1.5.2

Change-Id: Ie47428f3c0a3dd5fadcb25ea05e655073a40f81f
  • Loading branch information
frankfliu committed Apr 14, 2022
1 parent 6c4dabe commit 4fff540
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public class XgbNDManager extends BaseNDManager {

private static final XgbNDManager SYSTEM_MANAGER = new SystemManager();

private float missingValue = Float.NaN;

private XgbNDManager(NDManager parent, Device device) {
super(parent, device);
}
Expand All @@ -40,6 +42,10 @@ static XgbNDManager getSystemManager() {
return SYSTEM_MANAGER;
}

public void setMissingValue(float missingValue) {
this.missingValue = missingValue;
}

/** {@inheritDoc} */
@Override
public ByteBuffer allocateDirect(int capacity) {
Expand Down Expand Up @@ -95,7 +101,7 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) {

if (data.isDirect() && data instanceof ByteBuffer) {
// TODO: allow user to set missing value
long handle = JniUtils.createDMatrix(data, shape, 0.0f);
long handle = JniUtils.createDMatrix(data, shape, missingValue);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
}

Expand All @@ -109,7 +115,7 @@ public NDArray create(Buffer data, Shape shape, DataType dataType) {
ByteBuffer buf = allocateDirect(size);
buf.asFloatBuffer().put((FloatBuffer) data);
buf.rewind();
long handle = JniUtils.createDMatrix(buf, shape, 0.0f);
long handle = JniUtils.createDMatrix(buf, shape, missingValue);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public void testNDArray() {

try (XgbNDManager manager =
(XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) {
manager.setMissingValue(Float.NaN);
NDArray zeros = manager.zeros(new Shape(1, 2));
Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray);

Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ paddlepaddle_version=2.0.2
sentencepiece_version=0.1.95
tokenizers_version=0.11.0
fasttext_version=0.9.2
xgboost_version=1.4.1
xgboost_version=1.5.2

commons_cli_version=1.5.0
commons_compress_version=1.21
Expand Down

0 comments on commit 4fff540

Please sign in to comment.