Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced indexing that supports all indexing features on PyTorch (getter) #1747

Closed
wants to merge 10 commits into from
1 change: 1 addition & 0 deletions api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies {
}
testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo")
testRuntimeOnly project(":engines:pytorch:pytorch-jni")
}

javadoc {
Expand Down
10 changes: 9 additions & 1 deletion api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNone;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.types.DataType;
Expand Down Expand Up @@ -50,7 +51,7 @@ public class NDIndex {
/* Android regex requires escape } char as well */
private static final Pattern ITEM_PATTERN =
Pattern.compile(
"(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})");
"(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})|None");

private int rank;
private List<NDIndexElement> indices;
Expand Down Expand Up @@ -105,6 +106,8 @@ public NDIndex() {
*
* // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
* assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
*
* // TODO: Add doc for the new indexings
* </pre>
*
* @param indices a comma separated list of indices corresponding to either subsections,
Expand Down Expand Up @@ -335,6 +338,11 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) {
if (!m.matches()) {
throw new IllegalArgumentException("Invalid argument index: " + indexItem);
}
// "None" case
if ("None".equals(indexItem)) {
indices.add(new NDIndexNone());
return argIndex;
}
// "*" case
String star = m.group(1);
if (star != null) {
Expand Down
23 changes: 23 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.index.dim;

/** An {@code NDIndexElement} to return all values in a particular dimension. */
public class NDIndexNone implements NDIndexElement {

/** {@inheritDoc} */
@Override
public int getRank() {
return 1;
}
}
16 changes: 8 additions & 8 deletions api/src/main/java/ai/djl/ndarray/index/dim/NDIndexPick.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
/** An {@link NDIndexElement} that gets elements by index in the specified axis. */
public class NDIndexPick implements NDIndexElement {

private NDArray indices;
private NDArray index;

/**
* Constructs a pick.
*
* @param indices the indices to pick
* @param index the index to pick
*/
public NDIndexPick(NDArray indices) {
this.indices = indices;
public NDIndexPick(NDArray index) {
this.index = index;
}

/** {@inheritDoc} */
Expand All @@ -35,11 +35,11 @@ public int getRank() {
}

/**
* Returns the indices to pick.
* Returns the index to pick.
*
* @return the indices to pick
* @return the index to pick
*/
public NDArray getIndices() {
return indices;
public NDArray getIndex() {
return index;
}
}
45 changes: 45 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/dim/NDIndexTake.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.index.dim;

import ai.djl.ndarray.NDArray;

/** An {@link NDIndexElement} that gets elements by index in the specified axis. */
public class NDIndexTake implements NDIndexElement {

private NDArray index;

/**
* Constructs a pick.
*
* @param index the index to pick
*/
public NDIndexTake(NDArray index) {
this.index = index;
}

/** {@inheritDoc} */
@Override
public int getRank() {
return 1;
}

/**
* Returns the index to pick.
*
* @return the index to pick
*/
public NDArray getIndex() {
return index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
axis++;
} else if (el instanceof NDIndexPick) {
if (fullPick == null) {
fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndices(), axis);
fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndex(), axis);
} else {
// Don't support multiple picks
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
Expand Down Expand Up @@ -49,6 +50,18 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) {
}
}

/** {@inheritDoc} */
@Override
public NDArray get(NDArray array, NDIndex index) {
if (index.getRank() == 0) {
if (array.getShape().isScalar()) {
return array.duplicate();
}
index.addAllDim();
}
return JniUtils.indexAdv(manager.from(array), index);
}

/** {@inheritDoc} */
@Override
public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public ByteBuffer allocateDirect(int capacity) {
/** {@inheritDoc} */
@Override
public PtNDArray from(NDArray array) {
if (array == null || array instanceof PtNDArray) {
if (array == null || array instanceof PtNDArray && array.getManager() == this) {
return (PtNDArray) array;
}
return create(array.toByteBuffer(), array.getShape(), array.getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNone;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
Expand All @@ -35,6 +45,8 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;

/**
Expand Down Expand Up @@ -337,6 +349,65 @@ public static PtNDArray index(
ndArray.getHandle(), minIndices, maxIndices, stepIndices));
}

public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) {
List<NDIndexElement> indices = index.getIndices();
long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size());
ListIterator<NDIndexElement> it = indices.listIterator();
while (it.hasNext()) {
if (it.nextIndex() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

NDIndexElement elem = it.next();
if (elem instanceof NDIndexNone) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false);
} else if (elem instanceof NDIndexSlice) {
Long min = ((NDIndexSlice) elem).getMin();
Long max = ((NDIndexSlice) elem).getMax();
Long step = ((NDIndexSlice) elem).getStep();
int null_slice_bin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0);
// null_slice_bin encodes whether (min, max) is null:
// is_null == 1, ! is_null == 0;
// 0b11 == 3, 0b10 = 2, ...
PyTorchLibrary.LIB.torchIndexAppendSlice(
torchIndexHandle,
min == null ? 0 : min,
max == null ? 0 : max,
step == null ? 1 : step,
null_slice_bin);
} else if (elem instanceof NDIndexAll) {
PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3);
} else if (elem instanceof NDIndexFixed) {
PyTorchLibrary.LIB.torchIndexAppendFixed(
torchIndexHandle, ((NDIndexFixed) elem).getIndex());
} else if (elem instanceof NDIndexBooleans) {
PtNDArray index_arr = (PtNDArray) ((NDIndexBooleans) elem).getIndex();
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle());
} else if (elem instanceof NDIndexTake) {
PtNDArray index_arr = (PtNDArray) ((NDIndexTake) elem).getIndex();
if (index_arr.getDataType() != DataType.INT64) {
index_arr = index_arr.toType(DataType.INT64, true);
}
PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle());
} else if (elem instanceof NDIndexPick) {
//noinspection OptionalGetWithoutIsPresent
NDIndexFullPick fullPick =
NDIndexFullPick.fromIndex(index, ndArray.getShape()).get();
return pick(
ndArray,
ndArray.getManager().from(fullPick.getIndices()),
fullPick.getAxis());
}
}
if (indices.size() == index.getEllipsisIndex()) {
PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true);
}

return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchIndexReturn(ndArray.getHandle(), torchIndexHandle));
}

public static void indexSet(
PtNDArray ndArray,
PtNDArray value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,4 +600,17 @@ native void sgdUpdate(
native long torchNorm(long handle, int ord, long[] axis, boolean keepDims);

native long torchNonZeros(long handle);

native long torchIndexInit(int size);

native long torchIndexReturn(long handle, long torchIndexHandle);

native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean is_ellipsis);

native void torchIndexAppendSlice(
long torchIndexHandle, long min, long max, long step, int null_slice_binary);

native void torchIndexAppendFixed(long torchIndexHandle, long idx);

native void torchIndexAppendArray(long torchIndexHandle, long arrayHandle);
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,69 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndex(JNIEnv
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JNIEnv* env, jobject jthis, jint jsize) {
API_BEGIN()
std::vector<at::indexing::TensorIndex> *index_ptr = new std::vector<at::indexing::TensorIndex>;
index_ptr->reserve(jsize);
return reinterpret_cast<uintptr_t>(index_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn(JNIEnv* env, jobject jthis,
jlong jhandle, jlong jtorch_index_handle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
auto* index_ptr = reinterpret_cast<std::vector<at::indexing::TensorIndex> *>(jtorch_index_handle);
torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(*index_ptr));
return reinterpret_cast<uintptr_t>(ret_ptr);
API_END_RETURN()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNoneEllipsis(JNIEnv* env, jobject jthis,
jlong jtorch_index_handle, jboolean jis_ellipsis) {
API_BEGIN()
auto* index_ptr = reinterpret_cast<std::vector<at::indexing::TensorIndex> *>(jtorch_index_handle);
if (jis_ellipsis) {
index_ptr->emplace_back(torch::indexing::Ellipsis);
} else {
index_ptr->emplace_back(torch::indexing::None);
}
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendSlice(JNIEnv* env, jobject jthis,
jlong jtorch_index_handle, jlong jmin, jlong jmax, jlong jstep, jint jnull_slice_binary) {
API_BEGIN()
auto* index_ptr = reinterpret_cast<std::vector<at::indexing::TensorIndex> *>(jtorch_index_handle);
if (jnull_slice_binary == 0) {
index_ptr->emplace_back(torch::indexing::Slice(jmin, jmax, jstep));
} else if (jnull_slice_binary == 1) {
index_ptr->emplace_back(torch::indexing::Slice(jmin, torch::indexing::None, jstep));
} else if (jnull_slice_binary == 2) {
index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, jmax, jstep));
} else if (jnull_slice_binary == 3) {
index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, torch::indexing::None, jstep));
}
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFixed(JNIEnv* env, jobject jthis,
jlong jtorch_index_handle, jlong jidx) {
API_BEGIN()
auto* index_ptr = reinterpret_cast<std::vector<at::indexing::TensorIndex> *>(jtorch_index_handle);
index_ptr->emplace_back((int) jidx);
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendArray(JNIEnv* env, jobject jthis,
jlong jtorch_index_handle, jlong jarray) {
API_BEGIN()
auto* index_ptr = reinterpret_cast<std::vector<at::indexing::TensorIndex> *>(jtorch_index_handle);
auto* array_ptr = reinterpret_cast<torch::Tensor*>(jarray);
index_ptr->emplace_back(*array_ptr);
API_END()
}

JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexPut(JNIEnv* env, jobject jthis, jlong jhandle,
jlong jvalue_handle, jlongArray jmin_indices, jlongArray jmax_indices, jlongArray jstep_indices) {
API_BEGIN()
Expand Down