From 915c23b61b34604b19217759f320c84d3aa60605 Mon Sep 17 00:00:00 2001 From: abhikran-quic <63697863+abhikran-quic@users.noreply.github.com> Date: Thu, 30 Jun 2022 20:06:27 +0530 Subject: [PATCH] [TOPI] [Hexagon] Batch flatten slice op initial version (#11522) * [TOPI] [Hexagon] Batch flatten slice op initial version * Fix lint errors * Fix more lint errors * Fix lint warnings * Fix review comments * Update tests to use util functions * Update __init__.py * Fix review comments --- python/tvm/topi/hexagon/slice_ops/__init__.py | 1 + .../topi/hexagon/slice_ops/batch_flatten.py | 77 +++++++++++++ python/tvm/topi/hexagon/utils.py | 14 +++ .../contrib/test_hexagon/infrastructure.py | 6 ++ .../test_hexagon/topi/test_batch_flatten.py | 101 ++++++++++++++++++ 5 files changed, 199 insertions(+) create mode 100644 python/tvm/topi/hexagon/slice_ops/batch_flatten.py create mode 100644 tests/python/contrib/test_hexagon/topi/test_batch_flatten.py diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 3340f835200b..5b5c0b84214e 100755 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -20,5 +20,6 @@ from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule from .add_subtract_multiply import * from .argmax import argmax_compute, argmax_schedule +from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule from .softmax_slice import * from .clip import * diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py new file mode 100644 index 000000000000..6dc0914e91b4 --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Hexagon slice batch flatten compute and schedule""" +from tvm import te, tir, topi +from ..utils import get_layout_transform_fn + + +def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: + """Compute for slice batch flatten op for hexagon. + This op makes the following assumptions: + 1. This op is written for a sliced batch flatten operation. + 2. The input is assumed to be in NHWC layout. + + Parameters + ---------- + Input : te.Tensor + Input activations padded for inner dimension size + Returns + ------- + Output : te.Tensor + Output of applying batch flatten operation on input + """ + return topi.nn.flatten(inp) + + +def batch_flatten_stir_schedule( + out: te.Tensor, + inp: te.Tensor, + out_layout: str, + in_layout: str, +) -> tir.Schedule: + """STIR schedule definition for the compute of batch flatten compute. + Parameters + ---------- + outputs : te.Tensor + The output tensor as returned by a call to batch_flatten_compute + input : te.Tensor + Input tensor to batch_flatten + out_layout: typing.Callable + The transformation function definition for the expected output layout + in_layout: typing.Callable + The transformation function definition for the input layout + Returns + ------- + sch : tvm.tir.Schedule + The STIR schedule for slice batch flatten compute + """ + + batch_flatten_func = te.create_prim_func([inp, out]) + sch = tir.Schedule(batch_flatten_func, debug_mask="all") + compute = sch.get_block("compute") + + sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout)) + sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout)) + i, j = sch.get_loops(compute) + jout, channel = sch.split(j, [None, inp.shape[3]]) + height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) + channelo, channeli = sch.split(channel, [None, 1024]) + channelio, channelii = sch.split(channeli, [None, 64]) + sch.reorder(i, height, width, channelo, channelio, channelii) + sch.vectorize(channelii) + return sch diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 95b25cc5a73b..092bce87119a 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -67,6 +67,16 @@ def nc_512c_2d(n, c): return [n, c // 512, te.AXIS_SEPARATOR, c % 512] +def nhwc_1024c_2d(n, h, w, c): + """Return index map for nhwc_1024 2d layout""" + return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + +def nc_1024_2d(n, c): + """Return index map for nc_1024 2d layout""" + return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + def get_layout_transform_fn(layout): """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": @@ -77,6 +87,10 @@ def get_layout_transform_fn(layout): return n11c_1024c_2d if layout == "n11c-1024c-1d": return n11c_1024c_1d + if layout == "nhwc-1024c-2d": + return nhwc_1024c_2d + if layout == "nc-1024-2d": + return nc_1024_2d if layout == "nhw-32h16w-2d": return nhw_32h16w_2d if layout == "nhwc-4h4w32c-2d": diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index c1d2b4046372..53351854a06a 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -245,6 +245,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): n, h, w, c = arr_np.shape assert h == 1 and w == 1, "The size of h and w must be 1" return arr_np.reshape([n, 1, 1, c // 1024, 1024]) + if new_layout == "nc-1024-2d": + N, C = arr_np.shape + return arr_np.reshape([N, C // 1024, 1024]) + if new_layout == "nhwc-1024c-2d": + N, H, W, C = arr_np.shape + return arr_np.reshape([N, H, W, C // 1024, 1024]) raise RuntimeError(f"Unexpected new_layout '{new_layout}'") diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py new file mode 100644 index 000000000000..3a056116d45c --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.testing +import tvm.topi.hexagon.slice_ops as sl +from tvm import te, topi +from tvm.contrib.hexagon.build import HexagonLauncher +from tvm.topi import testing + +from ..infrastructure import allocate_hexagon_array, transform_numpy + + +class BaseTestBatchFlatten: + input_shape = tvm.testing.parameter( + (1, 1, 1, 2048), + (1, 2, 4, 2048), + (1, 8, 8, 1024), + (2, 4, 8, 1024), + (2, 3, 5, 2048), + ) + input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-2d", [4])) + output_layout, output_axis_sep = tvm.testing.parameters(("nc-1024-2d", [2])) + data_type = tvm.testing.parameter("float16") + + +class TestBatchFlatten(BaseTestBatchFlatten): + @tvm.testing.fixture + def output_shape(self, input_shape): + return input_shape[0], input_shape[1] * input_shape[2] * input_shape[3] + + @tvm.testing.requires_hexagon + def test_batch_flatten( + self, + data_type, + input_shape, + input_layout, + input_axis_sep, + output_shape, + output_layout, + output_axis_sep, + hexagon_session, + ): + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + A = te.placeholder(input_shape, name="A", dtype=data_type) + D = sl.batch_flatten_compute(A) + tir_s = sl.batch_flatten_stir_schedule( + D, + A, + output_layout, + input_layout, + ) + func_name = "batch_flatten" + with tvm.transform.PassContext(opt_level=3): + runtime_module = tvm.build(tir_s.mod, target=target, name=func_name) + + mod = hexagon_session.load_module(runtime_module) + + a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) + ref = np.reshape(a_numpy, output_shape) + + input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout) + ref_np_transformed = transform_numpy(ref, "nhwc", output_layout) + + a_tvm = allocate_hexagon_array( + hexagon_session.device, + data=input_np_transformed, + axis_separators=input_axis_sep, + mem_scope="global.vtcm", + ) + output = allocate_hexagon_array( + hexagon_session.device, + ref_np_transformed.shape, + data_type, + axis_separators=output_axis_sep, + mem_scope="global.vtcm", + ) + mod(a_tvm, output) + np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0) + + +if __name__ == "__main__": + tvm.testing.main(pytest.main(sys.argv))