From f984e0e616424a9298f8bf04d1a7d9069a399886 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Thu, 16 Jun 2022 17:16:36 +0100 Subject: [PATCH] Review comment: Additional check for correct number of arguments in composite function Change-Id: I7ffd6074bbbe9020b6efe64d48b80f79714ce8bd --- python/tvm/relay/op/contrib/cmsisnn.py | 1 - tests/python/contrib/test_cmsisnn/test_binary_ops.py | 7 +++++++ .../contrib/test_cmsisnn/test_scalar_to_tensor_constant.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 09831929e527..8d714b7269d9 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -223,7 +223,6 @@ def qnn_max_pool2d_pattern(): def check_qnn_max_pool2d(pattern): """Check if max pool2d is supported by CMSIS-NN.""" output = pattern - input_op = None if str(pattern.op.name) == "clip": pooling = pattern.args[0] diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index b42b2ffd0d9d..26604da0a64a 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -178,6 +178,13 @@ def test_same_input_to_binary_op(op, relu_type): # validate pattern matching assert_partitioned_function(orig_mod, cmsisnn_mod) + # Check if the number of internal function parameter is 1 + cmsisnn_global_func = cmsisnn_mod["tvmgen_default_cmsis_nn_main_0"] + assert ( + isinstance(cmsisnn_global_func.body, tvm.relay.expr.Call) + and len(cmsisnn_global_func.body.args) == 1 + ), "Composite function for the binary op should have only 1 parameter." + # validate the output in_min, in_max = get_range_for_dtype_str(dtype) inputs = { diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py index 8a015e26c1f4..df54f7ce55f1 100644 --- a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py +++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py @@ -262,7 +262,7 @@ def test_duplicate_constant_arguments(): dtype = "int8" shape = (1, 3, 3, 32) operand0 = generate_variable("operand0", shape, dtype) - operand1 = generate_variable("operand0", shape, dtype) + operand1 = generate_variable("operand1", shape, dtype) binary_op = make_binary_op( relay.qnn.op.add, operand0,