From 7c1a83bb6145c80286fcf6e38a80d1932dd9d589 Mon Sep 17 00:00:00 2001 From: abhikran Date: Fri, 24 Jun 2022 12:33:58 +0530 Subject: [PATCH] Fix review comments. --- python/tvm/topi/hexagon/slice_ops/batch_flatten.py | 4 ++-- python/tvm/topi/hexagon/utils.py | 8 ++++---- tests/python/contrib/test_hexagon/infrastructure.py | 2 +- .../contrib/test_hexagon/topi/test_batch_flatten.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py index 07230296412eb..b4fce141e0ec2 100644 --- a/python/tvm/topi/hexagon/slice_ops/batch_flatten.py +++ b/python/tvm/topi/hexagon/slice_ops/batch_flatten.py @@ -43,8 +43,8 @@ def batch_flatten_compute(inp: te.Tensor) -> te.Tensor: def batch_flatten_stir_schedule( out: te.Tensor, inp: te.Tensor, - out_layout: typing.Callable, - in_layout: typing.Callable, + out_layout: str, + in_layout: str, ) -> tir.Schedule: """STIR schedule definition for the compute of batch flatten compute. Parameters diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 1ceeb186ab87b..68b6c76e11a02 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -44,8 +44,8 @@ def nhwc_1024c_1d(n, h, w, c): return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] -def nc_1024_1d(n, c): - """Return index map for nc_1024 1d layout""" +def nc_1024_2d(n, c): + """Return index map for nc_1024 2d layout""" return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024] @@ -61,6 +61,6 @@ def get_layout_transform_fn(layout): return n11c_1024c_1d if layout == "nhwc-1024c-1d": return nhwc_1024c_1d - if layout == "nc-1d": - return nc_1024_1d + if layout == "nc-1024-2d": + return nc_1024_2d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 5d031871509be..1f2c180aec675 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -245,7 +245,7 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): n, h, w, c = arr_np.shape assert h == 1 and w == 1, "The size of h and w must be 1" return arr_np.reshape([n, 1, 1, c // 1024, 1024]) - if new_layout == "nc-1d": + if new_layout == "nc-1024-2d": N, C = arr_np.shape return arr_np.reshape([N, C // 1024, 1024]) if new_layout == "nhwc-1024c-1d": diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py index cd7a9ec515914..58032e6f4c5a0 100644 --- a/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py +++ b/tests/python/contrib/test_hexagon/topi/test_batch_flatten.py @@ -69,7 +69,7 @@ def test_batch_flatten( input_layout, ) func_name = "batch_flatten" - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): + with tvm.transform.PassContext(opt_level=3): runtime_module = tvm.build(tir_s.mod, target=target, name=func_name) mod = hexagon_session.load_module(runtime_module) @@ -98,4 +98,4 @@ def test_batch_flatten( if __name__ == "__main__": - sys.exit(pytest.main(sys.argv)) + tvm.testing.main(pytest.main(sys.argv))