diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index f6c30c25004c..c178aeeb0ec6 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -32,3 +32,4 @@ from .conv2d import * from .reshape import reshape_compute, reshape_stir_schedule from .relu import relu_compute, relu_stir_schedule +from .tanh import tanh_te_compute, tanhf16_schedule diff --git a/python/tvm/topi/hexagon/slice_ops/tanh.py b/python/tvm/topi/hexagon/slice_ops/tanh.py new file mode 100644 index 000000000000..3e10ec599cda --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/tanh.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +""" Hexagon tanh slice op compute and schedule """ +import tvm +from tvm import te, tir +from ..utils import get_layout_transform_fn + + +def tanh_te_compute(in_tensor): + out_tensor = te.compute( + in_tensor.shape, lambda n, h, w, c: tvm.tir.tanh(in_tensor[n, h, w, c]), name="tanhf16" + ) + return out_tensor + + +def tanhf16_stir_sched_nhwc(func, in_layout, out_layout, h_split_factor=8): + """Schedule for nhwc fp16 to nchw fp16 layout""" + sch = tir.Schedule(func, debug_mask="all") + block_name = "tanhf16" + n, h, w, c = sch.get_loops(sch.get_block(block_name)) + h_outer, h_inner = sch.split(h, [None, h_split_factor]) + w_outer, w_inner = sch.split(w, [None, 4]) + c_outer, c_inner = sch.split(c, [None, 32]) + w_inner_o, w_inner_i = sch.split(w_inner, [None, 2]) + sch.reorder(n, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner, w_inner_i) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + fused = sch.fuse(c_inner, w_inner_i) + sch.vectorize(fused) + return sch + + +def tanhf16_schedule(tanh_func, in_layout_str, out_layout_str): + in_layout_transform_func = get_layout_transform_fn(in_layout_str) + out_layout_transform_func = get_layout_transform_fn(out_layout_str) + return tanhf16_stir_sched_nhwc( + tanh_func, + in_layout_transform_func, + out_layout_transform_func, + ) diff --git a/tests/python/contrib/test_hexagon/topi/test_tanh_slice.py b/tests/python/contrib/test_hexagon/topi/test_tanh_slice.py new file mode 100644 index 000000000000..b1e85971a2f9 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_tanh_slice.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test for Hexagon slice tanh op """ +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import te +import tvm.topi.hexagon.slice_ops as sl +import tvm.contrib.hexagon +from ..infrastructure import allocate_hexagon_array, transform_numpy + +# pylint: disable=invalid-name + + +class TestTanhSlice: + """For Testing Tanh fp16 op""" + + input_shape, orig_layout, input_layout, output_layout, axis_sep = tvm.testing.parameters( + ((1, 8, 4, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 16, 12, 64), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 64, 64, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ) + dtype = tvm.testing.parameter("float16") + working_scope = tvm.testing.parameter("global.vtcm") + + @tvm.testing.fixture + def input_np(self, input_shape, dtype): + return np.random.uniform(size=input_shape).astype(dtype) + + @tvm.testing.fixture + def transformed_input_np(self, input_np, orig_layout, input_layout): + return transform_numpy(input_np, orig_layout, input_layout) + + @tvm.testing.fixture + def expected_output_np(self, input_np): + ref_np = np.tanh(input_np) + return ref_np + + @tvm.testing.fixture + def transformed_expected_output_np(self, expected_output_np, orig_layout, output_layout): + return transform_numpy(expected_output_np, orig_layout, output_layout) + + @tvm.testing.requires_hexagon + def test_tanh( + self, + input_shape, + dtype, + input_layout, + output_layout, + transformed_input_np, + transformed_expected_output_np, + axis_sep, + hexagon_session, + working_scope, + ): + """Top Level testing function for tanh fp16 op""" + + target_hexagon = tvm.target.hexagon("v69") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + A = te.placeholder(input_shape, name="A", dtype=dtype) + M = sl.tanh_te_compute(A) + tanhf16_func = te.create_prim_func([A, M]) + tir_s = sl.tanhf16_schedule(tanhf16_func, input_layout, output_layout) + A_data = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + M_data = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=transformed_expected_output_np.shape, + dtype=transformed_expected_output_np.dtype, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + with tvm.transform.PassContext(opt_level=3): + tir_irm = tvm.lower(tir_s.mod, [A, M], name="tanhf16") + runtime_module = tvm.build(tir_irm, target=target, name="tanhf16") + mod = hexagon_session.load_module(runtime_module) + + mod(A_data, M_data) + output_np = M_data.numpy() + tvm.testing.assert_allclose( + output_np, + transformed_expected_output_np, + 1e-3, + 1e-3, + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))