diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 617aaed920d7..931b703d7313 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -23,5 +23,11 @@ from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule from .softmax_slice import * from .clip import * +from .cast import ( + cast_f16_f32_compute, + cast_f16_f32_schedule, + cast_f32_f16_compute, + cast_f32_f16_schedule, +) from .conv2d import * from .reshape import reshape_compute, reshape_stir_schedule diff --git a/python/tvm/topi/hexagon/slice_ops/cast.py b/python/tvm/topi/hexagon/slice_ops/cast.py new file mode 100644 index 000000000000..b4984763e0e0 --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/cast.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Hexagon slice cast op compute and schedule""" + +from tvm import te +from tvm import tir +from ..utils import get_layout_transform_fn + + +def get_layout_transform_for_f32(f32_layout_string): + """ + Given f32 layout string, return transform_layout function and + channel/height split factor to be used for scheduling + """ + layout_transform_fn = get_layout_transform_fn(f32_layout_string) + if f32_layout_string == "nhwc-8h2w32c2w-2d": + return [layout_transform_fn, 8] + if f32_layout_string == "nhwc-4h2w32c2w-2d": + return [layout_transform_fn, 4] + if f32_layout_string == "nc-1024c-2d": + return [layout_transform_fn, 1024] + if f32_layout_string == "nc-512c-2d": + return [layout_transform_fn, 512] + raise RuntimeError(f"Unexpected f32_layout '{f32_layout_string}'") + + +def cast_f16_f32_compute(in_tensor): + out_tensor = te.compute( + in_tensor.shape, lambda *indices: in_tensor[indices].astype("float32"), name="CastF16F32" + ) + return out_tensor + + +def cast_f16_f32_stir_schedule_nhwc(func, in_layout, out_layout, h_split_factor): + """Schedule for nhwc f16 to f32 cast: nhwc layout""" + sch = tir.Schedule(func, debug_mask="all") + block_name = "CastF16F32" + n_orig, h_orig, w_orig, c_orig = sch.get_loops(sch.get_block(block_name)) + h_outer, h_inner = sch.split(h_orig, [None, h_split_factor]) + w_outer, w_inner = sch.split(w_orig, [None, 4]) + c_outer, c_inner = sch.split(c_orig, [None, 32]) + w_inner_o, w_inner_i = sch.split(w_inner, [None, 2]) + sch.reorder(n_orig, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner, w_inner_i) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + fused = sch.fuse(c_inner, w_inner_i) + sch.vectorize(fused) + return sch + + +def cast_f16_f32_stir_schedule_nc(func, in_layout, out_layout, c_split_factor): + """Schedule for nc f16 to f32 cast: nc layout""" + sch = tir.Schedule(func, debug_mask="all") + block_name = "CastF16F32" + _, c_orig = sch.get_loops(sch.get_block(block_name)) + _, c_inner = sch.split(c_orig, [None, c_split_factor]) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + sch.vectorize(c_inner) + return sch + + +def cast_f16_f32_schedule(cast_func, in_layout_str, out_layout_str): + """Schedule for f16 to f32 cast: top level function""" + f32_layout_transform_func, split_factor = get_layout_transform_for_f32(out_layout_str) + f16_layout_transform_func = get_layout_transform_fn(in_layout_str) + if in_layout_str == "nhwc-8h2w32c2w-2d": + return cast_f16_f32_stir_schedule_nhwc( + cast_func, + f16_layout_transform_func, + f32_layout_transform_func, + split_factor, + ) + if in_layout_str == "nc-1024c-2d": + return cast_f16_f32_stir_schedule_nc( + cast_func, f16_layout_transform_func, f32_layout_transform_func, split_factor + ) + raise RuntimeError(f"Unexpected input_layout, output_layout '{input_layout, output_layout}'") + + +def cast_f32_f16_compute(in_tensor): + out_tensor = te.compute( + in_tensor.shape, lambda *indices: in_tensor[indices].astype("float16"), name="CastF32F16" + ) + return out_tensor + + +def cast_f32_f16_stir_schedule_nhwc(func, in_layout, out_layout, h_split_factor): + """Schedule for nhwc f32 to f16 cast: nhwc layout""" + sch = tir.Schedule(func, debug_mask="all") + block_name = "CastF32F16" + n_orig, h_orig, w_orig, c_orig = sch.get_loops(sch.get_block(block_name)) + h_outer, h_inner = sch.split(h_orig, [None, h_split_factor]) + w_outer, w_inner = sch.split(w_orig, [None, 4]) + c_outer, c_inner = sch.split(c_orig, [None, 32]) + w_inner_o, w_inner_i = sch.split(w_inner, [None, 2]) + sch.reorder(n_orig, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner, w_inner_i) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + fused = sch.fuse(c_inner, w_inner_i) + sch.vectorize(fused) + return sch + + +def cast_f32_f16_stir_schedule_nc(func, in_layout, out_layout, c_split_factor): + """Schedule for nc f32 to f16 cast: nc layout""" + sch = tir.Schedule(func, debug_mask="all") + block_name = "CastF32F16" + _, c_orig = sch.get_loops(sch.get_block(block_name)) + _, c_inner = sch.split(c_orig, [None, c_split_factor]) + sch.transform_layout(block_name, "A", in_layout) + sch.transform_layout(block_name, block_name, out_layout) + sch.vectorize(c_inner) + return sch + + +def cast_f32_f16_schedule(cast_func, in_layout_str, out_layout_str): + """Schedule for f32 to f16 cast: top level function""" + f32_layout_transform_func, split_factor = get_layout_transform_for_f32(in_layout_str) + f16_layout_transform_func = get_layout_transform_fn(out_layout_str) + if out_layout_str == "nhwc-8h2w32c2w-2d": + return cast_f32_f16_stir_schedule_nhwc( + cast_func, f32_layout_transform_func, f16_layout_transform_func, split_factor + ) + if out_layout_str == "nc-1024c-2d": + return cast_f32_f16_stir_schedule_nc( + cast_func, f32_layout_transform_func, f16_layout_transform_func, split_factor + ) + raise RuntimeError(f"Unexpected input_layout, output_layout '{in_layout_str, out_layout_str}'") diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 58792fc3294f..4458c55e6273 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -67,6 +67,16 @@ def nc_512c_2d(n, c): return [n, c // 512, te.AXIS_SEPARATOR, c % 512] +def nc_1024c_2d(n, c): + """Return index map for nc_1024c 2d layout""" + return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] + + +def nhwc_4h2w32c2w_2d(n, h, w, c): + """Return index map for nhwc_4h2w32c2w 2d layout""" + return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, (w % 4) // 2, c % 32, w % 2] + + def nhwc_1024c_2d(n, h, w, c): """Return index map for nhwc_1024 2d layout""" return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] @@ -113,6 +123,10 @@ def get_layout_transform_fn(layout): return nc_512c_2d if layout == "nc-512c-1d": return nc_512c_1d + if layout == "nhwc-4h2w32c2w-2d": + return nhwc_4h2w32c2w_2d + if layout == "nc-1024c-2d": + return nc_1024c_2d if layout == "iohw-16i32o2i-1d": return iohw_16i32o2i_1d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 53351854a06a..a1fbfdefcdbd 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -241,6 +241,11 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): return arr_np.reshape([n, h // 8, 8, w // 4, 2, 2, c // 32, 32]).transpose( 0, 1, 3, 6, 2, 4, 7, 5 ) + if new_layout in ["nhwc-4h2w32c2w-2d"]: + n, h, w, c = arr_np.shape + return arr_np.reshape([n, h // 4, 4, w // 4, 2, 2, c // 32, 32]).transpose( + 0, 1, 3, 6, 2, 4, 7, 5 + ) if new_layout in ["n11c-1024c-2d", "n11c-1024c-1d"]: n, h, w, c = arr_np.shape assert h == 1 and w == 1, "The size of h and w must be 1" @@ -251,7 +256,14 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): if new_layout == "nhwc-1024c-2d": N, H, W, C = arr_np.shape return arr_np.reshape([N, H, W, C // 1024, 1024]) + raise RuntimeError(f"Unexpected new_layout '{new_layout}'") + if current_layout == "nc": + n, c = arr_np.shape + if new_layout in ["nc-1024c-2d"]: + return arr_np.reshape([n, c // 1024, 1024]) + if new_layout in ["nc-512c-2d"]: + return arr_np.reshape([n, c // 512, 512]) raise RuntimeError(f"Unexpected new_layout '{new_layout}'") if current_layout == "nhw": diff --git a/tests/python/contrib/test_hexagon/topi/test_cast_slice.py b/tests/python/contrib/test_hexagon/topi/test_cast_slice.py new file mode 100644 index 000000000000..30ea4c94b8b1 --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/test_cast_slice.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Tests for Hexagon slice cast ops """ +import numpy as np + +import tvm +import tvm.testing +from tvm import te +import tvm.topi.hexagon.slice_ops as sl +from ..infrastructure import allocate_hexagon_array, transform_numpy + + +class TestCastF16F32Slice2d: + """ + For testing Cast F16 to F32 Slice ops + """ + + input_shape, orig_layout, input_layout, output_layout, axis_sep = tvm.testing.parameters( + ((1, 16, 12, 64), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 64, 64, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 16, 12, 64), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-4h2w32c2w-2d", [4]), + ((1, 64, 64, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-4h2w32c2w-2d", [4]), + ((1, 1024), "nc", "nc-1024c-2d", "nc-1024c-2d", [2]), + ((1, 1024), "nc", "nc-1024c-2d", "nc-512c-2d", [2]), + ) + dtype = tvm.testing.parameter("float16") + working_scope = tvm.testing.parameter("global.vtcm") + + @tvm.testing.fixture + def input_np(self, input_shape, dtype): + return np.random.uniform(size=input_shape).astype(dtype) + + @tvm.testing.fixture + def transformed_input_np(self, input_np, orig_layout, input_layout): + return transform_numpy(input_np, orig_layout, input_layout) + + @tvm.testing.fixture + def expected_output_np(self, input_np): + ref_np = input_np.astype("float32") + return ref_np + + @tvm.testing.fixture + def transformed_expected_output_np(self, expected_output_np, orig_layout, output_layout): + return transform_numpy(expected_output_np, orig_layout, output_layout) + + @tvm.testing.requires_hexagon + def test_cast_fp16_fp32_slice( + self, + input_shape, + dtype, + input_layout, + output_layout, + transformed_input_np, + transformed_expected_output_np, + axis_sep, + hexagon_session, + working_scope, + ): + """ + Top level testing function for cast fp16 to fp32 + """ + if hexagon_session._launcher._serial_number != "simulator": + pytest.skip(msg="Due to https://github.com/apache/tvm/issues/11957") + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + cast_input = te.placeholder(input_shape, name="A", dtype=dtype) + cast_output = sl.cast_f16_f32_compute(cast_input) + cast_func = te.create_prim_func([cast_input, cast_output]) + tir_s = sl.cast_f16_f32_schedule(cast_func, input_layout, output_layout) + input_data = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + output_data = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=transformed_expected_output_np.shape, + dtype=transformed_expected_output_np.dtype, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + with tvm.transform.PassContext(opt_level=3): + tir_irm = tvm.lower(tir_s.mod, [cast_input, cast_output], name="cast_f16_f32") + runtime_module = tvm.build(tir_irm, target=target, name="cast_f16_f32") + mod = hexagon_session.load_module(runtime_module) + + mod(input_data, output_data) + output_np = output_data.numpy() + tvm.testing.assert_allclose( + output_np, + transformed_expected_output_np, + 1e-3, + 1e-3, + ) + + +class TestCastF32F16Slice2d: + """ + For testing Cast F32 to F16 Slice ops + """ + + (input_shape, orig_layout, input_layout, output_layout, axis_sep,) = tvm.testing.parameters( + ((1, 16, 12, 64), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 64, 64, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 16, 12, 64), "nhwc", "nhwc-4h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 64, 64, 32), "nhwc", "nhwc-4h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]), + ((1, 1024), "nc", "nc-1024c-2d", "nc-1024c-2d", [2]), + ((1, 1024), "nc", "nc-512c-2d", "nc-1024c-2d", [2]), + ) + dtype = tvm.testing.parameter("float32") + working_scope = tvm.testing.parameter("global.vtcm") + + @tvm.testing.fixture + def input_np(self, input_shape, dtype): + return np.random.uniform(size=input_shape).astype(dtype) + + @tvm.testing.fixture + def transformed_input_np(self, input_np, orig_layout, input_layout): + return transform_numpy(input_np, orig_layout, input_layout) + + @tvm.testing.fixture + def expected_output_np(self, input_np): + ref_np = input_np.astype("float16") + return ref_np + + @tvm.testing.fixture + def transformed_expected_output_np(self, expected_output_np, orig_layout, output_layout): + return transform_numpy(expected_output_np, orig_layout, output_layout) + + @tvm.testing.requires_hexagon + def test_cast_fp32_fp16_slice( + self, + input_shape, + dtype, + input_layout, + output_layout, + transformed_input_np, + transformed_expected_output_np, + axis_sep, + hexagon_session, + working_scope, + ): + """ + Top level testing function for cast fp32 to fp16 + """ + if hexagon_session._launcher._serial_number != "simulator": + pytest.skip(msg="Due to https://github.com/apache/tvm/issues/11957") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + cast_input = te.placeholder(input_shape, name="A", dtype=dtype) + cast_output = sl.cast_f32_f16_compute(cast_input) + cast_func = te.create_prim_func([cast_input, cast_output]) + tir_s = sl.cast_f32_f16_schedule(cast_func, input_layout, output_layout) + input_data = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + output_data = allocate_hexagon_array( + hexagon_session.device, + tensor_shape=transformed_expected_output_np.shape, + dtype=transformed_expected_output_np.dtype, + axis_separators=axis_sep, + mem_scope=working_scope, + ) + with tvm.transform.PassContext(opt_level=3): + tir_irm = tvm.lower(tir_s.mod, [cast_input, cast_output], name="cast_f32_f16") + runtime_module = tvm.build(tir_irm, target=target, name="cast_f32_f16") + mod = hexagon_session.load_module(runtime_module) + + mod(input_data, output_data) + output_np = output_data.numpy() + tvm.testing.assert_allclose( + output_np, + transformed_expected_output_np, + 1e-3, + 1e-3, + ) + + +if __name__ == "__main__": + tvm.testing.main()