Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][OP] Gather_nd exposed to relay #2945

Merged
merged 7 commits into from
Apr 2, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.gather_nd
tvm.relay.full
tvm.relay.full_like
tvm.relay.cast
Expand Down Expand Up @@ -225,6 +226,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.gather_nd
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,11 @@ def _mx_deformable_convolution(inputs, attrs):
return res


def _mx_gather_nd(inputs):
assert len(inputs) == 2
return _op.gather_nd(inputs[0], inputs[1])


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -768,6 +773,7 @@ def _mx_deformable_convolution(inputs, attrs):
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"smooth_l1" : _mx_smooth_l1,
"gather_nd" : _mx_gather_nd,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can put gather_nd in the _identity_list.

# vision
"_contrib_BilinearResize2D" : _mx_upsampling,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand All @@ -782,7 +788,6 @@ def _mx_deformable_convolution(inputs, attrs):
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like,
}

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_injective)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,36 @@ def reverse_reshape(data, newshape):
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))


def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The shape of output tensor.

Returns
-------
ret : relay.Expr
The computed result.

Examples
--------
.. code-block:: python

data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
relay.gather_nd(data, indices) = [2, 3, 0]

data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""

return _make.gather_nd(data, indices)
70 changes: 70 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2122,5 +2122,75 @@ example below::
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// gather_nd operator
bool GatherNDRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "GatherND: expect input data type to be TensorType but get "
<< types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "GatherND: expect indices type to be TensorType but get "
<< types[1];
return false;
}
const size_t ndim = data->shape.size();
const IntImm* mdim = data->shape[0].as<IntImm>();
const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy.";

Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i)
oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Array<Tensor> GatherNDCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::gather_nd(inputs[0], inputs[1]) };
}

Expr MakeGatherND(Expr data,
Expr indices) {
static const Op& op = Op::Get("gather_nd");
return CallNode::make(op, {data, indices}, {});
}

TVM_REGISTER_API("relay.op._make.gather_nd")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
});

RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.

Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("GatherND", GatherNDRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
16 changes: 16 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,21 @@ def verify(shape, indices_src, axis, mode="clip"):
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")

def test_forward_gather_nd():
def verify(xshape, yshape, y_data):
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -527,3 +542,4 @@ def verify(shape, indices_src, axis, mode="clip"):
test_forward_embedding()
test_forward_smooth_l1()
test_forward_take()
test_forward_gather_nd()
21 changes: 20 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def verify_stack(dshapes, axis):
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)



def test_reverse():
def verify_reverse(dshape, axis):
x = relay.var("x", relay.TensorType(dshape, "float32"))
Expand All @@ -573,6 +572,25 @@ def verify_reverse(dshape, axis):
verify_reverse((2, 3, 4), -1)


def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
z = relay.gather_nd(x, y)

func = relay.Function([x, y], z)
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = x_data[y_data]

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])


if __name__ == "__main__":
test_cast()
test_zeros_ones()
Expand Down Expand Up @@ -601,3 +619,4 @@ def verify_reverse(dshape, axis):
test_stack()
test_tile()
test_repeat()
test_gather_nd()