Skip to content

Commit

Permalink
suggested editions
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Oct 26, 2022
1 parent a113a7f commit dc5c328
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType

/** {@inheritDoc} */
@Override
public NDArray randPerm(long n) {
public NDArray randomPermutation(long n) {
throw new UnsupportedOperationException("Not supported!");
}

Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ default NDArray linspace(float start, float stop, int num, boolean endpoint, Dev
* @param n (int) – the upper bound (exclusive)
* @return a random permutation of integers from 0 to n - 1.
*/
NDArray randPerm(long n);
NDArray randomPermutation(long n);

/**
* Draws samples from a uniform distribution.
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/nn/core/Linear.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
* <p>It has the following shapes:
*
* <ul>
* <li>input X: [batch_num, input_dim]
* <li>input X: [x1, x2, …, xn, input_dim]
* <li>weight W: [units, input_dim]
* <li>Bias b: [units]
* <li>output Y: [batch_num, units]
* <li>output Y: [y1, y2, …, yn, units]
* </ul>
*
* <p>The Linear block should be constructed using {@link Linear.Builder}.
Expand Down Expand Up @@ -165,7 +165,7 @@ public void loadMetadata(byte loadVersion, DataInputStream is)
*
* @param input input X: [x1, x2, …, xn, input_dim]
* @param weight weight W: [units, input_dim]
* @return output Y: [x1, x2, …, xn, units]
* @return output Y: [y1, y2, …, yn, units]
*/
public static NDList linear(NDArray input, NDArray weight) {
return linear(input, weight, null);
Expand All @@ -177,7 +177,7 @@ public static NDList linear(NDArray input, NDArray weight) {
* @param input input X: [x1, x2, …, xn, input_dim]
* @param weight weight W: [units, input_dim]
* @param bias bias b: [units]
* @return output Y: [x1, x2, …, xn, units]
* @return output Y: [y1, y2, …, yn, units]
*/
public static NDList linear(NDArray input, NDArray weight, NDArray bias) {
return input.getNDArrayInternal().linear(input, weight, bias);
Expand Down
4 changes: 2 additions & 2 deletions api/src/main/java/ai/djl/translate/StackBatchifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ public NDList batchify(NDList[] inputs) {
NDArray currInput = input.get(i);
if (!currInput.getShape().equals(kindDataShape)) {
throw new IllegalArgumentException(
"You cannot batch data with different input shapes"
"You cannot batch data with different input shapes. currInput: "
+ currInput.getShape()
+ " vs "
+ " vs kindDataShape: "
+ kindDataShape,
e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType

/** {@inheritDoc} */
@Override
public NDArray randPerm(long n) {
public NDArray randomPermutation(long n) {
throw new UnsupportedOperationException("Not supported!");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType

/** {@inheritDoc} */
@Override
public NDArray randPerm(long n) {
public NDArray randomPermutation(long n) {
NDArray array = arange(0, n, 1, DataType.INT64);
MxOpParams params = new MxOpParams();
return invoke("_npi_shuffle", new NDList(array), params).singletonOrThrow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType

/** {@inheritDoc} */
@Override
public NDArray randPerm(long n) {
public NDArray randomPermutation(long n) {
return JniUtils.randperm(this, n, DataType.INT64, device);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public static TrainingResult runExample(String[] args)
ZooModel<NDList, NDList> embedding = criteria.loadModel();

Block baseBlock = embedding.getBlock();
baseBlock.freezeParameters(false);
Block blocks =
new SequentialBlock()
.add(baseBlock)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ public void testRandomInteger() {
}

@Test
public void testRandPerm() {
public void testRandomPermutation() {
try (NDManager manager = NDManager.newBaseManager()) {
long size = 3;
NDArray array = manager.randPerm(size);
NDArray array = manager.randomPermutation(size);
Assert.assertTrue(
array.contentEquals(manager.create(new long[] {0, 1, 2}))
|| array.contentEquals(manager.create(new long[] {0, 2, 1}))
Expand Down

0 comments on commit dc5c328

Please sign in to comment.