Skip to content

Commit

Permalink
[JAVA_API] Add get_element_type() to Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioPaul0 committed Oct 15, 2024
1 parent 4272f47 commit 7958322
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions modules/java_api/src/main/cpp/openvino_java.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ extern "C"
JNIEXPORT jlong JNICALL Java_org_intel_openvino_Tensor_TensorLong(JNIEnv *, jobject, jintArray, jlongArray);
JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetSize(JNIEnv *, jobject, jlong);
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_GetShape(JNIEnv *, jobject, jlong);
JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetElementType(JNIEnv *, jobject, jlong);
JNIEXPORT jfloatArray JNICALL Java_org_intel_openvino_Tensor_asFloat(JNIEnv *, jobject, jlong);
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_asInt(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_org_intel_openvino_Tensor_delete(JNIEnv *, jobject, jlong);
Expand Down
13 changes: 13 additions & 0 deletions modules/java_api/src/main/cpp/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_GetShape(JNIEnv *env,
return 0;
}

JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetElementType(JNIEnv *env, jobject, jlong addr)
{
JNI_METHOD(
"GetElementType",
Tensor *ov_tensor = (Tensor *)addr;

element::Type_t t_type = ov_tensor->get_element_type();
jint type = static_cast<jint>(t_type);
return type;
)
return 0;
}

JNIEXPORT jfloatArray JNICALL Java_org_intel_openvino_Tensor_asFloat(JNIEnv *env, jobject, jlong addr)
{
JNI_METHOD(
Expand Down
7 changes: 7 additions & 0 deletions modules/java_api/src/main/java/org/intel/openvino/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public int[] get_shape() {
return GetShape(nativeObj);
}

/** Returns the tensor element type. */
public ElementType get_element_type() {
return ElementType.valueOf(GetElementType(nativeObj));
}

/** Returns a tensor data as floating point array. */
public float[] data() {
return asFloat(nativeObj);
Expand All @@ -77,6 +82,8 @@ public int[] as_int() {

private static native int[] GetShape(long addr);

private static native int GetElementType(long addr);

private static native float[] asFloat(long addr);

private static native int[] asInt(long addr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public void testGetTensorFromFloat() {

assertArrayEquals(tensor.get_shape(), dimsArr);
assertArrayEquals(tensor.data(), data, 0.0f);
assertEquals(ElementType.f32, tensor.get_element_type());
}

@Test
Expand All @@ -29,6 +30,7 @@ public void testGetTensorFromInt() {
assertArrayEquals(dimsArr, tensor.get_shape());
assertArrayEquals(inputData, tensor.as_int());
assertEquals(size, tensor.get_size());
assertEquals(ElementType.i32, tensor.get_element_type());
}

@Test
Expand All @@ -41,5 +43,6 @@ public void testGetTensorFromLong() {

assertArrayEquals(dimsArr, tensor.get_shape());
assertEquals(size, tensor.get_size());
assertEquals(ElementType.i64, tensor.get_element_type());
}
}

0 comments on commit 7958322

Please sign in to comment.