From 7958322f499c0eb4c9c032e41d89293294998cc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A1udio?= Date: Tue, 15 Oct 2024 21:59:25 +0100 Subject: [PATCH] [JAVA_API] Add get_element_type() to Tensor --- modules/java_api/src/main/cpp/openvino_java.hpp | 1 + modules/java_api/src/main/cpp/tensor.cpp | 13 +++++++++++++ .../src/main/java/org/intel/openvino/Tensor.java | 7 +++++++ .../test/java/org/intel/openvino/TensorTests.java | 3 +++ 4 files changed, 24 insertions(+) diff --git a/modules/java_api/src/main/cpp/openvino_java.hpp b/modules/java_api/src/main/cpp/openvino_java.hpp index 42b2a9755..d289760b2 100644 --- a/modules/java_api/src/main/cpp/openvino_java.hpp +++ b/modules/java_api/src/main/cpp/openvino_java.hpp @@ -64,6 +64,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_org_intel_openvino_Tensor_TensorLong(JNIEnv *, jobject, jintArray, jlongArray); JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetSize(JNIEnv *, jobject, jlong); JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_GetShape(JNIEnv *, jobject, jlong); + JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetElementType(JNIEnv *, jobject, jlong); JNIEXPORT jfloatArray JNICALL Java_org_intel_openvino_Tensor_asFloat(JNIEnv *, jobject, jlong); JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_asInt(JNIEnv *, jobject, jlong); JNIEXPORT void JNICALL Java_org_intel_openvino_Tensor_delete(JNIEnv *, jobject, jlong); diff --git a/modules/java_api/src/main/cpp/tensor.cpp b/modules/java_api/src/main/cpp/tensor.cpp index 87e430382..93db00c6a 100644 --- a/modules/java_api/src/main/cpp/tensor.cpp +++ b/modules/java_api/src/main/cpp/tensor.cpp @@ -114,6 +114,19 @@ JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_GetShape(JNIEnv *env, return 0; } +JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetElementType(JNIEnv *env, jobject, jlong addr) +{ + JNI_METHOD( + "GetElementType", + Tensor *ov_tensor = (Tensor *)addr; + + element::Type_t t_type = ov_tensor->get_element_type(); + jint type = static_cast(t_type); + return type; + ) + return 0; +} + JNIEXPORT jfloatArray JNICALL Java_org_intel_openvino_Tensor_asFloat(JNIEnv *env, jobject, jlong addr) { JNI_METHOD( diff --git a/modules/java_api/src/main/java/org/intel/openvino/Tensor.java b/modules/java_api/src/main/java/org/intel/openvino/Tensor.java index b3236e26a..2ce905014 100644 --- a/modules/java_api/src/main/java/org/intel/openvino/Tensor.java +++ b/modules/java_api/src/main/java/org/intel/openvino/Tensor.java @@ -56,6 +56,11 @@ public int[] get_shape() { return GetShape(nativeObj); } + /** Returns the tensor element type. */ + public ElementType get_element_type() { + return ElementType.valueOf(GetElementType(nativeObj)); + } + /** Returns a tensor data as floating point array. */ public float[] data() { return asFloat(nativeObj); @@ -77,6 +82,8 @@ public int[] as_int() { private static native int[] GetShape(long addr); + private static native int GetElementType(long addr); + private static native float[] asFloat(long addr); private static native int[] asInt(long addr); diff --git a/modules/java_api/src/test/java/org/intel/openvino/TensorTests.java b/modules/java_api/src/test/java/org/intel/openvino/TensorTests.java index 45a0f3cd8..52c5de640 100644 --- a/modules/java_api/src/test/java/org/intel/openvino/TensorTests.java +++ b/modules/java_api/src/test/java/org/intel/openvino/TensorTests.java @@ -16,6 +16,7 @@ public void testGetTensorFromFloat() { assertArrayEquals(tensor.get_shape(), dimsArr); assertArrayEquals(tensor.data(), data, 0.0f); + assertEquals(ElementType.f32, tensor.get_element_type()); } @Test @@ -29,6 +30,7 @@ public void testGetTensorFromInt() { assertArrayEquals(dimsArr, tensor.get_shape()); assertArrayEquals(inputData, tensor.as_int()); assertEquals(size, tensor.get_size()); + assertEquals(ElementType.i32, tensor.get_element_type()); } @Test @@ -41,5 +43,6 @@ public void testGetTensorFromLong() { assertArrayEquals(dimsArr, tensor.get_shape()); assertEquals(size, tensor.get_size()); + assertEquals(ElementType.i64, tensor.get_element_type()); } }