Skip to content

Commit

Permalink
Keep test EngineAgnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 7, 2022
1 parent 239c908 commit a720822
Showing 1 changed file with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.integration.tests.ndarray;

import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
Expand Down Expand Up @@ -56,13 +55,12 @@ public void testPick() {

@Test
public void testGather() {
// Currently in windows gradle cannot find all the engines.
// Currently in windows gradle cannot find all the engines to fill in -classpath, except for
// MXNet.
// In the dependencies, changing runtimeOnly to api however will remedy the problem.
// TODO: remove this when gradle problem is fixed.
TestRequirements.notWindows();
Engine engine = Engine.getEngine("PyTorch");
try (NDManager manager = engine.newBaseManager()) {
// try (NDManager manager = NDManager.newBaseManager()) {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
Expand All @@ -73,9 +71,10 @@ public void testGather() {

@Test
public void testTake() {
Engine engine = Engine.getEngine("PyTorch");
try (NDManager manager = engine.newBaseManager()) {
NDArray arr = manager.arange(1,7f).reshape(-1, 3);
// TODO: remove this when gradle problem in windows shown above is fixed.
TestRequirements.notWindows();
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(1, 7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = arr.take(index);
NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Expand Down

0 comments on commit a720822

Please sign in to comment.