Skip to content

Commit

Permalink
Review comment: Additional check for correct number of arguments in c…
Browse files Browse the repository at this point in the history
…omposite function

Change-Id: I7ffd6074bbbe9020b6efe64d48b80f79714ce8bd
  • Loading branch information
ashutosh-arm committed Jun 17, 2022
1 parent 976cf2a commit f984e0e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
1 change: 0 additions & 1 deletion python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def qnn_max_pool2d_pattern():
def check_qnn_max_pool2d(pattern):
"""Check if max pool2d is supported by CMSIS-NN."""
output = pattern
input_op = None

if str(pattern.op.name) == "clip":
pooling = pattern.args[0]
Expand Down
7 changes: 7 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def test_same_input_to_binary_op(op, relu_type):
# validate pattern matching
assert_partitioned_function(orig_mod, cmsisnn_mod)

# Check if the number of internal function parameter is 1
cmsisnn_global_func = cmsisnn_mod["tvmgen_default_cmsis_nn_main_0"]
assert (
isinstance(cmsisnn_global_func.body, tvm.relay.expr.Call)
and len(cmsisnn_global_func.body.args) == 1
), "Composite function for the binary op should have only 1 parameter."

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
inputs = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_duplicate_constant_arguments():
dtype = "int8"
shape = (1, 3, 3, 32)
operand0 = generate_variable("operand0", shape, dtype)
operand1 = generate_variable("operand0", shape, dtype)
operand1 = generate_variable("operand1", shape, dtype)
binary_op = make_binary_op(
relay.qnn.op.add,
operand0,
Expand Down

0 comments on commit f984e0e

Please sign in to comment.