diff --git a/python/tvm/topi/hexagon/slice_ops/max_pool2d.py b/python/tvm/topi/hexagon/slice_ops/max_pool2d.py index a825508883886..2ff978e14065a 100644 --- a/python/tvm/topi/hexagon/slice_ops/max_pool2d.py +++ b/python/tvm/topi/hexagon/slice_ops/max_pool2d.py @@ -35,6 +35,18 @@ from tvm import tir from ..utils import get_layout_transform_fn +from typing import * +from inspect import currentframe, getframeinfo + + +def get_src_loc() -> str: + info = getframeinfo(currentframe().f_back) + return f"{info[0]}:{info[1]}" + + +import tvm +from tvm.support import dump_for_debug + def validate_out_shape(out_shape, in_shape, kernel, stride, dilation): """Validate output shape""" @@ -68,14 +80,23 @@ def max_pool2d_compute(A, out_shape, kernel, stride, dilation): ), name="max", ) - # breakpoint() return Max def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str): """Schedule for input and output layout nhwc-8h2w32c2w""" func = te.create_prim_func([ins, outs]) + print() + print("-" * 80) + print(get_src_loc(), ":") + print(dump_for_debug(func, "func", io_tensors=[ins, outs])) + s = tir.Schedule(func) + print() + print("-" * 80) + print(get_src_loc(), ":") + print(dump_for_debug(s, "s", io_tensors=[ins, outs])) + Max = s.get_block("max") input_transform_fn = get_layout_transform_fn(input_layout) @@ -83,7 +104,10 @@ def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: st s.transform_layout(Max, ("read", 0), input_transform_fn) s.transform_layout(Max, ("write", 0), output_transform_fn) - # TODO + print() + print("-" * 80) + print(get_src_loc(), ":") + print(dump_for_debug(s, "s", io_tensors=[ins, outs])) ## Schedule 'Avg' # n, h, w, c = s.get_loops(Avg) diff --git a/tests/python/contrib/test_hexagon/topi/test_max_pool2d_slice.py b/tests/python/contrib/test_hexagon/topi/test_max_pool2d_slice.py index 5cb8feadd3f46..ccd3fedda1a97 100644 --- a/tests/python/contrib/test_hexagon/topi/test_max_pool2d_slice.py +++ b/tests/python/contrib/test_hexagon/topi/test_max_pool2d_slice.py @@ -28,6 +28,7 @@ from ..infrastructure import allocate_hexagon_array, transform_numpy from ..pytest_util import ( get_numpy_dtype_info, + get_test_id, get_multitest_ids, create_populated_numpy_ndarray, TensorContentConstant, @@ -38,11 +39,6 @@ ) -input_layout = tvm.testing.parameter( - "nhwc-8h2w32c2w-2d", -) - - @tvm.testing.fixture def input_np(input_shape, dtype: str, input_tensor_populator): return create_populated_numpy_ndarray(input_shape, dtype, input_tensor_populator) @@ -144,7 +140,7 @@ class TestmaxPool2dSlice: False, True, "nhwc-8h2w32c2w-2d", - "float16", + "floGat16", TensorContentRandom(), ), # Test non-zero padding @@ -249,6 +245,10 @@ class TestmaxPool2dSlice: _param_ids = get_multitest_ids(_multitest_params, _param_descs) + input_layout = tvm.testing.parameter( + "nhwc-8h2w32c2w-2d", + ) + # NOTE: input_layout is always assumed to be "nhwc-8h2w32c2w-2d" ( output_shape, @@ -344,6 +344,8 @@ def test_max_pool2d_slice( dtype, dilation, padding, + ceil_mode, # only needed for manually obtaining the test id string + input_tensor_populator, # only needed for manually obtaining the test id string count_include_pad, input_layout, output_layout, @@ -399,6 +401,23 @@ def test_max_pool2d_slice( mem_scope="global.vtcm", ) + current_test_id = get_test_id( + output_shape, + kernel, + stride, + dilation, + padding, + ceil_mode, + count_include_pad, + output_layout, + dtype, + input_tensor_populator, + test_param_descs=self._param_descs, + ) + so_filename = f"/tmpx/max_pool2d-{current_test_id}.so" + print("Saving .so to file: '{so_filename}'") + func.save(so_filename) + # breakpoint() mod = hexagon_session.load_module(func) mod(input_arr, output_arr)