Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hexagon][topi] add sliced max_pool2 #12169

Merged
merged 1 commit into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
""" Computes and Schedules for Hexagon slice ops. """

from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
from .max_pool2d import max_pool2d_compute, max_pool2d_STIR_schedule
from .add_subtract_multiply import *
from .argmax import argmax_compute, argmax_schedule
from .batch_flatten import batch_flatten_compute, batch_flatten_stir_schedule
Expand Down
196 changes: 196 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/max_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, too-many-locals

""" Compute and schedule for max_pool2d slice op

Please note the following assumptions made by the implementation:

1) The input must be padded in advance to account for 'padding'. In addition,
both input and output must be padded as per the physical buffer layout.

2) The current implementation assumes 'count_include_pad' to be 'True'. It can be
modified to support 'False' case but the element count for the pooling window
must be pre-computed and provided as an input to reduce the run-time overhead.

3) 'padding' is ignored. It must be handled outside of the sliced op.

4) This implementation will not work if the output includes any physical layout
related padding, as it can result into out-of-bound access for the input.
"""

from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn


def validate_out_shape(out_shape, in_shape, kernel, stride, dilation):
"""Validate output shape"""
_, oh, ow, _ = out_shape
_, ih, iw, _ = in_shape
kh, kw = kernel
sh, sw = stride
dh, dw = dilation
if ih < (oh - 1) * sh + dh * (kh - 1) + 1:
raise RuntimeError("Output height is too large")
if iw < (ow - 1) * sw + dw * (kw - 1) + 1:
raise RuntimeError("Output width is too large")


def max_pool2d_compute(A, out_shape, kernel, stride, dilation):
"""max_pool2d compute"""
kh, kw = kernel
rh = te.reduce_axis((0, kh), name="rh")
rw = te.reduce_axis((0, kw), name="rw")
ob, oh, ow, oc = out_shape
if isinstance(ob, int):
validate_out_shape(out_shape, A.shape, kernel, stride, dilation)

sh, sw = stride
dh, dw = dilation

Max = te.compute(
out_shape,
lambda b, h, w, c: te.max(
A[b, h * sh + dh * rh, w * sw + dw * rw, c].astype(A.dtype), axis=[rh, rw]
),
name="max",
)
return Max


def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
"""Schedule for input and output layout nhwc-8h2w32c2w"""
func = te.create_prim_func([ins, outs])
s = tir.Schedule(func)

# NOTE!!! This scheduling logic is a work in progress.
# It is not known to ultimately result in near-optimal Hexagon performance.
# The schedule below strives to implement these heuristics:
#
# (1) For mathematical operations on tensor values, prefer HVX SIMD operations
# over per-element scalar operations.
#
# (2) Minimize the number of memory transfers used to operate on tensor values:
# host-memory <--> Hexagon DDR <--> VTCM <--> HVX registers
#
# As a consequence of (1) + (2), prefer TIR schedules that load each value
# into an HVX SIMD tensor exactly once.

Max = s.get_block("max")

input_transform_fn = get_layout_transform_fn(input_layout)
output_transform_fn = get_layout_transform_fn(output_layout)

s.transform_layout(Max, ("read", 0), input_transform_fn)
s.transform_layout(Max, ("write", 0), output_transform_fn)

# pylint: disable=line-too-long
#
# Restructure the loop nestings to have this overall structure:
# (loop over different 128-byte output-tensor chunks) : n, ho, wo, co }- the first level of a two-level tensor layout
# (loop within one 128-byte output-tensor chunk) : hi, wio, ci, wii }- the second level of a two-level tensor layout
# (loop over reduction axes) : rh, rw }- loop over multiple elements of the input tensor
#
# Note: This schedule is a work in progress. We *expect* that it's
# crucially important for the loops to have this relative ordering:
# n ... ho ... wo ... co ... hi ... wio ... ci ... wii
# because it lets us visit each of the 128-byte output chunks precisely once.

(
n,
h,
w,
c,
rh,
rw,
) = s.get_loops(Max)

# Restructure the loops from NHWC to nhwc_8h2w32c2w, with loops for 'max's reduction
# axes at the very end.
ho, hi = s.split(h, [None, 8])
wo, wi = s.split(w, [None, 4])
wio, wii = s.split(wi, [None, 2])
co, ci = s.split(c, [None, 32])
s.reorder(n, ho, wo, co, hi, wio, ci, wii, rh, rw)
cconvey marked this conversation as resolved.
Show resolved Hide resolved

# TODO: Enable vectorization.
# Hexagon v69's HVX units support SIMD operations on 64-element float16 vectors.
#
# TVM's 'vectorize' schedule primitive is the idiomatic way to encourage lower layers of the
# compiler to generate this kind of SIMD object code.
#
# Several requirements must be met to use 'vectorize':
#
# 1) It can only be applied to a schedule's innermost loop variable.
#
# 2) Any block-iterator(s) bound to that innermost loop variable must be
# *data-parallel* block iterators.
#
# 3) Ideally, the innermost loop variable will iterate only over the output
# tensor's fastest-changing indices and nothing else. But in our case,
# our two innermost loops correspond to the the max operator's reduction axes.
#
# Finding a good way to satisfy all of these requirements at the same time is
# left for future work.

# ci_wii = s.fuse(ci, wii)
# s.vectorize(ci_wii_rh_rw)

return s


def STIR_schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str):
"""Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w"""

# NOTE: This function is a variation of the STIR_schedule_nhwc_8h2w32c2w
# functions. Most of that function's code comments apply to this function
# as well, but are ommited for brevity.

# NOTE: the "n11c-1024c" output layout is shorthand for this axis mapping:
# [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
func = te.create_prim_func([ins, outs])

s = tir.Schedule(func)
Max = s.get_block("max")

input_transform_fn = get_layout_transform_fn(input_layout)
output_transform_fn = get_layout_transform_fn(output_layout)
s.transform_layout(Max, ("read", 0), input_transform_fn)
s.transform_layout(Max, ("write", 0), output_transform_fn)

(
n,
h,
w,
c,
rh,
rw,
) = s.get_loops(Max)
co, ci = s.split(c, [None, 1024])
# s.vectorize(ci)

return s


def max_pool2d_STIR_schedule(outs, ins, output_layout: str, input_layout: str):
"""STIR based schedule"""
if output_layout == "nhwc-8h2w32c2w-2d":
return STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout, input_layout)
if output_layout == "n11c-1024c-2d":
return STIR_schedule_n11c_1024c(outs, ins, output_layout, input_layout)
raise RuntimeError(f"Unexpected layout '{output_layout}'")
Loading