Skip to content

Commit

Permalink
Fix review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
abhikran-quic committed Jun 24, 2022
1 parent c4db7e3 commit 7c1a83b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions python/tvm/topi/hexagon/slice_ops/batch_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def batch_flatten_compute(inp: te.Tensor) -> te.Tensor:
def batch_flatten_stir_schedule(
out: te.Tensor,
inp: te.Tensor,
out_layout: typing.Callable,
in_layout: typing.Callable,
out_layout: str,
in_layout: str,
) -> tir.Schedule:
"""STIR schedule definition for the compute of batch flatten compute.
Parameters
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def nhwc_1024c_1d(n, h, w, c):
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def nc_1024_1d(n, c):
"""Return index map for nc_1024 1d layout"""
def nc_1024_2d(n, c):
"""Return index map for nc_1024 2d layout"""
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]


Expand All @@ -61,6 +61,6 @@ def get_layout_transform_fn(layout):
return n11c_1024c_1d
if layout == "nhwc-1024c-1d":
return nhwc_1024c_1d
if layout == "nc-1d":
return nc_1024_1d
if layout == "nc-1024-2d":
return nc_1024_2d
raise RuntimeError(f"Unexpected layout '{layout}'")
2 changes: 1 addition & 1 deletion tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
n, h, w, c = arr_np.shape
assert h == 1 and w == 1, "The size of h and w must be 1"
return arr_np.reshape([n, 1, 1, c // 1024, 1024])
if new_layout == "nc-1d":
if new_layout == "nc-1024-2d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 1024, 1024])
if new_layout == "nhwc-1024c-1d":
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_hexagon/topi/test_batch_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_batch_flatten(
input_layout,
)
func_name = "batch_flatten"
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}):
with tvm.transform.PassContext(opt_level=3):
runtime_module = tvm.build(tir_s.mod, target=target, name=func_name)

mod = hexagon_session.load_module(runtime_module)
Expand Down Expand Up @@ -98,4 +98,4 @@ def test_batch_flatten(


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
tvm.testing.main(pytest.main(sys.argv))

0 comments on commit 7c1a83b

Please sign in to comment.