Skip to content

Commit

Permalink
[Topi][Hexagon] Implement Cast F32ToF16 and F16ToF32 Slice Op (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
arangasa authored and Mikael Sevenier committed Jul 26, 2022
1 parent 0456870 commit daf081f
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,11 @@
from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule
from .softmax_slice import *
from .clip import *
from .cast import (
cast_f16_f32_compute,
cast_f16_f32_schedule,
cast_f32_f16_compute,
cast_f32_f16_schedule,
)
from .conv2d import *
from .reshape import reshape_compute, reshape_stir_schedule
143 changes: 143 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Hexagon slice cast op compute and schedule"""

from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn


def get_layout_transform_for_f32(f32_layout_string):
"""
Given f32 layout string, return transform_layout function and
channel/height split factor to be used for scheduling
"""
layout_transform_fn = get_layout_transform_fn(f32_layout_string)
if f32_layout_string == "nhwc-8h2w32c2w-2d":
return [layout_transform_fn, 8]
if f32_layout_string == "nhwc-4h2w32c2w-2d":
return [layout_transform_fn, 4]
if f32_layout_string == "nc-1024c-2d":
return [layout_transform_fn, 1024]
if f32_layout_string == "nc-512c-2d":
return [layout_transform_fn, 512]
raise RuntimeError(f"Unexpected f32_layout '{f32_layout_string}'")


def cast_f16_f32_compute(in_tensor):
out_tensor = te.compute(
in_tensor.shape, lambda *indices: in_tensor[indices].astype("float32"), name="CastF16F32"
)
return out_tensor


def cast_f16_f32_stir_schedule_nhwc(func, in_layout, out_layout, h_split_factor):
"""Schedule for nhwc f16 to f32 cast: nhwc layout"""
sch = tir.Schedule(func, debug_mask="all")
block_name = "CastF16F32"
n_orig, h_orig, w_orig, c_orig = sch.get_loops(sch.get_block(block_name))
h_outer, h_inner = sch.split(h_orig, [None, h_split_factor])
w_outer, w_inner = sch.split(w_orig, [None, 4])
c_outer, c_inner = sch.split(c_orig, [None, 32])
w_inner_o, w_inner_i = sch.split(w_inner, [None, 2])
sch.reorder(n_orig, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner, w_inner_i)
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
fused = sch.fuse(c_inner, w_inner_i)
sch.vectorize(fused)
return sch


def cast_f16_f32_stir_schedule_nc(func, in_layout, out_layout, c_split_factor):
"""Schedule for nc f16 to f32 cast: nc layout"""
sch = tir.Schedule(func, debug_mask="all")
block_name = "CastF16F32"
_, c_orig = sch.get_loops(sch.get_block(block_name))
_, c_inner = sch.split(c_orig, [None, c_split_factor])
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
sch.vectorize(c_inner)
return sch


def cast_f16_f32_schedule(cast_func, in_layout_str, out_layout_str):
"""Schedule for f16 to f32 cast: top level function"""
f32_layout_transform_func, split_factor = get_layout_transform_for_f32(out_layout_str)
f16_layout_transform_func = get_layout_transform_fn(in_layout_str)
if in_layout_str == "nhwc-8h2w32c2w-2d":
return cast_f16_f32_stir_schedule_nhwc(
cast_func,
f16_layout_transform_func,
f32_layout_transform_func,
split_factor,
)
if in_layout_str == "nc-1024c-2d":
return cast_f16_f32_stir_schedule_nc(
cast_func, f16_layout_transform_func, f32_layout_transform_func, split_factor
)
raise RuntimeError(f"Unexpected input_layout, output_layout '{input_layout, output_layout}'")


def cast_f32_f16_compute(in_tensor):
out_tensor = te.compute(
in_tensor.shape, lambda *indices: in_tensor[indices].astype("float16"), name="CastF32F16"
)
return out_tensor


def cast_f32_f16_stir_schedule_nhwc(func, in_layout, out_layout, h_split_factor):
"""Schedule for nhwc f32 to f16 cast: nhwc layout"""
sch = tir.Schedule(func, debug_mask="all")
block_name = "CastF32F16"
n_orig, h_orig, w_orig, c_orig = sch.get_loops(sch.get_block(block_name))
h_outer, h_inner = sch.split(h_orig, [None, h_split_factor])
w_outer, w_inner = sch.split(w_orig, [None, 4])
c_outer, c_inner = sch.split(c_orig, [None, 32])
w_inner_o, w_inner_i = sch.split(w_inner, [None, 2])
sch.reorder(n_orig, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner, w_inner_i)
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
fused = sch.fuse(c_inner, w_inner_i)
sch.vectorize(fused)
return sch


def cast_f32_f16_stir_schedule_nc(func, in_layout, out_layout, c_split_factor):
"""Schedule for nc f32 to f16 cast: nc layout"""
sch = tir.Schedule(func, debug_mask="all")
block_name = "CastF32F16"
_, c_orig = sch.get_loops(sch.get_block(block_name))
_, c_inner = sch.split(c_orig, [None, c_split_factor])
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
sch.vectorize(c_inner)
return sch


def cast_f32_f16_schedule(cast_func, in_layout_str, out_layout_str):
"""Schedule for f32 to f16 cast: top level function"""
f32_layout_transform_func, split_factor = get_layout_transform_for_f32(in_layout_str)
f16_layout_transform_func = get_layout_transform_fn(out_layout_str)
if out_layout_str == "nhwc-8h2w32c2w-2d":
return cast_f32_f16_stir_schedule_nhwc(
cast_func, f32_layout_transform_func, f16_layout_transform_func, split_factor
)
if out_layout_str == "nc-1024c-2d":
return cast_f32_f16_stir_schedule_nc(
cast_func, f32_layout_transform_func, f16_layout_transform_func, split_factor
)
raise RuntimeError(f"Unexpected input_layout, output_layout '{in_layout_str, out_layout_str}'")
14 changes: 14 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def nc_512c_2d(n, c):
return [n, c // 512, te.AXIS_SEPARATOR, c % 512]


def nc_1024c_2d(n, c):
"""Return index map for nc_1024c 2d layout"""
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def nhwc_4h2w32c2w_2d(n, h, w, c):
"""Return index map for nhwc_4h2w32c2w 2d layout"""
return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, (w % 4) // 2, c % 32, w % 2]


def nhwc_1024c_2d(n, h, w, c):
"""Return index map for nhwc_1024 2d layout"""
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
Expand Down Expand Up @@ -113,6 +123,10 @@ def get_layout_transform_fn(layout):
return nc_512c_2d
if layout == "nc-512c-1d":
return nc_512c_1d
if layout == "nhwc-4h2w32c2w-2d":
return nhwc_4h2w32c2w_2d
if layout == "nc-1024c-2d":
return nc_1024c_2d
if layout == "iohw-16i32o2i-1d":
return iohw_16i32o2i_1d
raise RuntimeError(f"Unexpected layout '{layout}'")
12 changes: 12 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
return arr_np.reshape([n, h // 8, 8, w // 4, 2, 2, c // 32, 32]).transpose(
0, 1, 3, 6, 2, 4, 7, 5
)
if new_layout in ["nhwc-4h2w32c2w-2d"]:
n, h, w, c = arr_np.shape
return arr_np.reshape([n, h // 4, 4, w // 4, 2, 2, c // 32, 32]).transpose(
0, 1, 3, 6, 2, 4, 7, 5
)
if new_layout in ["n11c-1024c-2d", "n11c-1024c-1d"]:
n, h, w, c = arr_np.shape
assert h == 1 and w == 1, "The size of h and w must be 1"
Expand All @@ -251,7 +256,14 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
if new_layout == "nhwc-1024c-2d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])
raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

if current_layout == "nc":
n, c = arr_np.shape
if new_layout in ["nc-1024c-2d"]:
return arr_np.reshape([n, c // 1024, 1024])
if new_layout in ["nc-512c-2d"]:
return arr_np.reshape([n, c // 512, 512])
raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

if current_layout == "nhw":
Expand Down
Loading

0 comments on commit daf081f

Please sign in to comment.