Skip to content

Commit

Permalink
[microNPU] Add support for MEAN with uint8 ifm
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyag-grovety committed Mar 21, 2023
1 parent fe3fa9d commit fdf95da
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 45 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ def callback(
eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0

scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int16"), dtype="int16")
rounding_mode = "TRUNCATE" if params.ifm.dtype == "uint8" else "NATURAL"

reduced_op = ethosu_ops.ethosu_binary_elementwise(
ifm=reduced_op,
Expand All @@ -1068,7 +1069,7 @@ def callback(
ifm2_channels=out_channels,
reversed_operands=False,
ofm_dtype="int8",
rounding_mode="NATURAL",
rounding_mode=rounding_mode,
)
elif (
params.ifm.q_params.scale_f32 == params.ofm.q_params.scale_f32
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,10 +1336,12 @@ def check_axis(num_dims, axis):
return axis in ([0], [1], [0, 1])
return axis in ([1], [2], [1, 2])

tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
if not check_valid_dtypes([self.ifm], [np.int8, np.uint8]):
return False
if self.ifm.dtype != self.ofm.dtype:
if not check_valid_dtypes([self.ofm], [np.int8]):
return False
# ifm dtype uint8 is only supported for conv2d + mul case
if self.ifm.dtype == "uint8" and not (self.axis == [1, 2] and self.keepdims):
return False
if not len(self.ifm.shape) in [2, 3, 4]:
return False
Expand All @@ -1355,7 +1357,12 @@ def check_axis(num_dims, axis):
or self.ifm.q_params.zero_point != self.ofm.q_params.zero_point
) and input_size > 4096:
return False
if self.axis == [1, 2] and self.keepdims and self.ifm.dtype == "int8" and input_size > 256:
if (
self.axis == [1, 2]
and self.keepdims
and (self.ifm.dtype == "int8" or self.ifm.dtype == "uint8")
and input_size > 256
):
return False
# Large kernel height reshape only when axis is [1, 2]
if self.axis != [1, 2] and self.height > 64:
Expand Down
33 changes: 18 additions & 15 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,25 +397,25 @@ def binary_elementwise(lhs, rhs):
ACCEL_TYPES,
)
@pytest.mark.parametrize(
"ifm_shape, axis, keep_dims, use_same_quantization",
"ifm_shape, axis, keep_dims, use_same_quantization, ifm_dtype",
[
# mean to depthwise + multiply
[(1, 8, 16, 16), (1, 2), True, False],
[(1, 3, 4), (0, 1), True, False],
[(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64
[(1, 8, 16, 16), (1, 2), True, False, "int8"],
[(1, 8, 16, 16), (1, 2), True, True, "uint8"],
[(1, 3, 4), (0, 1), True, False, "int8"],
[(1, 65, 2, 1), (1, 2), True, False, "int8"], # special case when h > 64
# mean to average pool
[(1, 8, 16, 16), (2,), False, True],
[(3, 3, 4), (0,), True, True],
[(8, 5), (0,), False, True],
[(1, 8, 16, 16), (2,), False, True, "int8"],
[(3, 3, 4), (0,), True, True, "int8"],
[(8, 5), (0,), False, True, "int8"],
# mean to depthwise
[(1, 8, 16, 16), (2,), True, False],
[(1, 8, 16, 16), (2, 1), False, False],
[(8, 4), (0,), False, False],
[(1, 8, 16, 16), (2,), True, False, "int8"],
[(1, 8, 16, 16), (2, 1), False, False, "int8"],
[(8, 4), (0,), False, False, "int8"],
],
)
def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization):
def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization, ifm_dtype):
np.random.seed(0)
dtype = "int8"

def create_mod_from_tflite():
class Model(tf.Module):
Expand Down Expand Up @@ -447,13 +447,13 @@ def representative_dataset():
mod, _ = relay.frontend.from_tflite(
tflite_model,
shape_dict={"ifm": ifm_shape},
dtype_dict={"ifm": dtype},
dtype_dict={"ifm": ifm_dtype},
)
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
return mod, input_data, output_data

def create_mod_from_relay():
ifm = relay.var("input", shape=ifm_shape, dtype=dtype)
ifm = relay.var("input", shape=ifm_shape, dtype=ifm_dtype)
cast = relay.cast(ifm, dtype="int32")
mean = relay.mean(cast, axis=axis, keepdims=keep_dims)
requantize = relay.qnn.op.requantize(
Expand All @@ -467,7 +467,10 @@ def create_mod_from_relay():
func = relay.Function(relay.analysis.free_vars(requantize), requantize)
mod = tvm.IRModule.from_expr(func)

input_data = {"input": np.random.randint(low=-127, high=128, size=ifm_shape, dtype=dtype)}
low, high = (0, 256) if ifm_dtype == "uint8" else (-127, 128)
input_data = {
"input": np.random.randint(low=low, high=high, size=ifm_shape, dtype=ifm_dtype)
}
output_data = generate_ref_data(mod, input_data)
return mod, input_data, output_data

Expand Down
53 changes: 28 additions & 25 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,31 +1536,31 @@ def representative_dataset():


@pytest.mark.parametrize(
"ifm_shape, axis, keep_dims, use_same_quantization",
"ifm_shape, axis, keep_dims, use_same_quantization, ifm_dtype",
[
# mean to depthwise + multiply
[(1, 8, 16, 16), (1, 2), True, False],
[(1, 8, 16, 16), (2, 1), True, False],
[(1, 3, 4), (0, 1), True, False],
[(8, 5), (1, 0), True, False],
[(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64
[(1, 8, 16, 16), (1, 2), True, False, "int8"],
[(1, 8, 16, 16), (1, 2), True, True, "uint8"],
[(1, 8, 16, 16), (2, 1), True, False, "int8"],
[(1, 3, 4), (0, 1), True, False, "int8"],
[(8, 5), (1, 0), True, False, "int8"],
[(1, 65, 2, 1), (1, 2), True, False, "int8"], # special case when h > 64
# mean to average pool
[(1, 8, 16, 16), (1,), True, True],
[(1, 8, 16, 16), (2,), False, True],
[(1, 8, 16, 16), (1, 2), False, True],
[(3, 3, 4), (0,), True, True],
[(3, 3, 4), (1,), False, True],
[(8, 5), (0,), False, True],
[(8, 5), (1,), True, True],
[(1, 8, 16, 16), (1,), True, True, "int8"],
[(1, 8, 16, 16), (2,), False, True, "int8"],
[(1, 8, 16, 16), (1, 2), False, True, "int8"],
[(3, 3, 4), (0,), True, True, "int8"],
[(3, 3, 4), (1,), False, True, "int8"],
[(8, 5), (0,), False, True, "int8"],
[(8, 5), (1,), True, True, "int8"],
# mean to depthwise
[(1, 8, 16, 16), (1,), True, False],
[(1, 8, 16, 16), (2,), True, False],
[(1, 8, 16, 16), (1, 2), False, False],
[(8, 4), (0,), False, False],
[(1, 8, 16, 16), (1,), True, False, "int8"],
[(1, 8, 16, 16), (2,), True, False, "int8"],
[(1, 8, 16, 16), (1, 2), False, False, "int8"],
[(8, 4), (0,), False, False, "int8"],
],
)
def test_mean(ifm_shape, axis, keep_dims, use_same_quantization):
dtype = "int8"
def test_mean(ifm_shape, axis, keep_dims, use_same_quantization, ifm_dtype):

def create_tflite_graph():
class Model(tf.Module):
Expand Down Expand Up @@ -1592,12 +1592,12 @@ def representative_dataset():
mod, _ = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
dtype_dict={"input": ifm_dtype},
)
return mod

def create_relay_graph_with_same_quantization():
ifm = relay.var("input", shape=ifm_shape, dtype=dtype)
ifm = relay.var("input", shape=ifm_shape, dtype=ifm_dtype)
cast = relay.cast(ifm, dtype="int32")
mean = relay.mean(cast, axis=axis, keepdims=keep_dims)
requantize = relay.qnn.op.requantize(
Expand Down Expand Up @@ -1654,16 +1654,19 @@ def calculate_expected_output_shape():

# check IFM
assert tuple(in_var.checked_type.shape) == ifm_shape
assert in_var.checked_type.dtype == dtype
assert in_var.checked_type.dtype == ifm_dtype

# check OFM
assert tuple(out_var.checked_type.shape) == out_shape
assert out_var.checked_type.dtype == dtype
assert out_var.checked_type.dtype == "int8"

# check expected legalization case
if axis in [(1, 2), (2, 1), (0, 1), (1, 0)] and keep_dims and dtype == "int8":
if axis in [(1, 2), (2, 1), (0, 1), (1, 0)] and keep_dims and ifm_dtype == "int8":
assert depthwise_op and mul_op
assert mul_op.attrs.operator_type == "MUL" and mul_op.attrs.rounding_mode == "NATURAL"
elif axis == (1, 2) and keep_dims and ifm_dtype == "uint8":
assert depthwise_op and mul_op
assert mul_op.attrs.operator_type == "MUL"
assert mul_op.attrs.operator_type == "MUL" and mul_op.attrs.rounding_mode == "TRUNCATE"
elif pooling_op:
attrs = pooling_op.attrs
assert (
Expand Down

0 comments on commit fdf95da

Please sign in to comment.