Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] [Hexagon] Reshape slice op #11983

Merged
merged 4 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .softmax_slice import *
from .clip import *
from .conv2d import *
from .reshape import reshape_compute, reshape_stir_schedule
108 changes: 108 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Hexagon slice reshape compute and schedule"""
from tvm import te, tir, topi
from ..utils import get_layout_transform_fn


def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor:
"""Compute for slice reshape op for hexagon.
This op makes the following assumptions:
1. This op is written for a sliced reshape operation.
2. The input is assumed to be in NHWC layout.

Parameters
----------
Input : te.Tensor
Input tensor
New Shape: tuple
Output shape
Returns
-------
Output : te.Tensor
Output of applying reshape operation on input
"""
return topi.transform.reshape(inp, new_shape)


def stir_schedule_nhwc_1024c(
out: te.Tensor,
inp: te.Tensor,
out_layout: str,
in_layout: str,
) -> tir.Schedule:
"""Schedule for output layout: nhwc-1024c-2d"""
reshape_func = te.create_prim_func([inp, out])
sch = tir.Schedule(reshape_func, debug_mask="all")
compute = sch.get_block("T_reshape")

sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
i, j = sch.get_loops(compute)
jout, channel = sch.split(j, [None, inp.shape[3]])
height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
channelo, channeli = sch.split(channel, [None, 1024])
channelio, channelii = sch.split(channeli, [None, 64])
sch.reorder(i, height, width, channelo, channelio, channelii)
sch.vectorize(channelii)
return sch


def stir_schedule_nhwc_8h2w32c2w(
out: te.Tensor,
inp: te.Tensor,
out_layout: str,
in_layout: str,
) -> tir.Schedule:
"""Schedule for input and output layout nhwc-8h2w32c2w"""
reshape_func = te.create_prim_func([inp, out])
sch = tir.Schedule(reshape_func, debug_mask="all")
compute = sch.get_block("T_reshape")

sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
return sch


def reshape_stir_schedule(
out: te.Tensor,
inp: te.Tensor,
output_layout: str,
input_layout: str,
) -> tir.Schedule:
"""STIR schedule definition for the compute of reshape compute.
Parameters
----------
outputs : te.Tensor
The output tensor as returned by a call to reshape_compute
input : te.Tensor
Input tensor to reshape
out_layout: str
The transformation function definition for the expected output layout
in_layout: str
The transformation function definition for the input layout
Returns
-------
sch : tvm.tir.Schedule
The STIR schedule for slice reshape compute
"""
if output_layout == "nhwc-8h2w32c2w-2d":
return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout)
if output_layout == "nc-1024-2d":
return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout)
raise RuntimeError(f"Unexpected layout '{output_layout}'")
101 changes: 0 additions & 101 deletions tests/python/contrib/test_hexagon/topi/test_batch_flatten.py

This file was deleted.

168 changes: 168 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pytest

import tvm
import tvm.testing
import tvm.topi.hexagon.slice_ops as sl
from tvm import te, topi
from tvm.contrib.hexagon.build import HexagonLauncher
from tvm.topi import testing

from ..infrastructure import allocate_hexagon_array, transform_numpy


def reshape_helper(
func,
fcompute,
fschedule,
data_type,
input_shape,
input_layout,
output_shape,
output_layout,
hexagon_session,
):

target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
A = te.placeholder(input_shape, name="A", dtype=data_type)
if func == "reshape":
D = fcompute(A, output_shape)
elif func == "batch_flatten":
D = fcompute(A)
else:
raise RuntimeError(f"Unexpected func'{func}'")
tir_s = fschedule(
D,
A,
output_layout,
input_layout,
)
with tvm.transform.PassContext(opt_level=3):
print("output of tvm.lower", tvm.lower(tir_s.mod, name=func))
runtime_module = tvm.build(tir_s.mod, target=target, name=func)

mod = hexagon_session.load_module(runtime_module)

a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
ref = np.reshape(a_numpy, output_shape)

input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
input_axis_sep = [4]
if output_layout == "nhwc-8h2w32c2w-2d":
output_axis_sep = [4]
elif output_layout == "nc-1024-2d":
output_axis_sep = [2]
else:
raise RuntimeError(f"Unexpected layout '{output_layout}'")
a_tvm = allocate_hexagon_array(
hexagon_session.device,
data=input_np_transformed,
axis_separators=input_axis_sep,
mem_scope="global.vtcm",
)
output = allocate_hexagon_array(
hexagon_session.device,
ref_np_transformed.shape,
data_type,
axis_separators=output_axis_sep,
mem_scope="global.vtcm",
)
mod(a_tvm, output)
np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)


batch_flatten_tests = (
([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([2, 4, 8, 1024], [2, 4 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
)


class BaseTestBatchFlatten:
(
input_shape,
output_shape,
input_layout,
output_layout,
data_type,
) = tvm.testing.parameters(*batch_flatten_tests)


class TestBatchFlatten(BaseTestBatchFlatten):
@tvm.testing.requires_hexagon
def test_batch_flatten(
self,
data_type,
input_shape,
input_layout,
output_shape,
output_layout,
hexagon_session,
):
reshape_helper(
"batch_flatten",
sl.batch_flatten_compute,
sl.batch_flatten_stir_schedule,
data_type,
input_shape,
input_layout,
output_shape,
output_layout,
hexagon_session,
)


class BaseTestReshape(BaseTestBatchFlatten):
(input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
*batch_flatten_tests,
([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
)


class TestReshape(BaseTestReshape):
@tvm.testing.requires_hexagon
def test_reshape(
self,
data_type,
input_shape,
input_layout,
output_shape,
output_layout,
hexagon_session,
):
reshape_helper(
"reshape",
sl.reshape_compute,
sl.reshape_stir_schedule,
data_type,
input_shape,
input_layout,
output_shape,
output_layout,
hexagon_session,
)


if __name__ == "__main__":
tvm.testing.main()