diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index a5fb0a2f109..61a6f545320 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -21,7 +21,6 @@ import ai.djl.ndarray.index.dim.NDIndexPick; import ai.djl.ndarray.index.dim.NDIndexSlice; import ai.djl.ndarray.index.dim.NDIndexTake; -import ai.djl.ndarray.types.DataType; import java.util.ArrayList; import java.util.List; @@ -370,10 +369,11 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) { return argIndex + 1; } else if (arg instanceof NDArray) { NDArray array = (NDArray) arg; - if (array.getDataType() == DataType.BOOLEAN) { + if (array.getDataType().isBoolean()) { indices.add(new NDIndexBooleans(array)); return argIndex + 1; - } else if (array.getDataType().isInteger()) { + } else if (array.getDataType().isInteger() + || array.getDataType().isFloating()) { indices.add(new NDIndexTake(array)); return argIndex + 1; } diff --git a/api/src/main/java/ai/djl/ndarray/types/DataType.java b/api/src/main/java/ai/djl/ndarray/types/DataType.java index 4101054fccc..620cad66fcc 100644 --- a/api/src/main/java/ai/djl/ndarray/types/DataType.java +++ b/api/src/main/java/ai/djl/ndarray/types/DataType.java @@ -90,6 +90,15 @@ public boolean isInteger() { return format == Format.UINT || format == Format.INT; } + /** + * Checks whether it is a boolean data type. + * + * @return whether it is a boolean data type + */ + public boolean isBoolean() { + return format == Format.BOOLEAN; + } + /** * Returns the data type to use for a data buffer. * diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 7e38fe9f530..4920187a3c6 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -57,7 +57,7 @@ public void testPick() { public void testGather() { try (NDManager manager = NDManager.newBaseManager()) { NDArray original = manager.arange(20f).reshape(-1, 4); - NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); + NDArray index = manager.create(new float[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2)); NDArray actual = original.gather(index, 1); NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2)); Assert.assertEquals(actual, expected); @@ -68,7 +68,7 @@ public void testGather() { public void testTake() { try (NDManager manager = NDManager.newBaseManager()) { NDArray original = manager.arange(1, 7f).reshape(-1, 3); - NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray index = manager.create(new float[] {0, 4, 1, 2}, new Shape(2, 2)); NDArray actual = original.take(index); NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); Assert.assertEquals(actual, expected); @@ -119,25 +119,28 @@ public void testGet() { expected = manager.arange(5).reshape(1, 5); Assert.assertEquals(original.get(bool), expected); - // get from int array + // get from integer array (higher rank included) or float array original = manager.arange(1, 7f).reshape(-1, 2); NDArray index = manager.create(new long[] {0, 0, 1, 2}, new Shape(2, 2)); + NDArray indexFloat = manager.create(new float[] {0, 0, 1, 2}, new Shape(2, 2)); NDArray actual = original.get(index); + NDArray actual2 = original.get(indexFloat); expected = manager.create(new float[] {1, 2, 1, 2, 3, 4, 5, 6}, new Shape(2, 2, 2)); Assert.assertEquals(actual, expected); + Assert.assertEquals(actual2, expected); - // indexing with boolean, broadcast int array and slice + // indexing with boolean, slice, and integer array (higher rank included) or float array original = manager.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3); NDArray bool1 = manager.create(new boolean[] {true, false, true}); NDArray index1 = manager.create(new long[] {2, 2}, new Shape(1, 2)); - NDArray index2 = manager.create(new long[] {0, 1}, new Shape(1, 2)); + NDArray index2 = manager.create(new float[] {0, 1}, new Shape(1, 2)); actual = original.get(":{}, {}, {}, {}", 2, index1, bool1, index2); expected = manager.create(new int[] {18, 25, 45, 52}, new Shape(2, 1, 2)); Assert.assertEquals(actual, expected); - // indexing with null, broadcast int array and slice + // indexing with null, slice and integer array (higher rank included) or float array original = manager.arange(3 * 3 * 3).reshape(3, 3, 3); - index1 = manager.create(new long[] {0, 1}, new Shape(2)); + index1 = manager.create(new float[] {0, 1}, new Shape(2)); index2 = manager.create(new long[] {0, 0, 2, 1}, new Shape(2, 2)); actual = original.get(":{}, {}, {}, {}", 2, index1, index2, null); expected = manager.create(new int[] {0, 3, 2, 4, 9, 12, 11, 13}, new Shape(2, 2, 2, 1)); @@ -160,7 +163,7 @@ public void testSetArray() { // set by index array original = manager.arange(1, 10).reshape(3, 3); - NDArray index = manager.create(new long[] {0, 1}, new Shape(2)); + NDArray index = manager.create(new float[] {0, 1}, new Shape(2)); value = manager.create(new int[] {666, 777, 888, 999}, new Shape(2, 2)); original.set(new NDIndex("{}, :{}", index, 2), value); expected = @@ -258,7 +261,7 @@ public void testPut() { try (NDManager manager = NDManager.newBaseManager()) { NDArray original = manager.create(new float[] {1, 2, 3, 4}, new Shape(2, 2)); NDArray expected = manager.create(new float[] {1, 8, 666, 77}, new Shape(2, 2)); - NDArray idx = manager.create(new long[] {2, 3, 1}, new Shape(3)); + NDArray idx = manager.create(new float[] {2, 3, 1}, new Shape(3)); NDArray data = manager.create(new float[] {666, 77, 8}, new Shape(3)); Assert.assertEquals(original.put(idx, data), expected); }