Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, ONNX] Support gather_nd batch_dims attribute for TF/ONNX #8084

Merged
merged 17 commits into from
May 21, 2021
7 changes: 7 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
}
};

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
}
};
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
Expand Down
10 changes: 7 additions & 3 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1238,13 +1238,14 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
*
* \param data The source array.
* \param indices The indices of the values to extract.
* \param batch_dims The number of batch dimensions.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the gather_nd operation
*/
inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd",
std::string tag = kInjective) {
inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
std::string name = "T_gather_nd", std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
Expand All @@ -1255,7 +1256,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
for (size_t i = 1; i < ndim_i; ++i) {
out_shape.push_back(indices->shape[i]);
}
for (size_t i = indices_dim0; i < ndim_d; ++i) {
for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
out_shape.push_back(data->shape[i]);
}
return compute(
Expand All @@ -1267,6 +1268,9 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
indices_position.push_back(out_index[i]);
}
Array<PrimExpr> real_indices;
for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
real_indices.push_back(out_index[i]);
}
for (size_t i = 0; i < indices_dim0; ++i) {
indices_position.Set(0, make_const(DataType::Int(32), i));
if (indices->dtype.is_int()) {
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,11 +1413,20 @@ def _impl_v1(cls, inputs, attr, params):
class GatherND(OnnxOpConverter):
"""Operator converter for GatherND."""

@classmethod
def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(data, indices, batch_dims)

@classmethod
def _impl_v1(cls, inputs, attr, params):
indices_dims = len(infer_shape(inputs[1]))
indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(inputs[0], indices)
return cls._impl_common(inputs[0], inputs[1])

@classmethod
def _impl_v12(cls, inputs, attr, params):
batch_dims = attr.get("batch_dims", 0)
return cls._impl_common(inputs[0], inputs[1], batch_dims)


class Scatter(OnnxOpConverter):
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices):
def gather_nd(data, indices, batch_dims=0):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.

Expand All @@ -1084,6 +1084,9 @@ def gather_nd(data, indices):
indices : relay.Expr
The shape of output tensor.

batch_dims : int
The number of batch dimensions.

Returns
-------
ret : relay.Expr
Expand All @@ -1100,8 +1103,12 @@ def gather_nd(data, indices):
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]

data = [[[0,1],[2,3]],[[4,5],[6,7]]]
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices)
return _make.gather_nd(data, indices, batch_dims)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
38 changes: 30 additions & 8 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3350,21 +3350,34 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const size_t kdim = indices->shape.size() - 1;
ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy.";

const auto param = attrs.as<GatherNDAttrs>();
ICHECK(param != nullptr);

for (int i = 0; i < param->batch_dims->value; ++i) {
ICHECK(reporter->AssertEQ(
data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple
}

Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return {topi::gather_nd(inputs[0], inputs[1])};
const auto* param = attrs.as<GatherNDAttrs>();
ICHECK(param);
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) {
static const Op& op = Op::Get("gather_nd");
return Call(op, {data, indices}, {});
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
return Call(op, {data, indices}, Attrs(attrs));
}

TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND);
Expand All @@ -3373,10 +3386,19 @@ RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.

Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
Optionally, batch_dims, the number of batch dimensions, can be given, whose
default value is 0.

Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}),
(M, Y_0, ..., Y_{K-1}) respectively.

When B > 0, indexing will start from the B-th axis, and it must be the case that
X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. The output will have a shape
(X_0, ..., X_{B-1}, Y_B, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N.

When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}).

In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
Expand Down
25 changes: 23 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import numpy as np
import pytest
import scipy
Expand Down Expand Up @@ -218,6 +219,12 @@ def make_constant_node(name, data_type, dims, vals):
)


def is_version_greater_than(ver):
return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", onnx.__version__)[0]) > "".join(
re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0]
)


@tvm.testing.uses_gpu
def test_reshape():
in_shape = (4, 3, 3, 4)
Expand Down Expand Up @@ -1002,12 +1009,16 @@ def test_isnan():
_test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {})


def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"):
def verify_gather_nd(in_shape, indices, out_shape, dtype="float32", batch_dims=0, opset=11):
x = np.random.uniform(size=in_shape).astype(dtype)
indices = np.array(indices, dtype="int64")

y = helper.make_node("GatherND", ["in", "indices"], ["out"])

if opset >= 12:
batch_dims_attr = helper.make_attribute("batch_dims", batch_dims)
y.attribute.append(batch_dims_attr)

graph = helper.make_graph(
[y],
"gather_test",
Expand All @@ -1024,7 +1035,7 @@ def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"):
],
)
model = helper.make_model(graph, producer_name="gather_test")
verify_with_ort_with_inputs(model, [x, indices], [out_shape])
verify_with_ort_with_inputs(model, [x, indices], [out_shape], opset=opset)


@tvm.testing.uses_gpu
Expand All @@ -1034,6 +1045,16 @@ def test_gather_nd():
verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2])
verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2])

if is_version_greater_than("1.6.0"):
verify_gather_nd([2, 2, 2], [[1], [0]], [2, 2], batch_dims=1, opset=12)
verify_gather_nd(
(3, 2, 2, 3, 4),
np.random.randint(low=0, high=2, size=(3, 2, 3), dtype="int64"),
(3, 2),
batch_dims=2,
opset=12,
)


@tvm.testing.uses_gpu
def test_onehot():
Expand Down
55 changes: 52 additions & 3 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,14 +1252,40 @@ def verify_gather(data, axis, indices, ref_res):

@tvm.testing.uses_gpu
def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
z = relay.gather_nd(x, y)
z = relay.gather_nd(x, y, batch_dims)

func = relay.Function([x, y], z)

x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = x_data[tuple(y_data)]

if y_data:
y_data = np.array(y_data, dtype="int32")
else:
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])

ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)

out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
else:
ref_res = x_data[tuple(y_data)]

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand All @@ -1272,6 +1298,29 @@ def verify_gather_nd(xshape, yshape, y_data):
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])

# Examples from tensorflow gather_nd doc
# https://www.tensorflow.org/api_docs/python/tf/gather_nd
verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1)
verify_gather_nd((2, 2, 2), (1, 2, 1), [[[1], [0]]], 1)
verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1)

# Test cases from tensorflow gather_nd tests kernel_tests/array_ops_test.py
verify_gather_nd((2, 2, 2), (1, 2), None, 1)
verify_gather_nd((2, 2, 2), (2, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (3, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (2, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (1, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (3, 2, 1), None, 1)
verify_gather_nd((2, 2, 3, 2), (2, 2, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (1, 2, 3), None, 1)

verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2, 1), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2)


def _verify_infiniteness_ops(relay_op, ref_op):
for dtype in ["float32", "float16", "float16", "int32", "int16"]:
Expand Down