Skip to content

Commit

Permalink
Add UINT8 datatype support to Java (#8401)
Browse files Browse the repository at this point in the history
Add UINT8 datatype support
Add inference test for UINT8 model
  • Loading branch information
frankfliu authored Jul 23, 2021
1 parent 950fe5e commit 002e427
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 1 deletion.
4 changes: 3 additions & 1 deletion java/src/main/java/ai/onnxruntime/OnnxJavaType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ public enum OnnxJavaType {
INT64(6, long.class, 8),
BOOL(7, boolean.class, 1),
STRING(8, String.class, 4),
UINT8(9, byte.class, 1),
UNKNOWN(0, Object.class, 0);

private static final OnnxJavaType[] values = new OnnxJavaType[9];
private static final OnnxJavaType[] values = new OnnxJavaType[10];

static {
for (OnnxJavaType ot : OnnxJavaType.values()) {
Expand Down Expand Up @@ -62,6 +63,7 @@ public static OnnxJavaType mapFromInt(int value) {
public static OnnxJavaType mapFromOnnxTensorType(OnnxTensorType onnxValue) {
switch (onnxValue) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return OnnxJavaType.UINT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return OnnxJavaType.INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/onnxruntime/OnnxMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) {
return OnnxMapValueType.LONG;
case STRING:
return OnnxMapValueType.STRING;
case UINT8:
case INT8:
case INT16:
case INT32:
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/onnxruntime/OnnxSequence.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public List<Object> getValue() throws OrtException {
list.addAll(Arrays.asList(strings));
return list;
case BOOL:
case UINT8:
case INT8:
case INT16:
case INT32:
Expand Down
2 changes: 2 additions & 0 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public Object getValue() throws OrtException {
return getFloat(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
case DOUBLE:
return getDouble(OnnxRuntime.ortApiHandle, nativeHandle);
case UINT8:
case INT8:
return getByte(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
case INT16:
Expand Down Expand Up @@ -744,6 +745,7 @@ private static OnnxTensor createTensor(
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case UINT8:
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/onnxruntime/OrtUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) {
double[] doubleArr = new double[1];
doubleArr[0] = (Double) data;
return doubleArr;
case UINT8:
case INT8:
byte[] byteArr = new byte[1];
byteArr[0] = (Byte) data;
Expand Down
3 changes: 3 additions & 0 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
case INT8:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
case UINT8:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
case INT16:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
case INT32:
Expand Down Expand Up @@ -179,6 +181,7 @@ public Object makeCarrier() throws OrtException {
return OrtUtil.newFloatArray(shape);
case DOUBLE:
return OrtUtil.newDoubleArray(shape);
case UINT8:
case INT8:
return OrtUtil.newByteArray(shape);
case INT16:
Expand Down
22 changes: 22 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,28 @@ public void testModelInputINT8() throws OrtException {
}
}

@Test
public void testModelInputUINT8() throws OrtException {
String modelPath = getResourcePath("/test_types_UINT8.pb").toString();

try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputUINT8");
SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map<String, OnnxTensor> container = new HashMap<>();
byte[] flatInput = new byte[] {1, 2, -3, Byte.MIN_VALUE, Byte.MAX_VALUE};
ByteBuffer data = ByteBuffer.wrap(flatInput);
long[] shape = new long[] {1, 5};
OnnxTensor ov = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8);
container.put(inputName, ov);
try (OrtSession.Result res = session.run(container)) {
byte[] resultArray = TestHelpers.flattenByte(res.get(0).getValue());
assertArrayEquals(flatInput, resultArray);
}
OnnxValue.close(container);
}
}

@Test
public void testModelInputINT16() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
Expand Down
13 changes: 13 additions & 0 deletions java/src/test/java/ai/onnxruntime/TensorCreationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package ai.onnxruntime;

import java.nio.ByteBuffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -112,4 +113,16 @@ public void testStringCreation() throws OrtException {
}
}
}

@Test
public void testUint8Creation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
byte[] buf = new byte[] {0, 1};
ByteBuffer data = ByteBuffer.wrap(buf);
long[] shape = new long[] {2};
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}
}
}

0 comments on commit 002e427

Please sign in to comment.