Skip to content

Commit

Permalink
Refactor NDArrayAdapter
Browse files Browse the repository at this point in the history
1. Remove default implemenation from NDArrayAdapter for required API
2. Throw Unsupported exception for OrtNDArray
3. Fix XGBoost NDArray creation bug

Change-Id: I9f19d064245e47caf0a8718a0b706d6fe8934164
  • Loading branch information
frankfliu committed Jul 18, 2021
1 parent 7b31d94 commit 0fa6e7a
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 86 deletions.
61 changes: 0 additions & 61 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import java.nio.Buffer;
import java.nio.ByteBuffer;

/**
* A base implementation of the {@link NDArray} that does nothing. This can be used for overriding
Expand All @@ -33,66 +32,12 @@ public interface NDArrayAdapter extends NDArray {
String UNSUPPORTED_MSG =
"This NDArray implementation does not currently support this operation";

/** {@inheritDoc} */
@Override
default NDManager getManager() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default String getName() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default void setName(String name) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default String getUid() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default DataType getDataType() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default Device getDevice() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default Shape getShape() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default SparseFormat getSparseFormat() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default void attach(NDManager manager) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default void detach() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default NDArray toDevice(Device device, boolean copy) {
Expand Down Expand Up @@ -135,12 +80,6 @@ default String[] toStringArray() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default ByteBuffer toByteBuffer() {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
default void set(Buffer data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ public void detach() {
manager = DlrNDManager.getSystemManager();
}

/** {@inheritDoc} */
@Override
public NDArray stopGradient() {
throw new UnsupportedOperationException("Not supported for DLR");
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
Expand Down
1 change: 1 addition & 0 deletions integration/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {
runtimeOnly "ai.djl.pytorch:pytorch-native-auto:${pytorch_version}-SNAPSHOT"
runtimeOnly project(":tensorflow:tensorflow-model-zoo")
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:${tensorflow_version}"
runtimeOnly project(":ml:xgboost")

if (System.getProperty("ai.djl.default_engine") == "OnnxRuntime") {
// onnxruntime requires user install libgomp.so.1 manually, exclude from default dependency
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ public void testBlockFactoryLoadingFromZip()
TranslateException {
Path savedDir = Paths.get("build/testBlockFactory");
Utils.deleteQuietly(savedDir);
Path zipPath = prepareModel(savedDir);
Path zipPath;
try {
zipPath = prepareModel(savedDir);
} catch (ModelNotFoundException e) {
throw new UnsupportedOperationException(
"No test model for engine: " + Engine.getInstance().getEngineName(), e);
}
// load model from here
Criteria<NDList, NDList> criteria =
Criteria.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public void l1DecayTest() {
NDArray parameters1 = manager.create(new float[] {-1, -2, 3, 4, 5}); // 15
NDArray parameters2 = manager.create(new float[] {-1, -1, -1, -1, -1}); // 5
// Not used
NDArray pred = manager.create(new float[] {});
NDArray label = manager.create(new float[] {});
NDArray pred = manager.create(new float[0]);
NDArray label = manager.create(new float[0]);
// r = 2*(15 + 5) = 40
L1WeightDecay decay =
Loss.l1WeightedDecay("", 2.0f, new NDList(parameters1, parameters2));
Expand All @@ -46,8 +46,8 @@ public void l2DecayTest() {
NDArray parameters1 = manager.create(new float[] {-1, -2, 3, 4, 5}); // 55
NDArray parameters2 = manager.create(new float[] {-1, -1, -1, -1, -1}); // 5
// Not used
NDArray pred = manager.create(new float[] {});
NDArray label = manager.create(new float[] {});
NDArray pred = manager.create(new float[0]);
NDArray label = manager.create(new float[0]);
// r = 2*(55 + 5) = 120
L2WeightDecay decay =
Loss.l2WeightedDecay("", 2.0f, new NDList(parameters1, parameters2));
Expand All @@ -62,8 +62,8 @@ public void elasticNetDecayTest() {
NDArray parameters1 = manager.create(new float[] {-1, -2, 3, 4, 5});
NDArray parameters2 = manager.create(new float[] {-1, -1, -1, -1, -1});
// Not used
NDArray pred = manager.create(new float[] {});
NDArray label = manager.create(new float[] {});
NDArray pred = manager.create(new float[0]);
NDArray label = manager.create(new float[0]);
// r = L1 + L2 = 2*20 + 1*60 = 100
ElasticNetWeightDecay decay =
Loss.elasticNetWeightedDecay(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void runIntegrationTests() {
if (System.getProperty("os.name").startsWith("Win")) {
engines = new String[] {"MXNet"};
} else {
engines = new String[] {"MXNet", "PyTorch", "TensorFlow"};
engines = new String[] {"MXNet", "PyTorch", "TensorFlow", "XGBoost"};
}
} else {
engines = new String[] {defaultEngine};
Expand Down
13 changes: 13 additions & 0 deletions ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
public class XgbNDArray implements NDArrayAdapter {

private AtomicLong handle;
private String name;
private String uid;
private ByteBuffer data;
private XgbNDManager manager;
Expand Down Expand Up @@ -65,6 +66,18 @@ public XgbNDManager getManager() {
return manager;
}

/** {@inheritDoc} */
@Override
public String getName() {
return name;
}

/** {@inheritDoc} */
@Override
public void setName(String name) {
this.name = name;
}

/** {@inheritDoc} */
@Override
public String getUid() {
Expand Down
16 changes: 8 additions & 8 deletions ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ public Engine getEngine() {
@Override
public XgbNDArray create(Buffer data, Shape shape, DataType dataType) {
if (shape.dimension() != 2) {
throw new IllegalArgumentException("Shape must be in two dimension");
throw new UnsupportedOperationException("Shape must be in two dimension");
}
DataType inputType = DataType.fromBuffer(data);
if (inputType != DataType.FLOAT32) {
throw new IllegalArgumentException(
throw new UnsupportedOperationException(
"Only Float32 data type supported, actual " + inputType);
}
if (data.isDirect() && data instanceof ByteBuffer) {
Expand All @@ -82,7 +82,7 @@ public XgbNDArray create(Buffer data, Shape shape, DataType dataType) {
ByteBuffer buf = allocateDirect(size * numOfBytes);
buf.asFloatBuffer().put((FloatBuffer) data);
buf.rewind();
long handle = JniUtils.createDMatrix(data, shape, 0.0f);
long handle = JniUtils.createDMatrix(buf, shape, 0.0f);
return new XgbNDArray(this, handle, shape, SparseFormat.DENSE);
}

Expand All @@ -91,7 +91,7 @@ public XgbNDArray create(Buffer data, Shape shape, DataType dataType) {
public NDArray createCSR(
float[] data, long[] indptr, long[] indices, Shape shape, Device device) {
if (shape.dimension() != 2) {
throw new IllegalArgumentException("Shape must be in two dimension");
throw new UnsupportedOperationException("Shape must be in two dimension");
}
int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray();
long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data);
Expand All @@ -110,10 +110,10 @@ public NDArray createCSR(
@Override
public NDArray zeros(Shape shape, DataType dataType) {
if (dataType != DataType.FLOAT32) {
throw new IllegalArgumentException("Only float32 supported");
throw new UnsupportedOperationException("Only float32 supported");
}
if (shape.dimension() != 2) {
throw new IllegalArgumentException("Shape must be in two dimension");
throw new UnsupportedOperationException("Shape must be in two dimension");
}
int size = Math.toIntExact(4 * shape.size());
ByteBuffer buffer = allocateDirect(size);
Expand All @@ -124,10 +124,10 @@ public NDArray zeros(Shape shape, DataType dataType) {
@Override
public NDArray ones(Shape shape, DataType dataType) {
if (dataType != DataType.FLOAT32) {
throw new IllegalArgumentException("Only float32 supported");
throw new UnsupportedOperationException("Only float32 supported");
}
if (shape.dimension() != 2) {
throw new IllegalArgumentException("Shape must be in two dimension");
throw new UnsupportedOperationException("Shape must be in two dimension");
}
long size = shape.size();
int bytes = Math.toIntExact(4 * size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.onnxruntime.engine;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
Expand Down Expand Up @@ -42,6 +41,9 @@ public static OnnxTensor toTensor(OrtEnvironment env, NDArray array) throws OrtE

public static OnnxTensor toTensor(
OrtEnvironment env, Buffer data, Shape shape, DataType dataType) throws OrtException {
if (shape.size() == 0) {
throw new UnsupportedOperationException("OnnxRuntime doesn't support 0 length tensor.");
}
long[] sh = shape.getShape();
switch (dataType) {
case FLOAT32:
Expand All @@ -55,11 +57,13 @@ public static OnnxTensor toTensor(
case INT8:
case UINT8:
return OnnxTensor.createTensor(env, (ByteBuffer) data, sh, OnnxJavaType.INT8);
case STRING:
throw new UnsupportedOperationException(
"Use toTensor(OrtEnvironment env, String[] inputs, Shape shape) instead.");
case BOOLEAN:
return OnnxTensor.createTensor(env, (ByteBuffer) data, sh, OnnxJavaType.BOOL);
case FLOAT16:
default:
throw new EngineException("Data type not supported: " + dataType);
throw new UnsupportedOperationException("Data type not supported: " + dataType);
}
}

Expand Down

0 comments on commit 0fa6e7a

Please sign in to comment.