diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 863f2160e8f..59b69fbe729 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -193,6 +193,12 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ + @Override + public NDArray randomPermutation(long n) { + throw new UnsupportedOperationException("Not supported!"); + } + /** {@inheritDoc} */ @Override public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 72d4c08e15f..26ffcd7aa4a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -1176,6 +1176,14 @@ default NDArray linspace(float start, float stop, int num, boolean endpoint, Dev */ NDArray randomInteger(long low, long high, Shape shape, DataType dataType); + /** + * Returns a random permutation of integers from 0 to n - 1. + * + * @param n (int) – the upper bound (exclusive) + * @return a random permutation of integers from 0 to n - 1. + */ + NDArray randomPermutation(long n); + /** * Draws samples from a uniform distribution. * diff --git a/api/src/main/java/ai/djl/training/dataset/BatchSampler.java b/api/src/main/java/ai/djl/training/dataset/BatchSampler.java index b44cec801f2..432c7190532 100644 --- a/api/src/main/java/ai/djl/training/dataset/BatchSampler.java +++ b/api/src/main/java/ai/djl/training/dataset/BatchSampler.java @@ -73,7 +73,7 @@ class Iterate implements Iterator> { private long size; private long current; - private Iterator itemSampler; + private Iterator subSample; Iterate(RandomAccessDataset dataset) { current = 0; @@ -82,7 +82,7 @@ class Iterate implements Iterator> { } else { this.size = (dataset.size() + batchSize - 1) / batchSize; } - itemSampler = subSampler.sample(dataset); + subSample = subSampler.sample(dataset); } /** {@inheritDoc} */ @@ -95,8 +95,8 @@ public boolean hasNext() { @Override public List next() { List batchIndices = new ArrayList<>(); - while (itemSampler.hasNext()) { - batchIndices.add(itemSampler.next()); + while (subSample.hasNext()) { + batchIndices.add(subSample.next()); if (batchIndices.size() == batchSize) { break; } diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index 1a0ac0d4ef5..bde8a89137a 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -143,6 +143,12 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType throw new UnsupportedOperationException(UNSUPPORTED); } + /** {@inheritDoc} */ + @Override + public NDArray randomPermutation(long n) { + throw new UnsupportedOperationException("Not supported!"); + } + /** {@inheritDoc} */ @Override public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java index a4bb4b2fea9..7fc78b5427f 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java @@ -233,6 +233,14 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType return invoke("_npi_random_randint", params); } + /** {@inheritDoc} */ + @Override + public NDArray randomPermutation(long n) { + NDArray array = arange(0, n, 1, DataType.INT64); + MxOpParams params = new MxOpParams(); + return invoke("_npi_shuffle", new NDList(array), params).singletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 40af6b7a7db..401da0942d6 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -152,6 +152,12 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType return JniUtils.randint(this, low, high, shape, dataType, device); } + /** {@inheritDoc} */ + @Override + public NDArray randomPermutation(long n) { + return JniUtils.randperm(this, n, DataType.INT64, device); + } + /** {@inheritDoc} */ @Override public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index a4bb78cb6d2..eae7b68afef 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1114,6 +1114,18 @@ public static PtNDArray randint( false)); } + public static PtNDArray randperm( + PtNDManager manager, long n, DataType dataType, Device device) { + return new PtNDArray( + manager, + PyTorchLibrary.LIB.torchRandPerm( + n, + dataType.ordinal(), + layoutMapper(SparseFormat.DENSE, device), + new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()}, + false)); + } + public static PtNDArray normal( PtNDManager manager, double mean, diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 378efeae477..7bab3e1540a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -345,6 +345,8 @@ native long torchRandint( int[] device, boolean requiredGrad); + native long torchRandPerm(long n, int dType, int layout, int[] device, boolean requireGrad); + native long torchNormal( double mean, double std, diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_random_sampling.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_random_sampling.cc index 6a174a35baf..0f53b2f1efa 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_random_sampling.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_random_sampling.cc @@ -36,6 +36,21 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchRandint(JNIE API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchRandPerm( + JNIEnv* env, jobject jthis, jlong jn, jint jdtype, jint jlayout, jintArray jdevice, jboolean jrequire_grad) { + API_BEGIN() + const auto options = utils::CreateTensorOptions(env, jdtype, jlayout, jdevice, jrequire_grad); + torch::Tensor tensor = torch::randperm(jn, options); + // Tensor Option for mkldnn is not working + // explicitly convert to mkldnn + if (jlayout == 2) { + tensor = tensor.to_mkldnn(); + } + const auto* result_ptr = new torch::Tensor(tensor); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNormal(JNIEnv* env, jobject jthis, jdouble jmean, jdouble jstd, jlongArray jsizes, jint jdtype, jint jlayout, jintArray jdevice, jboolean jrequire_grad) { API_BEGIN() diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java index acebbc93ec3..79c730abf00 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java @@ -97,8 +97,8 @@ public static TrainingResult runExample(String[] args) for (Pair paramPair : baseBlock.getParameters()) { learningRateTrackerBuilder.put(paramPair.getValue().getId(), 0.1f * lr); } - Optimizer optimizer = - Adam.builder().optLearningRateTracker(learningRateTrackerBuilder.build()).build(); + FixedPerVarTracker learningRateTracker = learningRateTrackerBuilder.build(); + Optimizer optimizer = Adam.builder().optLearningRateTracker(learningRateTracker).build(); config.optOptimizer(optimizer); Trainer trainer = model.newTrainer(config); @@ -131,7 +131,7 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize) float[] mean = {0.485f, 0.456f, 0.406f}; float[] std = {0.229f, 0.224f, 0.225f}; - // If the user wants to use local repository, then the dataset can be loaded as follows + // If users want to use local repository, then the dataset can be loaded as follows // Repository repository = Repository.newInstance("banana", Paths.get(LOCAL_FOLDER/{train OR // test})); // FruitsFreshAndRotten dataset = diff --git a/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java b/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java index 69c683bcf40..14e822148c7 100644 --- a/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java +++ b/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java @@ -161,6 +161,8 @@ public void testBoundingBoxes() { @Test public void testDrawImage() throws IOException { + TestRequirements.notWindows(); // failed on Windows ServerCore container + TestRequirements.notArm(); ImageFactory factory = ImageFactory.getInstance(); int[] pixels = new int[64]; int index = 0; diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java index 37e18db1b91..00b346eabba 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayCreationOpTest.java @@ -394,6 +394,21 @@ public void testRandomInteger() { } } + @Test + public void testRandomPermutation() { + try (NDManager manager = NDManager.newBaseManager()) { + long size = 3; + NDArray array = manager.randomPermutation(size); + Assert.assertTrue( + array.contentEquals(manager.create(new long[] {0, 1, 2})) + || array.contentEquals(manager.create(new long[] {0, 2, 1})) + || array.contentEquals(manager.create(new long[] {1, 0, 2})) + || array.contentEquals(manager.create(new long[] {1, 2, 0})) + || array.contentEquals(manager.create(new long[] {2, 0, 1})) + || array.contentEquals(manager.create(new long[] {2, 1, 0}))); + } + } + @Test public void testRandomUniform() { try (NDManager manager = NDManager.newBaseManager()) {