Skip to content

Commit

Permalink
[TOPI] [Hexagon] Batch flatten slice op initial version (#11522)
Browse files Browse the repository at this point in the history
* [TOPI] [Hexagon] Batch flatten slice op initial version

* Fix lint errors

* Fix more lint errors

* Fix lint warnings

* Fix review comments

* Update tests to use util functions

* Update __init__.py

* Fix review comments
  • Loading branch information
abhikran-quic authored Jun 30, 2022
1 parent 80a0c6c commit 915c23b
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
from .add_subtract_multiply import *
from .argmax import argmax_compute, argmax_schedule
from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule
from .softmax_slice import *
from .clip import *
77 changes: 77 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/batch_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Hexagon slice batch flatten compute and schedule"""
from tvm import te, tir, topi
from ..utils import get_layout_transform_fn


def batch_flatten_compute(inp: te.Tensor) -> te.Tensor:
"""Compute for slice batch flatten op for hexagon.
This op makes the following assumptions:
1. This op is written for a sliced batch flatten operation.
2. The input is assumed to be in NHWC layout.
Parameters
----------
Input : te.Tensor
Input activations padded for inner dimension size
Returns
-------
Output : te.Tensor
Output of applying batch flatten operation on input
"""
return topi.nn.flatten(inp)


def batch_flatten_stir_schedule(
out: te.Tensor,
inp: te.Tensor,
out_layout: str,
in_layout: str,
) -> tir.Schedule:
"""STIR schedule definition for the compute of batch flatten compute.
Parameters
----------
outputs : te.Tensor
The output tensor as returned by a call to batch_flatten_compute
input : te.Tensor
Input tensor to batch_flatten
out_layout: typing.Callable
The transformation function definition for the expected output layout
in_layout: typing.Callable
The transformation function definition for the input layout
Returns
-------
sch : tvm.tir.Schedule
The STIR schedule for slice batch flatten compute
"""

batch_flatten_func = te.create_prim_func([inp, out])
sch = tir.Schedule(batch_flatten_func, debug_mask="all")
compute = sch.get_block("compute")

sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
i, j = sch.get_loops(compute)
jout, channel = sch.split(j, [None, inp.shape[3]])
height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
channelo, channeli = sch.split(channel, [None, 1024])
channelio, channelii = sch.split(channeli, [None, 64])
sch.reorder(i, height, width, channelo, channelio, channelii)
sch.vectorize(channelii)
return sch
14 changes: 14 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def nc_512c_2d(n, c):
return [n, c // 512, te.AXIS_SEPARATOR, c % 512]


def nhwc_1024c_2d(n, h, w, c):
"""Return index map for nhwc_1024 2d layout"""
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def nc_1024_2d(n, c):
"""Return index map for nc_1024 2d layout"""
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def get_layout_transform_fn(layout):
"""Return index map function as per the layout string"""
if layout == "nhwc-8h2w32c2w-2d":
Expand All @@ -77,6 +87,10 @@ def get_layout_transform_fn(layout):
return n11c_1024c_2d
if layout == "n11c-1024c-1d":
return n11c_1024c_1d
if layout == "nhwc-1024c-2d":
return nhwc_1024c_2d
if layout == "nc-1024-2d":
return nc_1024_2d
if layout == "nhw-32h16w-2d":
return nhw_32h16w_2d
if layout == "nhwc-4h4w32c-2d":
Expand Down
6 changes: 6 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
n, h, w, c = arr_np.shape
assert h == 1 and w == 1, "The size of h and w must be 1"
return arr_np.reshape([n, 1, 1, c // 1024, 1024])
if new_layout == "nc-1024-2d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 1024, 1024])
if new_layout == "nhwc-1024c-2d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

Expand Down
101 changes: 101 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_batch_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pytest

import tvm
import tvm.testing
import tvm.topi.hexagon.slice_ops as sl
from tvm import te, topi
from tvm.contrib.hexagon.build import HexagonLauncher
from tvm.topi import testing

from ..infrastructure import allocate_hexagon_array, transform_numpy


class BaseTestBatchFlatten:
input_shape = tvm.testing.parameter(
(1, 1, 1, 2048),
(1, 2, 4, 2048),
(1, 8, 8, 1024),
(2, 4, 8, 1024),
(2, 3, 5, 2048),
)
input_layout, input_axis_sep = tvm.testing.parameters(("nhwc-1024c-2d", [4]))
output_layout, output_axis_sep = tvm.testing.parameters(("nc-1024-2d", [2]))
data_type = tvm.testing.parameter("float16")


class TestBatchFlatten(BaseTestBatchFlatten):
@tvm.testing.fixture
def output_shape(self, input_shape):
return input_shape[0], input_shape[1] * input_shape[2] * input_shape[3]

@tvm.testing.requires_hexagon
def test_batch_flatten(
self,
data_type,
input_shape,
input_layout,
input_axis_sep,
output_shape,
output_layout,
output_axis_sep,
hexagon_session,
):
target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
A = te.placeholder(input_shape, name="A", dtype=data_type)
D = sl.batch_flatten_compute(A)
tir_s = sl.batch_flatten_stir_schedule(
D,
A,
output_layout,
input_layout,
)
func_name = "batch_flatten"
with tvm.transform.PassContext(opt_level=3):
runtime_module = tvm.build(tir_s.mod, target=target, name=func_name)

mod = hexagon_session.load_module(runtime_module)

a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
ref = np.reshape(a_numpy, output_shape)

input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)

a_tvm = allocate_hexagon_array(
hexagon_session.device,
data=input_np_transformed,
axis_separators=input_axis_sep,
mem_scope="global.vtcm",
)
output = allocate_hexagon_array(
hexagon_session.device,
ref_np_transformed.shape,
data_type,
axis_separators=output_axis_sep,
mem_scope="global.vtcm",
)
mod(a_tvm, output)
np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)


if __name__ == "__main__":
tvm.testing.main(pytest.main(sys.argv))

0 comments on commit 915c23b

Please sign in to comment.