diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index cd8240d57d8d..a0111ff7cdbf 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1185,7 +1185,7 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, const size_t kdim = indices->shape.size() - 1; const size_t ndim = out_shape.size(); ICHECK_LE(size_t(mdim->value), ndim) - << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " + << "ScatterND: Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N."; // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. for (size_t i = 0; i < kdim; i++) { @@ -1197,9 +1197,9 @@ bool ScatterNDRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.push_back(x); } - // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} + // updates: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} for (size_t i = mdim->value; i < ndim; i++) { - reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); + reporter->AssertEQ(updates->shape[i - mdim->value + kdim], oshape[i]); } reporter->Assign(types[3], TensorType(data->shape, data->dtype)); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 493bf00fc6ad..5e86ab8da76d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1942,6 +1942,26 @@ def before(): test_scatter_nd_large_shape() + def test_scatter_nd_inequal_m_k(): + def before(): + data = relay.const(np.zeros((1, 1, 10), dtype="float32"), dtype="float32") + indices = relay.const(np.zeros((2, 1, 1, 1), dtype="float32"), dtype="int64") + update = relay.const(np.ones((1, 1, 1, 10), dtype="float32"), dtype="float32") + b = relay.op.scatter_nd(data, indices, update) + return relay.Function(relay.analysis.free_vars(b), b) + + passes = tvm.transform.Sequential( + [ + relay.transform.InferType(), + relay.transform.FoldConstant(), + ] + ) + before_mod = tvm.IRModule.from_expr(before()) + with tvm.transform.PassContext(opt_level=3): + after_mod = passes(before_mod) + + test_scatter_nd_inequal_m_k() + def verify_scatter_nd( data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5 ):