Skip to content

Commit

Permalink
[TOPI][RELAY][ONNX] Scatter ND (#7927)
Browse files Browse the repository at this point in the history
* passing topi tests

* passing relay tests, needs better shape checking still

* support ONNX operator

* add shape checking back in

* fix lint

* update docstring
  • Loading branch information
Matthew Brookhart authored Apr 28, 2021
1 parent dee3133 commit 8fce895
Show file tree
Hide file tree
Showing 14 changed files with 248 additions and 178 deletions.
5 changes: 3 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
};

struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
Array<Integer> out_shape;
String mode;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
TVM_ATTR_FIELD(mode).describe(
"Accumulation mode of the scatter, either \"update\" or \"add\".");
}
};

Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,18 @@ def _impl_v1(cls, inputs, attr, params):
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)


class ScatterND(OnnxOpConverter):
"""Operator converter for Scatter."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
indices_dim = len(infer_shape(inputs[1]))
axes = list(range(indices_dim))
return _op.scatter_nd(
inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update"
)


class Greater(OnnxOpConverter):
"""Operator logical greater."""

Expand Down Expand Up @@ -2874,6 +2886,7 @@ def _get_convert_map(opset):
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
"ScatterND": ScatterND.get_converter(opset),
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
"Unsqueeze": Unsqueeze.get_converter(opset),
"Pad": Pad.get_converter(opset),
Expand Down
21 changes: 4 additions & 17 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,26 +2118,13 @@ def index_put(self, inputs, input_types):
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
# accumulate parameter is ignored.
# torch.index_put default is False but Relay.scatter_nd accumulates values.
# We assume there is no duplicate indices in torch.index_put input
if not accumulate:
logging.warning(
"torch.index_put accumulate parameter is False. "
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
"Make sure there is no duplicate indices in torch.index_put input."
)
# Relay scatter_nd does not support input tensor
# We assume that torch.index_put is used with empty zero-values input tensor
# scatter_nd will create empty zero-values tensor with a given shape
out_shape = self.infer_shape(in_tensor)
logging.warning(
"tvm.relay.scatter_nd operator does not support input tensor parameter. "
"TVM assumes that torch.index_put is used with empty zero-values input tensor"
)
mode = "update"
else:
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(values, index_tensor, out_shape)
return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)

def scalar_tensor(self, inputs, input_types):
data = inputs[0]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def gather_nd_grad(orig, grad):
Returns the gradient of gather_nd, which is simply scatter_nd.
"""
data, indices = orig.args
return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
return [scatter_nd(zeros_like(data), indices, grad, mode="add"), zeros_like(indices)]


@register_gradient("reshape_like")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compute_scatter_add(attrs, inputs, output_type):
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]
return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)]


_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""

def _compute_scatter_nd(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.mode)]

return _compute_scatter_nd

Expand Down
13 changes: 8 additions & 5 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def scatter_nd(data, indices, out_shape):
"""Scatter values from an array.
def scatter_nd(data, indices, updates, mode="update"):
"""Scatter values from an array and update.
See :py:func:`tvm.topi.scatter` for how data is scattered.
Expand All @@ -323,15 +323,18 @@ def scatter_nd(data, indices, out_shape):
indices : relay.Expr
The index locations to update.
out_shape : Union[Tuple[int], List[int]]
Output shape of the scatter.
updates : relay.Expr
The values to update.
mode : string
The accumulation mode for scatter. "update" or "add"
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_nd(data, indices, out_shape)
return _make.scatter_nd(data, indices, updates, mode)


def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
Expand Down
66 changes: 38 additions & 28 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,12 @@ def update_func(dst_ptr, dst_index, update):
return out


def scatter_nd(data, indices, shape):
def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.
Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
(M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
(M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
scatter_nd computes
.. code-block::
Expand All @@ -737,9 +738,9 @@ def scatter_nd(data, indices, shape):
x_M,
...,
x_{N-1}
] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])
all other entries in the output are 0. Repeated indices are summed.
where the update function f is determinted by the mode.
Parameters
----------
Expand All @@ -749,35 +750,41 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.
shape : Sequence[int]
The output shape. This must be specified because it cannot be inferred.
updates : tvm.te.Tensor
The updates to apply at the Indices
mode : string
The update mode for the algorithm, either "update" or "add"
If update, the update values will replace the input data
If add, the update values will be added to the input data
Returns
-------
ret : tvm.te.Tensor
"""
_verify_scatter_nd_inputs(data, indices, shape)
_verify_scatter_nd_inputs(data, indices, updates)

def gen_ir(data_ptr, indices_ptr, out_ptr):
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
# number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i
fused_updates_dimension = 1
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_updates_dimension *= i

fused_shape = 1
for i in shape:
for i in data_ptr.shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
Expand All @@ -789,38 +796,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_data_dimension)
tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_data_dimension, tdim)
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# zero data
# TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
index = i * fused_data_dimension + bx * tdim + tx
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_data_dimension):
offset = fused_data_dimension
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= shape[l]
out[index] += data[i * fused_data_dimension + j]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[shape],
[data, indices],
lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_cuda",
Expand Down
Loading

0 comments on commit 8fce895

Please sign in to comment.