From 8fce89500c520c4dc6ce8733172fa87ead107709 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 27 Apr 2021 23:13:06 -0600 Subject: [PATCH] [TOPI][RELAY][ONNX] Scatter ND (#7927) * passing topi tests * passing relay tests, needs better shape checking still * support ONNX operator * add shape checking back in * fix lint * update docstring --- include/tvm/relay/attrs/transform.h | 5 +- python/tvm/relay/frontend/onnx.py | 13 +++ python/tvm/relay/frontend/pytorch.py | 21 +---- python/tvm/relay/op/_tensor_grad.py | 2 +- python/tvm/relay/op/_transform.py | 2 +- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/transform.py | 13 +-- python/tvm/topi/cuda/scatter.py | 66 ++++++++------ python/tvm/topi/scatter.py | 60 +++++++------ python/tvm/topi/x86/scatter.py | 60 +++++++------ src/relay/op/tensor/transform.cc | 26 ++++-- tests/python/frontend/onnx/test_forward.py | 1 - tests/python/relay/test_op_level3.py | 85 ++++++++++--------- tests/python/topi/python/test_topi_scatter.py | 70 +++++++++------ 14 files changed, 248 insertions(+), 178 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index a5544c8a8799..113c8209fe6a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -126,10 +126,11 @@ struct ScatterAddAttrs : public tvm::AttrsNode { }; struct ScatterNDAttrs : public tvm::AttrsNode { - Array out_shape; + String mode; TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") { - TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter."); + TVM_ATTR_FIELD(mode).describe( + "Accumulation mode of the scatter, either \"update\" or \"add\"."); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a695e0002b34..deb29480d807 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1376,6 +1376,18 @@ def _impl_v1(cls, inputs, attr, params): return _op.scatter(inputs[0], inputs[1], inputs[2], axis) +class ScatterND(OnnxOpConverter): + """Operator converter for Scatter.""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + indices_dim = len(infer_shape(inputs[1])) + axes = list(range(indices_dim)) + return _op.scatter_nd( + inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update" + ) + + class Greater(OnnxOpConverter): """Operator logical greater.""" @@ -2874,6 +2886,7 @@ def _get_convert_map(opset): "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}), "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), + "ScatterND": ScatterND.get_converter(opset), "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a31c44a369f9..025942bcfa22 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2118,26 +2118,13 @@ def index_put(self, inputs, input_types): indices = inputs[1] values = inputs[2] accumulate = inputs[3] - # accumulate parameter is ignored. - # torch.index_put default is False but Relay.scatter_nd accumulates values. - # We assume there is no duplicate indices in torch.index_put input if not accumulate: - logging.warning( - "torch.index_put accumulate parameter is False. " - "TVM uses tvm.relay.scatter_nd operator which accumulates values. " - "Make sure there is no duplicate indices in torch.index_put input." - ) - # Relay scatter_nd does not support input tensor - # We assume that torch.index_put is used with empty zero-values input tensor - # scatter_nd will create empty zero-values tensor with a given shape - out_shape = self.infer_shape(in_tensor) - logging.warning( - "tvm.relay.scatter_nd operator does not support input tensor parameter. " - "TVM assumes that torch.index_put is used with empty zero-values input tensor" - ) + mode = "update" + else: + mode = "add" # Combine array of index tensors into one index tensor with shape (N,_) index_tensor = _op.stack(indices, axis=0) - return _op.transform.scatter_nd(values, index_tensor, out_shape) + return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode) def scalar_tensor(self, inputs, input_types): data = inputs[0] diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 5836aebce393..108bef0242fe 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -834,7 +834,7 @@ def gather_nd_grad(orig, grad): Returns the gradient of gather_nd, which is simply scatter_nd. """ data, indices = orig.args - return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)] + return [scatter_nd(zeros_like(data), indices, grad, mode="add"), zeros_like(indices)] @register_gradient("reshape_like") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 8220ad3bc736..2920c9955b9b 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -145,7 +145,7 @@ def compute_scatter_add(attrs, inputs, output_type): @_reg.register_compute("scatter_nd") def compute_scatter_nd(attrs, inputs, output_type): """Compute definition of scatter_nd""" - return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)] + return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)] _reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 70e021910ab0..7451b397265f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1288,7 +1288,7 @@ def wrap_compute_scatter_nd(topi_compute): """Wrap scatter_nd topi compute""" def _compute_scatter_nd(attrs, inputs, _): - return [topi_compute(inputs[0], inputs[1], attrs.out_shape)] + return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.mode)] return _compute_scatter_nd diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f94a00db2fb1..df2686196151 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -310,8 +310,8 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) -def scatter_nd(data, indices, out_shape): - """Scatter values from an array. +def scatter_nd(data, indices, updates, mode="update"): + """Scatter values from an array and update. See :py:func:`tvm.topi.scatter` for how data is scattered. @@ -323,15 +323,18 @@ def scatter_nd(data, indices, out_shape): indices : relay.Expr The index locations to update. - out_shape : Union[Tuple[int], List[int]] - Output shape of the scatter. + updates : relay.Expr + The values to update. + + mode : string + The accumulation mode for scatter. "update" or "add" Returns ------- ret : relay.Expr The computed result. """ - return _make.scatter_nd(data, indices, out_shape) + return _make.scatter_nd(data, indices, updates, mode) def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fd05904ba8e7..cee13d7e01a2 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -723,11 +723,12 @@ def update_func(dst_ptr, dst_index, update): return out -def scatter_nd(data, indices, shape): +def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. - Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), + scatter_nd computes .. code-block:: @@ -737,9 +738,9 @@ def scatter_nd(data, indices, shape): x_M, ..., x_{N-1} - ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - all other entries in the output are 0. Repeated indices are summed. + where the update function f is determinted by the mode. Parameters ---------- @@ -749,35 +750,41 @@ def scatter_nd(data, indices, shape): indices : tvm.te.Tensor The indices of the values to extract. - shape : Sequence[int] - The output shape. This must be specified because it cannot be inferred. + updates : tvm.te.Tensor + The updates to apply at the Indices + + mode : string + The update mode for the algorithm, either "update" or "add" + If update, the update values will replace the input data + If add, the update values will be added to the input data Returns ------- ret : tvm.te.Tensor """ - _verify_scatter_nd_inputs(data, indices, shape) + _verify_scatter_nd_inputs(data, indices, updates) - def gen_ir(data_ptr, indices_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the data dimensions. + # number of loops. We do the same thing for all the update dimensions. fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i - fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: - fused_data_dimension *= i + fused_updates_dimension = 1 + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_updates_dimension *= i fused_shape = 1 - for i in shape: + for i in data_ptr.shape: fused_shape *= i # For now we avoid parallizing over dimensions indexed by `indices` as @@ -789,38 +796,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - tdim = min(max_threads, fused_data_dimension) + tdim = min(max_threads, fused_updates_dimension) ib.scope_attr(tx, "thread_extent", tdim) - bdim = ceil_div(fused_data_dimension, tdim) + bdim = ceil_div(fused_updates_dimension, tdim) ib.scope_attr(bx, "thread_extent", bdim) - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: - index = i * fused_data_dimension + bx * tdim + tx + index = i * fused_updates_dimension + bx * tdim + tx with ib.if_scope(index < fused_shape): - out[index] = tvm.tir.Cast(data_ptr.dtype, 0) + out[index] = data[index] with ib.for_range(0, fused_indices_dimension) as i: j = bx * tdim + tx - with ib.if_scope(j < fused_data_dimension): - offset = fused_data_dimension + with ib.if_scope(j < fused_updates_dimension): + offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() - out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( - [shape], - [data, indices], - lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="scatter_nd_cuda", diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index a376963aa55a..d7b008c4c33f 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..tir import decl_buffer, ir_builder, Cast, AssertStmt, StringImm, Evaluate +from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate from ..te import extern, hybrid @@ -199,22 +199,22 @@ def scatter(data, indices, updates, axis=0): raise ValueError("scatter only support for 1-4 dimensions") -def _verify_scatter_nd_inputs(data, indices, shape): +def _verify_scatter_nd_inputs(data, indices, updates): mdim = int(indices.shape[0]) - assert mdim <= len(shape), ( + assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " f"the length of the shape of the output ({len(shape)})." ) for i in range(len(indices.shape) - 1): - assert indices.shape[i + 1] == data.shape[i], ( + assert indices.shape[i + 1] == updates.shape[i], ( f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " - f"data[{i}] ({data.shape[i]})." + f"updates[{i}] ({updates.shape[i]})." ) - for i in range(mdim, len(shape)): + for i in range(mdim, len(data.shape)): data_ind = i - mdim + len(indices.shape) - 1 - assert data.shape[data_ind] == shape[i], ( - f"Dimension of data[{data_ind}] ({data.shape[data_ind]}) must equal dimension " - f"of out_shape[{i}] ({shape[i]})." + assert updates.shape[data_ind] == data.shape[i], ( + f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension " + f"of out_shape[{i}] ({data.shape[i]})." ) assert ( @@ -222,11 +222,12 @@ def _verify_scatter_nd_inputs(data, indices, shape): ), f"Indices must be a tensor of integers, but its elements are {indices.dtype}." -def scatter_nd(data, indices, shape): +def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. - Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), + scatter_nd computes .. code-block:: @@ -236,9 +237,9 @@ def scatter_nd(data, indices, shape): x_M, ..., x_{N-1} - ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - all other entries in the output are 0. Repeated indices are summed. + where the update function f is determinted by the mode. Parameters ---------- @@ -248,29 +249,33 @@ def scatter_nd(data, indices, shape): indices : tvm.te.Tensor The indices of the values to extract. - shape : Sequence[int] - The output shape. This must be specified because it cannot be inferred. + updates : tvm.te.Tensor + The updates to apply at the Indices + + mode : string + The update mode for the algorithm, either "update" or "add" + If update, the update values will replace the input data + If add, the update values will be added to the input data Returns ------- ret : tvm.te.Tensor """ - _verify_scatter_nd_inputs(data, indices, shape) + _verify_scatter_nd_inputs(data, indices, updates) - def gen_ir(data_ptr, indices_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? fused_shape = 1 - for i in shape: + for i in data.shape: fused_shape *= i with ib.for_range(0, fused_shape) as i: - out[i] = Cast(data_ptr.dtype, 0) + out[i] = data[i] # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary @@ -300,15 +305,20 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): ) ) offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + if mode == "add": + out[index] += updates[i * fused_data_dimension + j] + elif mode == "update": + out[index] = updates[i * fused_data_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() out_buf = decl_buffer(shape, data.dtype, "out_buf") return extern( [shape], - [data, indices], - lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="scatter_nd_generic", diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index 8bb3f57e82e4..5eb5e6e99b6c 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -20,11 +20,12 @@ from ..scatter import _verify_scatter_nd_inputs -def scatter_nd(data, indices, shape): +def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. - Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape - (M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes + Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape + (M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}), + scatter_nd computes .. code-block:: @@ -34,9 +35,9 @@ def scatter_nd(data, indices, shape): x_M, ..., x_{N-1} - ] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] + ] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]) - all other entries in the output are 0. Repeated indices are summed. + where the update function f is determinted by the mode. Parameters ---------- @@ -46,62 +47,71 @@ def scatter_nd(data, indices, shape): indices : tvm.te.Tensor The indices of the values to extract. - shape : Sequence[int] - The output shape. This must be specified because it cannot be inferred. + updates : tvm.te.Tensor + The updates to apply at the Indices + + mode : string + The update mode for the algorithm, either "update" or "add" + If update, the update values will replace the input data + If add, the update values will be added to the input data Returns ------- ret : tvm.te.Tensor """ - _verify_scatter_nd_inputs(data, indices, shape) + _verify_scatter_nd_inputs(data, indices, updates) - def gen_ir(data_ptr, indices_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name ib = tvm.tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) # We combine all the indices dimensions but the first one into a single # dimension so we can iterate it in single loop instead of an arbitrary - # number of loops. We do the same thing for all the data dimensions. + # number of loops. We do the same thing for all the update dimensions. fused_indices_dimension = 1 for i in indices_ptr.shape[1:]: fused_indices_dimension *= i - fused_data_dimension = 1 - for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]: - fused_data_dimension *= i + fused_updates_dimension = 1 + for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]: + fused_updates_dimension *= i fused_shape = 1 - for i in shape: + for i in data_ptr.shape: fused_shape *= i - # zero data - # TODO(tkonolige): could we use topi.full to zero it instead? with ib.for_range(0, fused_shape) as i: - out[i] = tvm.tir.Cast(data_ptr.dtype, 0) + out[i] = data[i] with ib.for_range(0, fused_indices_dimension) as i: - with ib.for_range(0, fused_data_dimension, kind="parallel") as j: - offset = fused_data_dimension + with ib.for_range(0, fused_updates_dimension, kind="parallel") as j: + offset = fused_updates_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part # of the index into out. for l in reversed(range(indices_ptr.shape[0].value)): # indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}] index += offset * indices[i + l * fused_indices_dimension] - offset *= shape[l] - out[index] += data[i * fused_data_dimension + j] + offset *= data_ptr.shape[l] + if mode == "update": + out[index] = updates[i * fused_updates_dimension + j] + elif mode == "add": + out[index] += updates[i * fused_updates_dimension + j] + else: + raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) return ib.get() - out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf") + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( - [shape], - [data, indices], - lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="scatter_nd_x86", diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 11e94cb4b93e..e937cb0c7b1f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1096,10 +1096,11 @@ TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // `types` contains: [data, indices, result] - ICHECK_EQ(types.size(), 3); + // `types` contains: [data, indices, updates, result] + ICHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* indices = types[1].as(); + const auto* updates = types[2].as(); if (data == nullptr) { ICHECK(types[0].as()) << "ScatterND: expect input data type to be TensorType but got " << types[0]; @@ -1110,8 +1111,14 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, << "ScatterND: expect indices type to be TensorType but got " << types[1]; return false; } + if (updates == nullptr) { + ICHECK(types[2].as()) + << "ScatterND: expect updates type to be TensorType but got " << types[2]; + return false; + } ICHECK(indices->dtype.is_int()) << "ScatterND: indices must be a tensor of integers."; - const auto out_shape = attrs.as()->out_shape; + + const auto out_shape = data->shape; const IntImmNode* mdim = indices->shape[0].as(); const size_t kdim = indices->shape.size() - 1; const size_t ndim = out_shape.size(); @@ -1120,7 +1127,7 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. for (size_t i = 0; i < kdim; i++) { - reporter->AssertEQ(indices->shape[i + 1], data->shape[i]); + reporter->AssertEQ(indices->shape[i + 1], updates->shape[i]); } std::vector oshape; @@ -1133,15 +1140,15 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); } - reporter->Assign(types[2], TensorType(oshape, data->dtype)); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); return true; } -Expr MakeScatterND(Expr data, Expr indices, const Array out_shape) { +Expr MakeScatterND(Expr data, Expr indices, Expr updates, String mode) { auto attrs = make_object(); - attrs->out_shape = out_shape; + attrs->mode = std::move(mode); static const Op& op = Op::Get("scatter_nd"); - return Call(op, {data, indices}, Attrs(attrs), {}); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND); @@ -1156,9 +1163,10 @@ whose shape is defined by indices. Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape (M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}). )code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(3) .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor.") .set_support_level(3) .add_type_rel("ScatterND", ScatterNDRel) .set_attr("TOpPattern", kOpaque); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1a3d0d4ac6e0..e11689cc1232 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4237,7 +4237,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_round/", "test_scan9_sum/", "test_scan_sum/", - "test_scatternd/", "test_simple_rnn_defaults/", "test_simple_rnn_with_initial_bias/", "test_strnormalizer_export_monday_casesensintive_lower/", diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index bf0a7e4952e5..e84b22b30ce1 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1833,38 +1833,39 @@ def test_cumprod(target, dev): @tvm.testing.parametrize_targets def test_scatter_nd(target, dev): - def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5): + def verify_scatter_nd( + data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5 + ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype)) + updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) - out = relay.op.scatter_nd(data, indices, shape) - func = relay.Function([data, indices], out) + out = relay.op.scatter_nd(data, indices, updates, mode) + func = relay.Function([data, indices, updates], out) for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np) + op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) - def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5): + def verify_scatter_nd_with_stack( + data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5 + ): data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices_vars = [ relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) ] + updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) # test if scatter_nd works in case indices are prepared by another Relay operator indices = relay.op.stack(indices_vars, axis=0) - out = relay.op.scatter_nd(data, indices, shape) + out = relay.op.scatter_nd(data, indices, updates, mode) func = relay.Function( - [ - data, - ] - + indices_vars, + [data, updates] + indices_vars, out, ) - fargs = [ - data_np, - ] + fargs = [data_np, updates_np] for a in indices_np: fargs.append(a) for kind in ["graph", "debug"]: @@ -1872,39 +1873,47 @@ def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, op_res = intrp.evaluate(func)(*fargs) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) - data = np.array([2, 3, 0]) + data = np.zeros((2, 2)).astype("int64") indices = np.array([[1, 1, 0], [0, 1, 0]]) - shape = (2, 2) + updates = np.array([2, 3, 0]) out = np.array([[0, 0], [2, 3]]) - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) + verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd_with_stack(data, indices, updates, out) - data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + data = np.zeros((2, 2, 2, 2)).astype("int64") indices = np.array([[0, 1], [1, 1]]) - shape = (2, 2, 2, 2) + updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) + verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd_with_stack(data, indices, updates, out) - data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") indices = np.array([[1, 0, 0]]) + updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") shape = (2, 1560) - out = np.zeros(shape).astype("float32") - out[1, :] += data[0, :] - out[0, :] += data[1, :] - out[0, :] += data[2, :] - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) - - data = np.ones((5, 3)).astype("float64") - indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") - shape = (2, 7, 3) - out = np.zeros(shape).astype("float64") - for i in range(indices.shape[1]): - for j in range(data.shape[1]): - out[indices[0, i], indices[1, i], j] += data[i, j] - verify_scatter_nd(data, indices, shape, out) - verify_scatter_nd_with_stack(data, indices, shape, out) + data = np.zeros(shape).astype("float32") + out = data.copy() + out[1, :] += updates[0, :] + out[0, :] += updates[1, :] + out[0, :] += updates[2, :] + verify_scatter_nd(data, indices, updates, out) + verify_scatter_nd_with_stack(data, indices, updates, out) + + for mode in ["add", "update"]: + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( + "int64" + ) + updates = np.ones((5, 3)).astype("float64") + shape = (2, 7, 3) + data = np.random.random(shape).astype("float64") + out = data.copy() + for i in range(indices.shape[1]): + for j in range(updates.shape[1]): + if mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "update": + out[indices[0, i], indices[1, i], j] = updates[i, j] + verify_scatter_nd(data, indices, updates, out, mode) + verify_scatter_nd_with_stack(data, indices, updates, out, mode) def test_unique(): diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index ad73bb51f2d3..648ef62a04ee 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -23,44 +23,64 @@ @tvm.testing.parametrize_targets def test_scatter_nd(dev, target): - def check_scatter_nd(data, indices, shape, out): + def check_scatter_nd(data, indices, updates, out, mode="add"): implementations = { - "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), - "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), - "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), + "generic": ( + lambda x, y, z: topi.scatter_nd(x, y, z, mode), + topi.generic.schedule_extern, + ), + "gpu": ( + lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode), + topi.generic.schedule_extern, + ), + "cpu": ( + lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode), + topi.generic.schedule_extern, + ), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, dev, fcompute, fschedule) + tvm.topi.testing.compare_numpy_tvm( + [data, indices, updates], out, target, dev, fcompute, fschedule + ) - data = np.array([2, 3, 0]) + data = np.zeros((2, 2)).astype("int64") indices = np.array([[1, 1, 0], [0, 1, 0]]) - shape = (2, 2) + updates = np.array([2, 3, 0]) out = np.array([[0, 0], [2, 3]]) - check_scatter_nd(data, indices, shape, out) + check_scatter_nd(data, indices, updates, out) - data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + data = np.zeros((2, 2, 2, 2)).astype("int64") indices = np.array([[0, 1], [1, 1]]) - shape = (2, 2, 2, 2) + updates = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) - check_scatter_nd(data, indices, shape, out) + check_scatter_nd(data, indices, updates, out) - data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") indices = np.array([[1, 0, 0]]) + updates = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") shape = (2, 1560) - out = np.zeros(shape).astype("float32") - out[1, :] += data[0, :] - out[0, :] += data[1, :] - out[0, :] += data[2, :] - check_scatter_nd(data, indices, shape, out) + data = np.zeros(shape).astype("float32") + out = data.copy() + out[1, :] += updates[0, :] + out[0, :] += updates[1, :] + out[0, :] += updates[2, :] + check_scatter_nd(data, indices, updates, out) - data = np.ones((5, 3)).astype("float64") - indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") - shape = (2, 7, 3) - out = np.zeros(shape).astype("float64") - for i in range(indices.shape[1]): - for j in range(data.shape[1]): - out[indices[0, i], indices[1, i], j] += data[i, j] - check_scatter_nd(data, indices, shape, out) + for mode in ["add", "update"]: + updates = np.ones((5, 3)).astype("float64") + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( + "int64" + ) + shape = (2, 7, 3) + data = np.random.random(shape).astype("float64") + out = data.copy() + for i in range(indices.shape[1]): + for j in range(updates.shape[1]): + if mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "update": + out[indices[0, i], indices[1, i], j] = updates[i, j] + + check_scatter_nd(data, indices, updates, out, mode) if __name__ == "__main__":