diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 87af3a767c38..3340f835200b 100755 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -19,5 +19,6 @@ from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule from .add_subtract_multiply import * +from .argmax import argmax_compute, argmax_schedule from .softmax_slice import * from .clip import * diff --git a/python/tvm/topi/hexagon/slice_ops/argmax.py b/python/tvm/topi/hexagon/slice_ops/argmax.py new file mode 100644 index 000000000000..4d34cb50a0b0 --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/argmax.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Hexagon slice argmax compute and schedule""" + +from tvm import tir +from tvm import topi +from ..utils import get_layout_transform_fn + + +def argmax_compute(in_tensor, axis): + out_tensor = topi.argmax(in_tensor, axis) + return out_tensor + + +def argmax_stir_schedule_nhwc(func, in_layout, out_layout): + """Schedule for nhwc argmax""" + sch = tir.Schedule(func, debug_mask="all") + sch.transform_layout("A_red_temp", "A", in_layout) + sch.transform_layout("A_red", "A_red", out_layout) + return sch + + +def argmax_schedule(argmax_func, in_layout_str, out_layout_str): + """Schedule for argmax: top level function""" + if (in_layout_str == "nhwc-8h2w32c2w-2d") and (out_layout_str == "nhw-32h16w-2d"): + fp16_layout_transform = get_layout_transform_fn(in_layout_str) + int32_layout_transform = get_layout_transform_fn(out_layout_str) + tir_s = argmax_stir_schedule_nhwc( + argmax_func, fp16_layout_transform, int32_layout_transform + ) + return tir_s + raise RuntimeError(f"Unexpected input_layout, output_layout '{in_layout_str, out_layout_str}'") diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 3efc48c4d04f..95b25cc5a73b 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -42,6 +42,11 @@ def nhwc_8h2w32c2w_1d(n, h, w, c): return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2] +def nhw_32h16w_2d(n, h, w): + """Return index map for nhw_32h16w 2d layout""" + return [n, h // 32, w // 16, te.AXIS_SEPARATOR, h % 32, w % 16] + + def nhwc_4h4w32c_1d(n, h, w, c): """Return index map for nhwc_4h4232c 1d layout""" return [n, h // 4, w // 4, c // 32, h % 4, w % 4, c % 32] @@ -72,6 +77,8 @@ def get_layout_transform_fn(layout): return n11c_1024c_2d if layout == "n11c-1024c-1d": return n11c_1024c_1d + if layout == "nhw-32h16w-2d": + return nhw_32h16w_2d if layout == "nhwc-4h4w32c-2d": return nhwc_4h4w32c_2d if layout == "nhwc-4h4w32c-1d": diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 57a9dff8b424..c1d2b4046372 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -247,4 +247,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): return arr_np.reshape([n, 1, 1, c // 1024, 1024]) raise RuntimeError(f"Unexpected new_layout '{new_layout}'") + + if current_layout == "nhw": + if new_layout in ["nhw-32h16w-2d"]: + n, h, w = arr_np.shape + return arr_np.reshape([n, h // 32, 32, w // 16, 16]).transpose(0, 1, 3, 2, 4) + + raise RuntimeError(f"Unexpected new_layout '{new_layout}'") + raise RuntimeError(f"Unexpected current_layout '{current_layout}'") diff --git a/tests/python/contrib/test_hexagon/topi/test_argmax_slice.py b/tests/python/contrib/test_hexagon/topi/test_argmax_slice.py new file mode 100644 index 000000000000..4cbd524f4abf --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_argmax_slice.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Tests for Hexagon slice argmax op """ +import numpy as np + +import tvm +import tvm.testing +from tvm import te +import tvm.topi.hexagon.slice_ops as sl +import tvm.contrib.hexagon +from ..infrastructure import allocate_hexagon_array, transform_numpy + + +class TestArgMaxSlice: + """Argmax Slice Op Tests""" + + ( + input_shape, + input_layout, + output_layout, + in_axis, + in_axis_sep, + out_axis_sep, + ) = tvm.testing.parameters( + ((1, 64, 64, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]), + ((3, 32, 16, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]), + ((1, 32, 32, 64), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]), + ) + dtype = tvm.testing.parameter("float16") + working_scope = tvm.testing.parameter("global.vtcm") + + @tvm.testing.fixture + def input_np(self, input_shape, dtype): + return np.random.uniform(size=input_shape).astype(dtype) + + @tvm.testing.fixture + def transformed_input_np(self, input_np, input_layout): + return transform_numpy(input_np, "nhwc", input_layout) + + @tvm.testing.fixture + def expected_output_np(self, input_np, in_axis): + ref_np = np.argmax(input_np, *in_axis).astype("int32") + return ref_np + + @tvm.testing.fixture + def transformed_expected_output_np(self, expected_output_np, output_layout): + return transform_numpy(expected_output_np, "nhw", output_layout) + + @tvm.testing.requires_hexagon + def test_argmax_slice( + self, + input_shape, + dtype, + input_layout, + output_layout, + in_axis, + transformed_input_np, + transformed_expected_output_np, + in_axis_sep, + out_axis_sep, + hexagon_session, + working_scope, + ): + """Top level testing function for argmax""" + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + argmax_input = te.placeholder(input_shape, name="A", dtype=dtype) + output = sl.argmax.argmax_compute(argmax_input, in_axis) + argmax_func = te.create_prim_func([argmax_input, output]) + tir_s = sl.argmax_schedule(argmax_func, input_layout, output_layout) + input_data = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np, + axis_separators=in_axis_sep, + mem_scope=working_scope, + ) + output_data = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=transformed_expected_output_np.shape, + dtype=transformed_expected_output_np.dtype, + axis_separators=out_axis_sep, + mem_scope=working_scope, + ) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): + tir_irm = tvm.lower(tir_s.mod, [argmax_input, output], name="argmax") + runtime_module = tvm.build( + tir_irm, [argmax_input, output], target=target, name="argmax" + ) + mod = hexagon_session.load_module(runtime_module) + + mod(input_data, output_data) + output_np = output_data.numpy() + tvm.testing.assert_allclose( + output_np, + transformed_expected_output_np, + 1e-3, + 1e-3, + ) + + +if __name__ == "__main__": + tvm.testing.main()