Skip to content

Commit

Permalink
[BugFix][Relay] fix scatter_nd type relation
Browse files Browse the repository at this point in the history
add testcase for scatter_nd with m != k
  • Loading branch information
Jiang.Zhongzhou committed May 6, 2023
1 parent 3043487 commit da3b61e
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,26 @@ def before():

test_scatter_nd_large_shape()

def test_scatter_nd_inequal_m_k():
def before():
data = relay.const(np.zeros((1, 1, 10), dtype="float32"), dtype="float32")
indices = relay.const(np.zeros((2, 1, 1, 1), dtype="float32"), dtype="int64")
update = relay.const(np.ones((1, 1, 1, 10), dtype="float32"), dtype="float32")
b = relay.op.scatter_nd(data, indices, update)
return relay.Function(relay.analysis.free_vars(b), b)

passes = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldConstant(),
]
)
before_mod = tvm.IRModule.from_expr(before())
with tvm.transform.PassContext(opt_level=3):
after_mod = passes(before_mod)

test_scatter_nd_inequal_m_k()

def verify_scatter_nd(
data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
):
Expand Down

0 comments on commit da3b61e

Please sign in to comment.