Skip to content

Commit

Permalink
[TOPI][Hexagon] Implement Argmax Slice Op (apache#11847)
Browse files Browse the repository at this point in the history
* [TOPI][Hexagon] Implement Argmax Slice Op

* run through black

* Address initial review comments

* Fix variable names in tests

* Fix lint issue

Co-authored-by: arangasa (generated by with_the_same_user script) <arangasa@hu-arangasa-hyd.qualcomm.com>
  • Loading branch information
2 people authored and Mikael Sevenier committed Jul 26, 2022
1 parent a9d81ae commit f40ac52
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@

from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
from .add_subtract_multiply import *
from .argmax import argmax_compute, argmax_schedule
from .softmax_slice import *
from .clip import *
46 changes: 46 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Hexagon slice argmax compute and schedule"""

from tvm import tir
from tvm import topi
from ..utils import get_layout_transform_fn


def argmax_compute(in_tensor, axis):
out_tensor = topi.argmax(in_tensor, axis)
return out_tensor


def argmax_stir_schedule_nhwc(func, in_layout, out_layout):
"""Schedule for nhwc argmax"""
sch = tir.Schedule(func, debug_mask="all")
sch.transform_layout("A_red_temp", "A", in_layout)
sch.transform_layout("A_red", "A_red", out_layout)
return sch


def argmax_schedule(argmax_func, in_layout_str, out_layout_str):
"""Schedule for argmax: top level function"""
if (in_layout_str == "nhwc-8h2w32c2w-2d") and (out_layout_str == "nhw-32h16w-2d"):
fp16_layout_transform = get_layout_transform_fn(in_layout_str)
int32_layout_transform = get_layout_transform_fn(out_layout_str)
tir_s = argmax_stir_schedule_nhwc(
argmax_func, fp16_layout_transform, int32_layout_transform
)
return tir_s
raise RuntimeError(f"Unexpected input_layout, output_layout '{in_layout_str, out_layout_str}'")
7 changes: 7 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def nhwc_8h2w32c2w_1d(n, h, w, c):
return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]


def nhw_32h16w_2d(n, h, w):
"""Return index map for nhw_32h16w 2d layout"""
return [n, h // 32, w // 16, te.AXIS_SEPARATOR, h % 32, w % 16]


def nhwc_4h4w32c_1d(n, h, w, c):
"""Return index map for nhwc_4h4232c 1d layout"""
return [n, h // 4, w // 4, c // 32, h % 4, w % 4, c % 32]
Expand Down Expand Up @@ -72,6 +77,8 @@ def get_layout_transform_fn(layout):
return n11c_1024c_2d
if layout == "n11c-1024c-1d":
return n11c_1024c_1d
if layout == "nhw-32h16w-2d":
return nhw_32h16w_2d
if layout == "nhwc-4h4w32c-2d":
return nhwc_4h4w32c_2d
if layout == "nhwc-4h4w32c-1d":
Expand Down
8 changes: 8 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
return arr_np.reshape([n, 1, 1, c // 1024, 1024])

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

if current_layout == "nhw":
if new_layout in ["nhw-32h16w-2d"]:
n, h, w = arr_np.shape
return arr_np.reshape([n, h // 32, 32, w // 16, 16]).transpose(0, 1, 3, 2, 4)

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

raise RuntimeError(f"Unexpected current_layout '{current_layout}'")
116 changes: 116 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_argmax_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Tests for Hexagon slice argmax op """
import numpy as np

import tvm
import tvm.testing
from tvm import te
import tvm.topi.hexagon.slice_ops as sl
import tvm.contrib.hexagon
from ..infrastructure import allocate_hexagon_array, transform_numpy


class TestArgMaxSlice:
"""Argmax Slice Op Tests"""

(
input_shape,
input_layout,
output_layout,
in_axis,
in_axis_sep,
out_axis_sep,
) = tvm.testing.parameters(
((1, 64, 64, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]),
((3, 32, 16, 32), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]),
((1, 32, 32, 64), "nhwc-8h2w32c2w-2d", "nhw-32h16w-2d", [3], [4], [3]),
)
dtype = tvm.testing.parameter("float16")
working_scope = tvm.testing.parameter("global.vtcm")

@tvm.testing.fixture
def input_np(self, input_shape, dtype):
return np.random.uniform(size=input_shape).astype(dtype)

@tvm.testing.fixture
def transformed_input_np(self, input_np, input_layout):
return transform_numpy(input_np, "nhwc", input_layout)

@tvm.testing.fixture
def expected_output_np(self, input_np, in_axis):
ref_np = np.argmax(input_np, *in_axis).astype("int32")
return ref_np

@tvm.testing.fixture
def transformed_expected_output_np(self, expected_output_np, output_layout):
return transform_numpy(expected_output_np, "nhw", output_layout)

@tvm.testing.requires_hexagon
def test_argmax_slice(
self,
input_shape,
dtype,
input_layout,
output_layout,
in_axis,
transformed_input_np,
transformed_expected_output_np,
in_axis_sep,
out_axis_sep,
hexagon_session,
working_scope,
):
"""Top level testing function for argmax"""
target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
argmax_input = te.placeholder(input_shape, name="A", dtype=dtype)
output = sl.argmax.argmax_compute(argmax_input, in_axis)
argmax_func = te.create_prim_func([argmax_input, output])
tir_s = sl.argmax_schedule(argmax_func, input_layout, output_layout)
input_data = allocate_hexagon_array(
hexagon_session.device,
data=transformed_input_np,
axis_separators=in_axis_sep,
mem_scope=working_scope,
)
output_data = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=transformed_expected_output_np.shape,
dtype=transformed_expected_output_np.dtype,
axis_separators=out_axis_sep,
mem_scope=working_scope,
)
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}):
tir_irm = tvm.lower(tir_s.mod, [argmax_input, output], name="argmax")
runtime_module = tvm.build(
tir_irm, [argmax_input, output], target=target, name="argmax"
)
mod = hexagon_session.load_module(runtime_module)

mod(input_data, output_data)
output_np = output_data.numpy()
tvm.testing.assert_allclose(
output_np,
transformed_expected_output_np,
1e-3,
1e-3,
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit f40ac52

Please sign in to comment.