From 64493cf920ef5f0b6eeb95b23db0f9757aa8b230 Mon Sep 17 00:00:00 2001 From: abhikran Date: Fri, 1 Jul 2022 14:57:57 +0530 Subject: [PATCH 1/3] Reshape slice op. This patch adds the initial python implementation reshape slice op for hexagon. --- python/tvm/topi/hexagon/slice_ops/__init__.py | 1 + python/tvm/topi/hexagon/slice_ops/reshape.py | 108 ++++++++++++++++++ ...{test_batch_flatten.py => test_reshape.py} | 0 3 files changed, 109 insertions(+) mode change 100755 => 100644 python/tvm/topi/hexagon/slice_ops/__init__.py create mode 100644 python/tvm/topi/hexagon/slice_ops/reshape.py rename tests/python/contrib/test_hexagon/topi/{test_batch_flatten.py => test_reshape.py} (100%) diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py old mode 100755 new mode 100644 index 5b5c0b84214e..39a5fdb64789 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -23,3 +23,4 @@ from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule from .softmax_slice import * from .clip import * +from .reshape import reshape_compute, reshape_stir_schedule diff --git a/python/tvm/topi/hexagon/slice_ops/reshape.py b/python/tvm/topi/hexagon/slice_ops/reshape.py new file mode 100644 index 000000000000..374c20bb72df --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/reshape.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Hexagon slice reshape compute and schedule""" +from tvm import te, tir, topi +from ..utils import get_layout_transform_fn + + +def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor: + """Compute for slice reshape op for hexagon. + This op makes the following assumptions: + 1. This op is written for a sliced reshape operation. + 2. The input is assumed to be in NHWC layout. + + Parameters + ---------- + Input : te.Tensor + Input tensor + New Shape: tuple + Output shape + Returns + ------- + Output : te.Tensor + Output of applying reshape operation on input + """ + return topi.transform.reshape(inp, new_shape) + + +def stir_schedule_nhwc_1024c( + out: te.Tensor, + inp: te.Tensor, + out_layout: str, + in_layout: str, +) -> tir.Schedule: + """Schedule for output layout: nhwc-1024c-2d""" + reshape_func = te.create_prim_func([inp, out]) + sch = tir.Schedule(reshape_func, debug_mask="all") + compute = sch.get_block("T_reshape") + + sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout)) + sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout)) + i, j = sch.get_loops(compute) + jout, channel = sch.split(j, [None, inp.shape[3]]) + height, width = sch.split(jout, [inp.shape[1], inp.shape[2]]) + channelo, channeli = sch.split(channel, [None, 1024]) + channelio, channelii = sch.split(channeli, [None, 64]) + sch.reorder(i, height, width, channelo, channelio, channelii) + sch.vectorize(channelii) + return sch + + +def stir_schedule_nhwc_8h2w32c2w( + out: te.Tensor, + inp: te.Tensor, + out_layout: str, + in_layout: str, +) -> tir.Schedule: + """Schedule for input and output layout nhwc-8h2w32c2w""" + reshape_func = te.create_prim_func([inp, out]) + sch = tir.Schedule(reshape_func, debug_mask="all") + compute = sch.get_block("T_reshape") + + sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout)) + sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout)) + return sch + + +def reshape_stir_schedule( + out: te.Tensor, + inp: te.Tensor, + output_layout: str, + input_layout: str, +) -> tir.Schedule: + """STIR schedule definition for the compute of reshape compute. + Parameters + ---------- + outputs : te.Tensor + The output tensor as returned by a call to reshape_compute + input : te.Tensor + Input tensor to reshape + out_layout: str + The transformation function definition for the expected output layout + in_layout: str + The transformation function definition for the input layout + Returns + ------- + sch : tvm.tir.Schedule + The STIR schedule for slice reshape compute + """ + if output_layout == "nhwc-8h2w32c2w-2d": + return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout) + if output_layout == "nc-1024-2d": + return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout) + raise RuntimeError(f"Unexpected layout '{output_layout}'") diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py similarity index 100% rename from tests/python/contrib/test_hexagon/topi/test_batch_flatten.py rename to tests/python/contrib/test_hexagon/topi/test_reshape.py From eff9c41d65b3824b146e1fc3a7f1edbaa97855bf Mon Sep 17 00:00:00 2001 From: abhikran Date: Fri, 1 Jul 2022 14:59:01 +0530 Subject: [PATCH 2/3] Add tests for reshape op --- .../contrib/test_hexagon/topi/test_reshape.py | 159 +++++++++++++----- 1 file changed, 113 insertions(+), 46 deletions(-) diff --git a/tests/python/contrib/test_hexagon/topi/test_reshape.py b/tests/python/contrib/test_hexagon/topi/test_reshape.py index 3a056116d45c..2def86ad8339 100644 --- a/tests/python/contrib/test_hexagon/topi/test_reshape.py +++ b/tests/python/contrib/test_hexagon/topi/test_reshape.py @@ -28,74 +28,141 @@ from ..infrastructure import allocate_hexagon_array, transform_numpy -class BaseTestBatchFlatten: - input_shape = tvm.testing.parameter( - (1, 1, 1, 2048), - (1, 2, 4, 2048), - (1, 8, 8, 1024), - (2, 4, 8, 1024), - (2, 3, 5, 2048), +def reshape_helper( + func, + fcompute, + fschedule, + data_type, + input_shape, + input_layout, + output_shape, + output_layout, + hexagon_session, +): + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + A = te.placeholder(input_shape, name="A", dtype=data_type) + if func == "reshape": + D = fcompute(A, output_shape) + elif func == "batch_flatten": + D = fcompute(A) + else: + raise RuntimeError(f"Unexpected func'{func}'") + tir_s = fschedule( + D, + A, + output_layout, + input_layout, ) - input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-2d", [4])) - output_layout, output_axis_sep = tvm.testing.parameters(("nc-1024-2d", [2])) - data_type = tvm.testing.parameter("float16") + with tvm.transform.PassContext(opt_level=3): + print("output of tvm.lower", tvm.lower(tir_s.mod, name=func)) + runtime_module = tvm.build(tir_s.mod, target=target, name=func) + mod = hexagon_session.load_module(runtime_module) -class TestBatchFlatten(BaseTestBatchFlatten): - @tvm.testing.fixture - def output_shape(self, input_shape): - return input_shape[0], input_shape[1] * input_shape[2] * input_shape[3] + a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) + ref = np.reshape(a_numpy, output_shape) + + input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout) + ref_np_transformed = transform_numpy(ref, "nhwc", output_layout) + input_axis_sep = [4] + if output_layout == "nhwc-8h2w32c2w-2d": + output_axis_sep = [4] + elif output_layout == "nc-1024-2d": + output_axis_sep = [2] + else: + raise RuntimeError(f"Unexpected layout '{output_layout}'") + a_tvm = allocate_hexagon_array( + hexagon_session.device, + data=input_np_transformed, + axis_separators=input_axis_sep, + mem_scope="global.vtcm", + ) + output = allocate_hexagon_array( + hexagon_session.device, + ref_np_transformed.shape, + data_type, + axis_separators=output_axis_sep, + mem_scope="global.vtcm", + ) + mod(a_tvm, output) + np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0) + + +batch_flatten_tests = ( + ([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"), + ([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"), + ([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"), + ([2, 4, 8, 1024], [2, 4 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"), +) + + +class BaseTestBatchFlatten: + ( + input_shape, + output_shape, + input_layout, + output_layout, + data_type, + ) = tvm.testing.parameters(*batch_flatten_tests) + +class TestBatchFlatten(BaseTestBatchFlatten): @tvm.testing.requires_hexagon def test_batch_flatten( self, data_type, input_shape, input_layout, - input_axis_sep, output_shape, output_layout, - output_axis_sep, hexagon_session, ): - target_hexagon = tvm.target.hexagon("v69") - target = tvm.target.Target(target_hexagon, host=target_hexagon) - A = te.placeholder(input_shape, name="A", dtype=data_type) - D = sl.batch_flatten_compute(A) - tir_s = sl.batch_flatten_stir_schedule( - D, - A, - output_layout, + reshape_helper( + "batch_flatten", + sl.batch_flatten_compute, + sl.batch_flatten_stir_schedule, + data_type, + input_shape, input_layout, + output_shape, + output_layout, + hexagon_session, ) - func_name = "batch_flatten" - with tvm.transform.PassContext(opt_level=3): - runtime_module = tvm.build(tir_s.mod, target=target, name=func_name) - mod = hexagon_session.load_module(runtime_module) - a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type) - ref = np.reshape(a_numpy, output_shape) +class BaseTestReshape(BaseTestBatchFlatten): + (input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters( + *batch_flatten_tests, + ([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"), + ([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"), + ) - input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout) - ref_np_transformed = transform_numpy(ref, "nhwc", output_layout) - a_tvm = allocate_hexagon_array( - hexagon_session.device, - data=input_np_transformed, - axis_separators=input_axis_sep, - mem_scope="global.vtcm", - ) - output = allocate_hexagon_array( - hexagon_session.device, - ref_np_transformed.shape, +class TestReshape(BaseTestReshape): + @tvm.testing.requires_hexagon + def test_reshape( + self, + data_type, + input_shape, + input_layout, + output_shape, + output_layout, + hexagon_session, + ): + reshape_helper( + "reshape", + sl.reshape_compute, + sl.reshape_stir_schedule, data_type, - axis_separators=output_axis_sep, - mem_scope="global.vtcm", + input_shape, + input_layout, + output_shape, + output_layout, + hexagon_session, ) - mod(a_tvm, output) - np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0) if __name__ == "__main__": - tvm.testing.main(pytest.main(sys.argv)) + tvm.testing.main() From cc32d1343a7b934cf9b4d01a0208499e36be5649 Mon Sep 17 00:00:00 2001 From: abhikran Date: Thu, 7 Jul 2022 10:08:10 +0530 Subject: [PATCH 3/3] Empty commit to trigger CI.