Skip to content

Commit

Permalink
[CMSIS-NN] Fixed the case with duplicate operands in the QNN binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutosh-arm committed Jun 17, 2022
1 parent 1b8f3b5 commit 976cf2a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,18 @@ class ExtractConstantsMutator : public MixedModeMutator {
function_signature.push_back(arg);
} else {
if (arg.as<VarNode>()) {
function_signature.push_back(arg);
// Only push if its not already present as multiple consumers of any input var
// will appear only once in the function signature.
bool found_in_existing_signature = false;
for (auto& sign : function_signature) {
if (arg.same_as(sign)) {
found_in_existing_signature = true;
break;
}
}
if (!found_in_existing_signature) {
function_signature.push_back(arg);
}
}
new_args.push_back(arg);
}
Expand Down
14 changes: 12 additions & 2 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,12 @@ class RelayToTIRVisitor : public MixedModeMutator {

BufferCreator buffer_creator;
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
tir::Var input_1;
if (mul_call->args[0].same_as(mul_call->args[1])) {
input_1 = input_0;
} else {
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
}
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));

tvm::Array<PrimExpr> args = {
Expand Down Expand Up @@ -626,7 +631,12 @@ class RelayToTIRVisitor : public MixedModeMutator {

BufferCreator buffer_creator;
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
tir::Var input_1;
if (add_call->args[0].same_as(add_call->args[1])) {
input_1 = input_0;
} else {
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
}
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));

tvm::Array<PrimExpr> args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
auto new_body = VisitExpr(func->body);
Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);

// Updating new_func parameters could result into uniquification of function parameters.
// Call arguments need to be aligned to the number of arguments expected by new_func.
if (new_args[0].same_as(new_args[1])) {
new_args.erase(new_args.begin());
}
return Call(new_func, new_args);
}

Expand Down
54 changes: 53 additions & 1 deletion tests/python/contrib/test_cmsisnn/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def make_model(
def test_op_int8(
op, relu_type, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
):
"""Tests QNN Conv2D operator for CMSIS-NN"""
"""Tests QNN binary operator for CMSIS-NN"""
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_USMP_CORSTONE300_RUNNER
Expand Down Expand Up @@ -145,6 +145,58 @@ def test_op_int8(
)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
def test_same_input_to_binary_op(op, relu_type):
"""Tests QNN binary operator for CMSIS-NN where both inputs are the same"""
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_USMP_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
input_ = generate_variable("input")
input_scale = 0.256
input_zero_point = 33

model = make_model(
op,
input_,
input_,
input_scale,
input_zero_point,
input_scale,
input_zero_point,
relu_type,
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
assert_partitioned_function(orig_mod, cmsisnn_mod)

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
inputs = {
"input": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=1,
),
test_runner,
interface_api,
use_unpacked_api,
)


def parameterize_for_constant_inputs(test):
"""Generates parameters in such a way so that at least one of the inputs is a constant,
both can't be variables, both can't be scalars.
Expand Down
34 changes: 34 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_extract_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,40 @@ def test_nested_function():
relay.transform.InferType()(mod)


@tvm.testing.requires_cmsisnn
def test_internal_function_with_duplicate_arguments():
"""Tests the pass ExternConstants when a composite function
is present within global function with repeating arguments
to one of the binary ops.
"""
input0 = relay.var("input0", shape=(8, 8))
binary_op0 = input0 + input0
binary_op1 = binary_op0 * relay.const(5.0, "float32")
local_func = relay.Function([input0], binary_op1, relay.TensorType((8, 8), "float32"))
local_func = set_composite_func_attr(local_func, "cmsis-nn")

arg = relay.var("arg", shape=(8, 8))
call_local_func = relay.Call(local_func, [arg])
extern_func = relay.Function([arg], call_local_func, relay.TensorType((8, 8), "float32"))

global_arg = relay.var("global_var", shape=(8, 8))
global_var = relay.GlobalVar("external_function")
extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint)
call_extern_func = relay.Call(global_var, [global_arg])
main_func = relay.Function([global_arg], call_extern_func, relay.TensorType((8, 8), "float32"))
main_var = relay.GlobalVar("main")

mod = tvm.IRModule()
mod[global_var] = extern_func
mod[main_var] = main_func

mod = ExtractConstantsFromPartitionedFunction()(mod)
constant_verifier = CheckFunctionsForConstants()
constant_verifier.visit_function(mod[global_var])
constant_verifier.check_num_constants()
relay.transform.InferType()(mod)


@tvm.testing.requires_cmsisnn
def test_multiple_functions():
"""Tests the pass ExternConstants when global function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,47 @@ def test_all_primary_operands_tensor_constants():
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)


@tvm.testing.requires_cmsisnn
def test_duplicate_constant_arguments():
"""Tests the pass when repeating operands are arguments to the binary op"""
dtype = "int8"
shape = (1, 3, 3, 32)
operand0 = generate_variable("operand0", shape, dtype)
operand1 = generate_variable("operand0", shape, dtype)
binary_op = make_binary_op(
relay.qnn.op.add,
operand0,
operand0,
input_0_scale=0.0128,
input_0_zero_point=32,
input_1_scale=0.256,
input_1_zero_point=-64,
)

local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype))
local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add")

rng = np.random.default_rng(12345)
arg0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype))
call_local_func = relay.Call(local_func, [arg0, arg0])
extern_func = relay.Function([], call_local_func, relay.TensorType(shape, dtype))

global_var = relay.GlobalVar("external_function")
extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint)
call_extern_func = relay.Call(global_var, [])
main_func = relay.Function([], call_extern_func, relay.TensorType(shape, dtype))
main_var = relay.GlobalVar("main")

mod = tvm.IRModule()
mod[global_var] = extern_func
mod[main_var] = main_func

mod = relay.transform.InferType()(mod)
mod = ScalarToTensorConstants()(mod)
new_mod = relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)


@tvm.testing.requires_cmsisnn
def test_non_cmsisnn_ext_func():
"""Non CMSISNN functions should not be altered."""
Expand Down

0 comments on commit 976cf2a

Please sign in to comment.