diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 723f9ecdab90..cc97a94a1406 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -144,6 +144,13 @@ struct GatherAttrs : public tvm::AttrsNode { } }; +struct GatherNDAttrs : public tvm::AttrsNode { + Integer batch_dims; + + TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { + TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); + } +}; struct TakeAttrs : public tvm::AttrsNode { Integer batch_dims; Integer axis; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index b2132b75fab9..781c1cbeb311 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1238,13 +1238,14 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, * * \param data The source array. * \param indices The indices of the values to extract. + * \param batch_dims The number of batch dimensions. * \param name The name of the operation. * \param tag The tag to mark the operation. * * \return A Tensor whose op member is the gather_nd operation */ -inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd", - std::string tag = kInjective) { +inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0, + std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; @@ -1255,7 +1256,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); } - for (size_t i = indices_dim0; i < ndim_d; ++i) { + for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) { out_shape.push_back(data->shape[i]); } return compute( @@ -1267,6 +1268,9 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n indices_position.push_back(out_index[i]); } Array real_indices; + for (size_t i = 0; i < static_cast(batch_dims); ++i) { + real_indices.push_back(out_index[i]); + } for (size_t i = 0; i < indices_dim0; ++i) { indices_position.Set(0, make_const(DataType::Int(32), i)); if (indices->dtype.is_int()) { diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b9fabdebb330..e70167a6aa57 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1413,11 +1413,20 @@ def _impl_v1(cls, inputs, attr, params): class GatherND(OnnxOpConverter): """Operator converter for GatherND.""" + @classmethod + def _impl_common(cls, data, indices, batch_dims=0): + indices_dims = len(infer_shape(indices)) + indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) + return _op.gather_nd(data, indices, batch_dims) + @classmethod def _impl_v1(cls, inputs, attr, params): - indices_dims = len(infer_shape(inputs[1])) - indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1))) - return _op.gather_nd(inputs[0], indices) + return cls._impl_common(inputs[0], inputs[1]) + + @classmethod + def _impl_v12(cls, inputs, attr, params): + batch_dims = attr.get("batch_dims", 0) + return cls._impl_common(inputs[0], inputs[1], batch_dims) class Scatter(OnnxOpConverter): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8744e7b5c6ad..eeb8644d4328 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices): +def gather_nd(data, indices, batch_dims=0): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1084,6 +1084,9 @@ def gather_nd(data, indices): indices : relay.Expr The shape of output tensor. + batch_dims : int + The number of batch dimensions. + Returns ------- ret : relay.Expr @@ -1100,8 +1103,12 @@ def gather_nd(data, indices): data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] indices = [[0, 1], [1, 0]] relay.gather_nd(data, indices) = [[3, 4], [5, 6]] + + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[1, 0]] + relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices) + return _make.gather_nd(data, indices, batch_dims) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index df60aeb16bf3..d6c19f9a7034 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3350,21 +3350,34 @@ bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const size_t kdim = indices->shape.size() - 1; ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy."; + const auto param = attrs.as(); + ICHECK(param != nullptr); + + for (int i = 0; i < param->batch_dims->value; ++i) { + ICHECK(reporter->AssertEQ( + data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple + } + Array oshape; for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]); + for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i) + oshape.push_back(data->shape[i]); reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } Array GatherNDCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return {topi::gather_nd(inputs[0], inputs[1])}; + const auto* param = attrs.as(); + ICHECK(param); + return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) { static const Op& op = Op::Get("gather_nd"); - return Call(op, {data, indices}, {}); + auto attrs = make_object(); + attrs->batch_dims = batch_dims; + return Call(op, {data, indices}, Attrs(attrs)); } TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND); @@ -3373,10 +3386,19 @@ RELAY_REGISTER_OP("gather_nd") .describe(R"code(Gather elements or slices from data and store to a tensor whose shape is defined by indices. -Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with -shape (M, Y_0, ..., Y_{K-1}), the output will have shape -(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, -output shape will simply be (Y_0, ..., Y_{K-1}). +Optionally, batch_dims, the number of batch dimensions, can be given, whose +default value is 0. + +Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}), +(M, Y_0, ..., Y_{K-1}) respectively. + +When B > 0, indexing will start from the B-th axis, and it must be the case that +X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. The output will have a shape +(X_0, ..., X_{B-1}, Y_B, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. + +When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}). + +In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index fdb8d205a244..aaf524cc9dcc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re import numpy as np import pytest import scipy @@ -218,6 +219,12 @@ def make_constant_node(name, data_type, dims, vals): ) +def is_version_greater_than(ver): + return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", onnx.__version__)[0]) > "".join( + re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0] + ) + + @tvm.testing.uses_gpu def test_reshape(): in_shape = (4, 3, 3, 4) @@ -1002,12 +1009,16 @@ def test_isnan(): _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {}) -def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"): +def verify_gather_nd(in_shape, indices, out_shape, dtype="float32", batch_dims=0, opset=11): x = np.random.uniform(size=in_shape).astype(dtype) indices = np.array(indices, dtype="int64") y = helper.make_node("GatherND", ["in", "indices"], ["out"]) + if opset >= 12: + batch_dims_attr = helper.make_attribute("batch_dims", batch_dims) + y.attribute.append(batch_dims_attr) + graph = helper.make_graph( [y], "gather_test", @@ -1024,7 +1035,7 @@ def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"): ], ) model = helper.make_model(graph, producer_name="gather_test") - verify_with_ort_with_inputs(model, [x, indices], [out_shape]) + verify_with_ort_with_inputs(model, [x, indices], [out_shape], opset=opset) @tvm.testing.uses_gpu @@ -1034,6 +1045,16 @@ def test_gather_nd(): verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2]) verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2]) + if is_version_greater_than("1.6.0"): + verify_gather_nd([2, 2, 2], [[1], [0]], [2, 2], batch_dims=1, opset=12) + verify_gather_nd( + (3, 2, 2, 3, 4), + np.random.randint(low=0, high=2, size=(3, 2, 3), dtype="int64"), + (3, 2), + batch_dims=2, + opset=12, + ) + @tvm.testing.uses_gpu def test_onehot(): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e84b22b30ce1..b8bab295ba67 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1252,14 +1252,40 @@ def verify_gather(data, axis, indices, ref_res): @tvm.testing.uses_gpu def test_gather_nd(): - def verify_gather_nd(xshape, yshape, y_data): + def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): x = relay.var("x", relay.TensorType(xshape, "float32")) y = relay.var("y", relay.TensorType(yshape, "int32")) - z = relay.gather_nd(x, y) + z = relay.gather_nd(x, y, batch_dims) func = relay.Function([x, y], z) + x_data = np.random.uniform(size=xshape).astype("float32") - ref_res = x_data[tuple(y_data)] + + if y_data: + y_data = np.array(y_data, dtype="int32") + else: + y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32") + + def gather_nd_batch_dims_1_ref(data, indices): + res = [] + for i, row in enumerate(data): + indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch + res.append(row[indices_tuple]) + # stack on the batch dim + return np.stack(res, 0) + + if batch_dims > 1: + x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:]) + y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :]) + + ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape) + + out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:] + ref_res = np.reshape(ref_res, out_shape) + elif batch_dims == 1: + ref_res = gather_nd_batch_dims_1_ref(x_data, y_data) + else: + ref_res = x_data[tuple(y_data)] for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -1272,6 +1298,29 @@ def verify_gather_nd(xshape, yshape, y_data): verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + # Examples from tensorflow gather_nd doc + # https://www.tensorflow.org/api_docs/python/tf/gather_nd + verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1) + verify_gather_nd((2, 2, 2), (1, 2, 1), [[[1], [0]]], 1) + verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1) + + # Test cases from tensorflow gather_nd tests kernel_tests/array_ops_test.py + verify_gather_nd((2, 2, 2), (1, 2), None, 1) + verify_gather_nd((2, 2, 2), (2, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (3, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (2, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (1, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (3, 2, 1), None, 1) + verify_gather_nd((2, 2, 3, 2), (2, 2, 2), None, 1) + verify_gather_nd((2, 2, 3, 2), (1, 2, 3), None, 1) + + verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2, 1), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2) + verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2) + def _verify_infiniteness_ops(relay_op, ref_op): for dtype in ["float32", "float16", "float16", "int32", "int16"]: