Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][RELAY][ONNX] Scatter ND #7927

Merged
merged 6 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
};

struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
Array<Integer> out_shape;
String mode;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(out_shape).describe("Output shape of the scatter.");
TVM_ATTR_FIELD(mode).describe(
"Accumulation mode of the scatter, either \"update\" or \"add\".");
}
};

Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,18 @@ def _impl_v1(cls, inputs, attr, params):
return _op.scatter(inputs[0], inputs[1], inputs[2], axis)


class ScatterND(OnnxOpConverter):
"""Operator converter for Scatter."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
indices_dim = len(infer_shape(inputs[1]))
axes = list(range(indices_dim))
return _op.scatter_nd(
inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update"
)


class Greater(OnnxOpConverter):
"""Operator logical greater."""

Expand Down Expand Up @@ -2820,6 +2832,7 @@ def _get_convert_map(opset):
"Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}),
"Scatter": Scatter.get_converter(opset),
"ScatterElements": Scatter.get_converter(opset),
"ScatterND": ScatterND.get_converter(opset),
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
"Unsqueeze": Unsqueeze.get_converter(opset),
"Pad": Pad.get_converter(opset),
Expand Down
21 changes: 4 additions & 17 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,26 +2118,13 @@ def index_put(self, inputs, input_types):
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
# accumulate parameter is ignored.
# torch.index_put default is False but Relay.scatter_nd accumulates values.
# We assume there is no duplicate indices in torch.index_put input
if not accumulate:
logging.warning(
"torch.index_put accumulate parameter is False. "
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
"Make sure there is no duplicate indices in torch.index_put input."
)
# Relay scatter_nd does not support input tensor
# We assume that torch.index_put is used with empty zero-values input tensor
# scatter_nd will create empty zero-values tensor with a given shape
out_shape = self.infer_shape(in_tensor)
logging.warning(
"tvm.relay.scatter_nd operator does not support input tensor parameter. "
"TVM assumes that torch.index_put is used with empty zero-values input tensor"
)
mode = "update"
else:
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(values, index_tensor, out_shape)
return _op.transform.scatter_nd(in_tensor, index_tensor, values, mode)

def scalar_tensor(self, inputs, input_types):
data = inputs[0]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def gather_nd_grad(orig, grad):
Returns the gradient of gather_nd, which is simply scatter_nd.
"""
data, indices = orig.args
return [scatter_nd(grad, indices, data.checked_type.concrete_shape), zeros_like(indices)]
return [scatter_nd(zeros_like(data), indices, grad, mode="add"), zeros_like(indices)]


@register_gradient("reshape_like")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compute_scatter_add(attrs, inputs, output_type):
@_reg.register_compute("scatter_nd")
def compute_scatter_nd(attrs, inputs, output_type):
"""Compute definition of scatter_nd"""
return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]
return [topi.scatter_nd(inputs[0], inputs[1], inputs[2], attrs.mode)]


_reg.register_strategy("scatter_nd", strategy.scatter_nd_strategy)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""

def _compute_scatter_nd(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.mode)]

return _compute_scatter_nd

Expand Down
13 changes: 8 additions & 5 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def scatter_nd(data, indices, out_shape):
"""Scatter values from an array.
def scatter_nd(data, indices, updates, mode="update"):
"""Scatter values from an array and update.

See :py:func:`tvm.topi.scatter` for how data is scattered.

Expand All @@ -323,15 +323,18 @@ def scatter_nd(data, indices, out_shape):
indices : relay.Expr
The index locations to update.

out_shape : Union[Tuple[int], List[int]]
Output shape of the scatter.
updates : relay.Expr
The values to update.

mode : string
The accumulation mode for scatter. "update" or "add"

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_nd(data, indices, out_shape)
return _make.scatter_nd(data, indices, updates, mode)


def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
Expand Down
66 changes: 38 additions & 28 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,12 @@ def update_func(dst_ptr, dst_index, update):
return out


def scatter_nd(data, indices, shape):
def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.

Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
(M, Y_0, ..., Y_{K-1}), and output with shape (X_0, X_1, ..., X_{N-1}), scatter_nd computes
Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), indices with shape
(M, Y_0, ..., Y_{K-1}), and output copied from data with shape (X_0, X_1, ..., X_{N-1}),
scatter_nd computes

.. code-block::

Expand All @@ -737,9 +738,9 @@ def scatter_nd(data, indices, shape):
x_M,
...,
x_{N-1}
] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
] = f(output[...], updates[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}])

all other entries in the output are 0. Repeated indices are summed.
where the update function f is determinted by the mode.

Parameters
----------
Expand All @@ -749,35 +750,41 @@ def scatter_nd(data, indices, shape):
indices : tvm.te.Tensor
The indices of the values to extract.

shape : Sequence[int]
The output shape. This must be specified because it cannot be inferred.
updates : tvm.te.Tensor
The updates to apply at the Indices

mode : string
The update mode for the algorithm, either "update" or "add"
If update, the update values will replace the input data
If add, the update values will be added to the input data

Returns
-------
ret : tvm.te.Tensor
"""
_verify_scatter_nd_inputs(data, indices, shape)
_verify_scatter_nd_inputs(data, indices, updates)

def gen_ir(data_ptr, indices_ptr, out_ptr):
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
updates = ib.buffer_ptr(updates_ptr)
out = ib.buffer_ptr(out_ptr)

# We combine all the indices dimensions but the first one into a single
# dimension so we can iterate it in single loop instead of an arbitrary
# number of loops. We do the same thing for all the data dimensions.
# number of loops. We do the same thing for all the update dimensions.
fused_indices_dimension = 1
for i in indices_ptr.shape[1:]:
fused_indices_dimension *= i

fused_data_dimension = 1
for i in data_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_data_dimension *= i
fused_updates_dimension = 1
for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
fused_updates_dimension *= i

fused_shape = 1
for i in shape:
for i in data_ptr.shape:
fused_shape *= i

# For now we avoid parallizing over dimensions indexed by `indices` as
Expand All @@ -789,38 +796,41 @@ def gen_ir(data_ptr, indices_ptr, out_ptr):
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
tdim = min(max_threads, fused_data_dimension)
tdim = min(max_threads, fused_updates_dimension)
ib.scope_attr(tx, "thread_extent", tdim)
bdim = ceil_div(fused_data_dimension, tdim)
bdim = ceil_div(fused_updates_dimension, tdim)
ib.scope_attr(bx, "thread_extent", bdim)

# zero data
# TODO(tkonolige): could we use topi.full to zero it instead?
with ib.for_range(0, ceil_div(fused_shape, bdim)) as i:
index = i * fused_data_dimension + bx * tdim + tx
index = i * fused_updates_dimension + bx * tdim + tx
with ib.if_scope(index < fused_shape):
out[index] = tvm.tir.Cast(data_ptr.dtype, 0)
out[index] = data[index]

with ib.for_range(0, fused_indices_dimension) as i:
j = bx * tdim + tx
with ib.if_scope(j < fused_data_dimension):
offset = fused_data_dimension
with ib.if_scope(j < fused_updates_dimension):
offset = fused_updates_dimension
index = j # This is x_M, .. x_{N-1} part of the index into out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part
# of the index into out.
for l in reversed(range(indices_ptr.shape[0].value)):
# indices[i * l * fused_indices_dimension] = indices[l, y_0, ... y_{k-1}]
index += offset * indices[i + l * fused_indices_dimension]
offset *= shape[l]
out[index] += data[i * fused_data_dimension + j]
offset *= data_ptr.shape[l]
if mode == "update":
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)

return ib.get()

out_buf = tvm.tir.decl_buffer(shape, data.dtype, "out_buf")
out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[shape],
[data, indices],
lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_nd_cuda",
Expand Down
Loading