Skip to content

Commit

Permalink
wip progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian Convey committed Jul 26, 2022
1 parent c824720 commit 3971bde
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
28 changes: 26 additions & 2 deletions python/tvm/topi/hexagon/slice_ops/max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@
from tvm import tir
from ..utils import get_layout_transform_fn

from typing import *
from inspect import currentframe, getframeinfo


def get_src_loc() -> str:
info = getframeinfo(currentframe().f_back)
return f"{info[0]}:{info[1]}"


import tvm
from tvm.support import dump_for_debug


def validate_out_shape(out_shape, in_shape, kernel, stride, dilation):
"""Validate output shape"""
Expand Down Expand Up @@ -68,22 +80,34 @@ def max_pool2d_compute(A, out_shape, kernel, stride, dilation):
),
name="max",
)
# breakpoint()
return Max


def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
"""Schedule for input and output layout nhwc-8h2w32c2w"""
func = te.create_prim_func([ins, outs])
print()
print("-" * 80)
print(get_src_loc(), ":")
print(dump_for_debug(func, "func", io_tensors=[ins, outs]))

s = tir.Schedule(func)
print()
print("-" * 80)
print(get_src_loc(), ":")
print(dump_for_debug(s, "s", io_tensors=[ins, outs]))

Max = s.get_block("max")

input_transform_fn = get_layout_transform_fn(input_layout)
output_transform_fn = get_layout_transform_fn(output_layout)
s.transform_layout(Max, ("read", 0), input_transform_fn)
s.transform_layout(Max, ("write", 0), output_transform_fn)

# TODO
print()
print("-" * 80)
print(get_src_loc(), ":")
print(dump_for_debug(s, "s", io_tensors=[ins, outs]))

## Schedule 'Avg'
# n, h, w, c = s.get_loops(Avg)
Expand Down
31 changes: 25 additions & 6 deletions tests/python/contrib/test_hexagon/topi/test_max_pool2d_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..infrastructure import allocate_hexagon_array, transform_numpy
from ..pytest_util import (
get_numpy_dtype_info,
get_test_id,
get_multitest_ids,
create_populated_numpy_ndarray,
TensorContentConstant,
Expand All @@ -38,11 +39,6 @@
)


input_layout = tvm.testing.parameter(
"nhwc-8h2w32c2w-2d",
)


@tvm.testing.fixture
def input_np(input_shape, dtype: str, input_tensor_populator):
return create_populated_numpy_ndarray(input_shape, dtype, input_tensor_populator)
Expand Down Expand Up @@ -144,7 +140,7 @@ class TestmaxPool2dSlice:
False,
True,
"nhwc-8h2w32c2w-2d",
"float16",
"floGat16",
TensorContentRandom(),
),
# Test non-zero padding
Expand Down Expand Up @@ -249,6 +245,10 @@ class TestmaxPool2dSlice:

_param_ids = get_multitest_ids(_multitest_params, _param_descs)

input_layout = tvm.testing.parameter(
"nhwc-8h2w32c2w-2d",
)

# NOTE: input_layout is always assumed to be "nhwc-8h2w32c2w-2d"
(
output_shape,
Expand Down Expand Up @@ -344,6 +344,8 @@ def test_max_pool2d_slice(
dtype,
dilation,
padding,
ceil_mode, # only needed for manually obtaining the test id string
input_tensor_populator, # only needed for manually obtaining the test id string
count_include_pad,
input_layout,
output_layout,
Expand Down Expand Up @@ -399,6 +401,23 @@ def test_max_pool2d_slice(
mem_scope="global.vtcm",
)

current_test_id = get_test_id(
output_shape,
kernel,
stride,
dilation,
padding,
ceil_mode,
count_include_pad,
output_layout,
dtype,
input_tensor_populator,
test_param_descs=self._param_descs,
)
so_filename = f"/tmpx/max_pool2d-{current_test_id}.so"
print("Saving .so to file: '{so_filename}'")
func.save(so_filename)

# breakpoint()
mod = hexagon_session.load_module(func)
mod(input_arr, output_arr)
Expand Down

0 comments on commit 3971bde

Please sign in to comment.