Skip to content

Commit

Permalink
Add UINT8 datatype support
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jul 15, 2021
1 parent c503806 commit 09720df
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion java/src/main/java/ai/onnxruntime/OnnxJavaType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ public enum OnnxJavaType {
INT64(6, long.class, 8),
BOOL(7, boolean.class, 1),
STRING(8, String.class, 4),
UINT8(9, byte.class, 1),
UNKNOWN(0, Object.class, 0);

private static final OnnxJavaType[] values = new OnnxJavaType[9];
private static final OnnxJavaType[] values = new OnnxJavaType[10];

static {
for (OnnxJavaType ot : OnnxJavaType.values()) {
Expand Down Expand Up @@ -62,6 +63,7 @@ public static OnnxJavaType mapFromInt(int value) {
public static OnnxJavaType mapFromOnnxTensorType(OnnxTensorType onnxValue) {
switch (onnxValue) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return OnnxJavaType.UINT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return OnnxJavaType.INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/onnxruntime/OnnxMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) {
return OnnxMapValueType.LONG;
case STRING:
return OnnxMapValueType.STRING;
case UINT8:
case INT8:
case INT16:
case INT32:
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/onnxruntime/OnnxSequence.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public List<Object> getValue() throws OrtException {
list.addAll(Arrays.asList(strings));
return list;
case BOOL:
case UINT8:
case INT8:
case INT16:
case INT32:
Expand Down
2 changes: 2 additions & 0 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public Object getValue() throws OrtException {
return getFloat(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
case DOUBLE:
return getDouble(OnnxRuntime.ortApiHandle, nativeHandle);
case UINT8:
case INT8:
return getByte(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value);
case INT16:
Expand Down Expand Up @@ -738,6 +739,7 @@ private static OnnxTensor createTensor(
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case UINT8:
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/onnxruntime/OrtUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) {
double[] doubleArr = new double[1];
doubleArr[0] = (Double) data;
return doubleArr;
case UINT8:
case INT8:
byte[] byteArr = new byte[1];
byteArr[0] = (Byte) data;
Expand Down
3 changes: 3 additions & 0 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
case INT8:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
case UINT8:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
case INT16:
return OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
case INT32:
Expand Down Expand Up @@ -175,6 +177,7 @@ public Object makeCarrier() throws OrtException {
return OrtUtil.newFloatArray(shape);
case DOUBLE:
return OrtUtil.newDoubleArray(shape);
case UINT8:
case INT8:
return OrtUtil.newByteArray(shape);
case INT16:
Expand Down
14 changes: 14 additions & 0 deletions java/src/test/java/ai/onnxruntime/TensorCreationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.nio.ByteBuffer;

public class TensorCreationTest {

@Test
Expand Down Expand Up @@ -81,4 +83,16 @@ public void testScalarCreation() throws OrtException {
}
}
}

@Test
public void testUint8Creation() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
byte[] buf = new byte[]{0, 1};
ByteBuffer data = ByteBuffer.wrap(buf);
long[] shape = new long[]{2};
try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}
}
}

0 comments on commit 09720df

Please sign in to comment.