Skip to content

Commit

Permalink
Merge branch 'deepjavalibrary:master' into bMultiplicationBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
patins1 authored Nov 1, 2022
2 parents ab92d7a + b0427a8 commit 689e10c
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 7 deletions.
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray randomPermutation(long n) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
Expand Down
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,14 @@ default NDArray linspace(float start, float stop, int num, boolean endpoint, Dev
*/
NDArray randomInteger(long low, long high, Shape shape, DataType dataType);

/**
* Returns a random permutation of integers from 0 to n - 1.
*
* @param n (int) – the upper bound (exclusive)
* @return a random permutation of integers from 0 to n - 1.
*/
NDArray randomPermutation(long n);

/**
* Draws samples from a uniform distribution.
*
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/training/dataset/BatchSampler.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Iterate implements Iterator<List<Long>> {

private long size;
private long current;
private Iterator<Long> itemSampler;
private Iterator<Long> subSample;

Iterate(RandomAccessDataset dataset) {
current = 0;
Expand All @@ -82,7 +82,7 @@ class Iterate implements Iterator<List<Long>> {
} else {
this.size = (dataset.size() + batchSize - 1) / batchSize;
}
itemSampler = subSampler.sample(dataset);
subSample = subSampler.sample(dataset);
}

/** {@inheritDoc} */
Expand All @@ -95,8 +95,8 @@ public boolean hasNext() {
@Override
public List<Long> next() {
List<Long> batchIndices = new ArrayList<>();
while (itemSampler.hasNext()) {
batchIndices.add(itemSampler.next());
while (subSample.hasNext()) {
batchIndices.add(subSample.next());
if (batchIndices.size() == batchSize) {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType
throw new UnsupportedOperationException(UNSUPPORTED);
}

/** {@inheritDoc} */
@Override
public NDArray randomPermutation(long n) {
throw new UnsupportedOperationException("Not supported!");
}

/** {@inheritDoc} */
@Override
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType
return invoke("_npi_random_randint", params);
}

/** {@inheritDoc} */
@Override
public NDArray randomPermutation(long n) {
NDArray array = arange(0, n, 1, DataType.INT64);
MxOpParams params = new MxOpParams();
return invoke("_npi_shuffle", new NDList(array), params).singletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType
return JniUtils.randint(this, low, high, shape, dataType, device);
}

/** {@inheritDoc} */
@Override
public NDArray randomPermutation(long n) {
return JniUtils.randperm(this, n, DataType.INT64, device);
}

/** {@inheritDoc} */
@Override
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,18 @@ public static PtNDArray randint(
false));
}

public static PtNDArray randperm(
PtNDManager manager, long n, DataType dataType, Device device) {
return new PtNDArray(
manager,
PyTorchLibrary.LIB.torchRandPerm(
n,
dataType.ordinal(),
layoutMapper(SparseFormat.DENSE, device),
new int[] {PtDeviceType.toDeviceType(device), device.getDeviceId()},
false));
}

public static PtNDArray normal(
PtNDManager manager,
double mean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ native long torchRandint(
int[] device,
boolean requiredGrad);

native long torchRandPerm(long n, int dType, int layout, int[] device, boolean requireGrad);

native long torchNormal(
double mean,
double std,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchRandint(JNIE
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchRandPerm(
JNIEnv* env, jobject jthis, jlong jn, jint jdtype, jint jlayout, jintArray jdevice, jboolean jrequire_grad) {
API_BEGIN()
const auto options = utils::CreateTensorOptions(env, jdtype, jlayout, jdevice, jrequire_grad);
torch::Tensor tensor = torch::randperm(jn, options);
// Tensor Option for mkldnn is not working
// explicitly convert to mkldnn
if (jlayout == 2) {
tensor = tensor.to_mkldnn();
}
const auto* result_ptr = new torch::Tensor(tensor);
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNormal(JNIEnv* env, jobject jthis, jdouble jmean,
jdouble jstd, jlongArray jsizes, jint jdtype, jint jlayout, jintArray jdevice, jboolean jrequire_grad) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ public static TrainingResult runExample(String[] args)
for (Pair<String, Parameter> paramPair : baseBlock.getParameters()) {
learningRateTrackerBuilder.put(paramPair.getValue().getId(), 0.1f * lr);
}
Optimizer optimizer =
Adam.builder().optLearningRateTracker(learningRateTrackerBuilder.build()).build();
FixedPerVarTracker learningRateTracker = learningRateTrackerBuilder.build();
Optimizer optimizer = Adam.builder().optLearningRateTracker(learningRateTracker).build();
config.optOptimizer(optimizer);

Trainer trainer = model.newTrainer(config);
Expand Down Expand Up @@ -131,7 +131,7 @@ private static RandomAccessDataset getData(Dataset.Usage usage, int batchSize)
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};

// If the user wants to use local repository, then the dataset can be loaded as follows
// If users want to use local repository, then the dataset can be loaded as follows
// Repository repository = Repository.newInstance("banana", Paths.get(LOCAL_FOLDER/{train OR
// test}));
// FruitsFreshAndRotten dataset =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ public void testBoundingBoxes() {

@Test
public void testDrawImage() throws IOException {
TestRequirements.notWindows(); // failed on Windows ServerCore container
TestRequirements.notArm();
ImageFactory factory = ImageFactory.getInstance();
int[] pixels = new int[64];
int index = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,21 @@ public void testRandomInteger() {
}
}

@Test
public void testRandomPermutation() {
try (NDManager manager = NDManager.newBaseManager()) {
long size = 3;
NDArray array = manager.randomPermutation(size);
Assert.assertTrue(
array.contentEquals(manager.create(new long[] {0, 1, 2}))
|| array.contentEquals(manager.create(new long[] {0, 2, 1}))
|| array.contentEquals(manager.create(new long[] {1, 0, 2}))
|| array.contentEquals(manager.create(new long[] {1, 2, 0}))
|| array.contentEquals(manager.create(new long[] {2, 0, 1}))
|| array.contentEquals(manager.create(new long[] {2, 1, 0})));
}
}

@Test
public void testRandomUniform() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down

0 comments on commit 689e10c

Please sign in to comment.