You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There seems to be backwards incompatible behavior with the NDArray.get(NDArray) method. In DJLv0.17.0 the following code works as expected, but in DJLv0.18.0 it throws an IllegalArgumentException at the xTile.get(...) line.
try (NDManagermanager = NDManager.newBaseManager()) {
intnTrain = 5;
NDArrayxTrain = manager.randomUniform(0, 1, newShape(nTrain)).mul(5).sort();
NDArrayxTile = xTrain.tile(newlong[] {nTrain, 1});
// Following line works as expected in v0.17 and returns the diagonal elements// Throws IllegalArgumentException in v0.18NDArraykeys = xTile.get((manager.eye(nTrain)).reshape(newShape(nTrain, -1)));
}
Expected Behavior
The above code should work and return an NDArray (shape 1,5) with the diagonal elements from xTile.
Here's the code I'm running. To reproduce the issue run this against either master or v0.18 tag. If you run this against v0.17 tag it works as expected.
I added this code to a file in the examples/inference module and ran it via ./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug
. But this logic was roughly the same in v0.17 and worked fine.
I'm not sure what changed, but maybe we need to investigate whether we create NDArrays with different datatypes (like int) in some default cases like eye?
Environment Info
Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:
@siddvenk Thanks for spotting this issue! I found the root reason. It worked in version 0.17.0 because NDArray keys = xTile.get((manager.eye(nTrain)).reshape(new Shape(nTrain, -1)));
internally calls take (see PR) which is also supported in MXNet (see PR. In later versions, it switched back to indexing with NDIndex. To utilize take feature, take has to be explicitly called now.
I will also add type convertion for indexing with NDIndex in the current version too.
Description
There seems to be backwards incompatible behavior with the NDArray.get(NDArray) method. In DJLv0.17.0 the following code works as expected, but in DJLv0.18.0 it throws an IllegalArgumentException at the
xTile.get(...)
line.Expected Behavior
The above code should work and return an NDArray (shape 1,5) with the diagonal elements from xTile.
Error Message
How to Reproduce?
Here's the code I'm running. To reproduce the issue run this against either master or v0.18 tag. If you run this against v0.17 tag it works as expected.
I added this code to a file in the examples/inference module and ran it via
./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug
Steps to reproduce
./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug
from the examples directoryWorking on v0.17:
same as above but checkout tags/v0.17.0
What have you tried to solve it?
Seems like the logic here explains why this is throwing an error
djl/api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Lines 356 to 362 in e547f71
I'm not sure what changed, but maybe we need to investigate whether we create NDArrays with different datatypes (like int) in some default cases like eye?
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:The text was updated successfully, but these errors were encountered: