Skip to content

Commit

Permalink
[BugFix][Relay] fix scatter_nd type relation (#14773)
Browse files Browse the repository at this point in the history
* [BugFix][Relay] fix scatter_nd type relation

ScatterND requires updates.shape[K:] == output.shape[M:],
not data.shape[K:] == output.shape[M:]

* [BugFix][Relay] fix scatter_nd type relation
add testcase for scatter_nd with m != k

---------

Co-authored-by: Jiang.Zhongzhou <jack.river@evas.ai>
  • Loading branch information
JR4er and Jiang.Zhongzhou authored May 6, 2023
1 parent bf1be35 commit 571eff9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const size_t kdim = indices->shape.size() - 1;
const size_t ndim = out_shape.size();
ICHECK_LE(size_t(mdim->value), ndim)
<< "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices "
<< "ScatterND: Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices "
"with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N.";
// Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
for (size_t i = 0; i < kdim; i++) {
Expand All @@ -1197,9 +1197,9 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
oshape.push_back(x);
}

// data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1}
// updates: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1}
for (size_t i = mdim->value; i < ndim; i++) {
reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
reporter->AssertEQ(updates->shape[i - mdim->value + kdim], oshape[i]);
}

reporter->Assign(types[3], TensorType(data->shape, data->dtype));
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,26 @@ def before():

test_scatter_nd_large_shape()

def test_scatter_nd_inequal_m_k():
def before():
data = relay.const(np.zeros((1, 1, 10), dtype="float32"), dtype="float32")
indices = relay.const(np.zeros((2, 1, 1, 1), dtype="float32"), dtype="int64")
update = relay.const(np.ones((1, 1, 1, 10), dtype="float32"), dtype="float32")
b = relay.op.scatter_nd(data, indices, update)
return relay.Function(relay.analysis.free_vars(b), b)

passes = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldConstant(),
]
)
before_mod = tvm.IRModule.from_expr(before())
with tvm.transform.PassContext(opt_level=3):
after_mod = passes(before_mod)

test_scatter_nd_inequal_m_k()

def verify_scatter_nd(
data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
):
Expand Down

0 comments on commit 571eff9

Please sign in to comment.