From 60a6db261f5124d31679e129f0016542e4a5d36f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 25 Mar 2022 11:19:46 -0700 Subject: [PATCH] [TIR][Analysis] Add SuggestIndexMap for layout rewriting (#10732) This PR added an analysis function `SuggestIndexMap` to analyze buffer access pattern and suggest index map for layout transformations. Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou --- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/function.py | 15 ++ python/tvm/tir/schedule/__init__.py | 2 + python/tvm/tir/schedule/analysis.py | 58 +++++ src/tir/ir/index_map.cc | 2 + src/tir/schedule/analysis.h | 14 ++ src/tir/schedule/analysis/layout.cc | 212 ++++++++++++++++++ .../unittest/test_tir_schedule_analysis.py | 107 +++++++++ 8 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 python/tvm/tir/schedule/analysis.py create mode 100644 src/tir/schedule/analysis/layout.cc create mode 100644 tests/python/unittest/test_tir_schedule_analysis.py diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 17f9aa3d9c604..2d201bb0dab65 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -42,7 +42,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc, TensorIntrin +from .function import PrimFunc, TensorIntrin, IndexMap from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 98af3b4720305..643bbca8eebdb 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -295,3 +295,18 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): final_indices = mapping_function(*args) return IndexMap(args, final_indices) + + def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]: + """Apply the index map to a set of indices + + Parameters + ---------- + indices : List[PriExpr] + The indices to be mapped + + Returns + ------- + result : List[PrimExpr] + The mapped indices + """ + return _ffi_api.IndexMapMapIndices(self, indices) diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5f0e169c43e33..66ac7b9d772b1 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -22,3 +22,5 @@ from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace + +from . import analysis diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py new file mode 100644 index 0000000000000..f2fb7c4f3d1d2 --- /dev/null +++ b/python/tvm/tir/schedule/analysis.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Analysis used in TensorIR scheduling""" +from typing import List, Optional + +from ..buffer import Buffer +from ..stmt import For +from ..expr import PrimExpr +from ..function import IndexMap + +from . import _ffi_api + + +def suggest_index_map( + buffer: Buffer, + indices: List[PrimExpr], + loops: List[For], + predicate: PrimExpr, +) -> Optional[IndexMap]: + """Provided the access pattern to a buffer, suggest one of the possible layout + transformation to maximize the locality of the access pattern. + + Parameters + ---------- + buffer : Buffer + The buffer to be transformed. + indices : List[PrimExpr] + The access pattern to the buffer. + loops : List[For] + The loops above the buffer. + predicate : PrimExpr + The predicate of the access. + + Returns + ------- + index_map : Optional[IndexMap] + The suggested index map. None if no transformation is suggested. + """ + return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member + buffer, + indices, + loops, + predicate, + ) diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index b58b602591ab6..1e54f54a066f1 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -201,5 +201,7 @@ TVM_REGISTER_GLOBAL("tir.IndexMap") return IndexMap(initial_indices, final_indices); }); +TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method(&IndexMapNode::MapIndices); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 4deadcf3eb52e..e74b9ea264845 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -520,6 +521,19 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops); +/*! + * \brief Provided the access pattern to a buffer, suggest one of the possible layout + * transformation to minimize the locality of the access pattern. + * \param buffer The buffer to be transformed + * \param indices The access pattern to the buffer + * \param loops The loops above the buffer + * \param predicate The predicate of the access + * \param analyzer Arithmetic analyzer + */ +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); + /*! * \brief Checks if the given AST contains the specific operators * \param stmt The AST statement to be checked diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc new file mode 100644 index 0000000000000..144b3a55a4675 --- /dev/null +++ b/src/tir/schedule/analysis/layout.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Calculate the strides of the buffer + * \param buffer The buffer + * \return The strides + */ +Array GetStrides(const Buffer& buffer) { + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + return buffer->strides; + } + int ndim = buffer->shape.size(); + if (ndim == 0) { + return {}; + } + Array strides(ndim, PrimExpr{nullptr}); + PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); + for (int i = ndim - 1; i >= 0; --i) { + strides.Set(i, stride); + stride = stride * buffer->shape[i]; + } + return strides; +} + +/*! + * \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern + * to help decision making in layout transformation + */ +class SplitExprCollector { + public: + /*! + * \brief The corresponding IterSplitExpr, simplified for our case + * The pattern is `source // lower_factor % extent * scale` + */ + struct SplitExpr { + /*! \brief The source variable */ + Var source; + /*! \brief The lower factor of the split expression */ + int64_t lower_factor; + /*! \brief The extent of the split expression */ + int64_t extent; + }; + + /*! + * \brief Collect the split expressions in the indexing pattern + * \param index The indexing pattern + * \param input_iters The input iterators' domain + * \param predicate The predicate of the affine map + * \param require_bijective Whether the affine map is required to be bijective + * \param analyzer The analyzer + * \return The collected split expressions + */ + static std::vector Collect(const PrimExpr& index, + const Map& input_iters, // + const PrimExpr& predicate, // + bool require_bijective, // + arith::Analyzer* analyzer) { + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array iter_sum_exprs = arith::DetectIterMap( + {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer, diag_ctx); + if (iter_sum_exprs.empty()) { + return {}; + } + ICHECK_EQ(iter_sum_exprs.size(), 1); + if (iter_sum_exprs[0]->args.size() == 0) { + return {}; + } + SplitExprCollector collector; + collector.Visit(iter_sum_exprs[0]); + if (collector.failed_) { + return {}; + } + return std::move(collector.exprs_); + } + + private: + void Visit(const arith::IterSplitExpr& expr) { + if (const auto* var = expr->source->source.as()) { + const int64_t* lower_factor = as_const_int(expr->lower_factor); + const int64_t* extent = as_const_int(expr->extent); + if (lower_factor == nullptr || extent == nullptr) { + failed_ = true; + return; + } + exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); + } else if (const auto* iter_sum_expr = expr->source->source.as()) { + Visit(GetRef(iter_sum_expr)); + } else { + ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); + } + } + + void Visit(const arith::IterSumExpr& expr) { + for (const arith::IterSplitExpr& arg : expr->args) { + Visit(arg); + } + } + + /*! \brief Whether the analysis failed */ + bool failed_ = false; + /*! \brief The collected split expressions */ + std::vector exprs_; +}; + +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer) { + int ndim = buffer->shape.size(); + int n_loops = loops.size(); + // Step 1. Collect the domains and indices of loop variables + Map input_iters; + std::unordered_map var2id; + var2id.reserve(n_loops); + for (int i = 0; i < n_loops; ++i) { + const For& loop = loops[i]; + input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + var2id.emplace(loop->loop_var.get(), i); + } + // Step 2. Calculate a functor that flattens a multi-dimensional index + auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( + const Array& indices) -> PrimExpr { + PrimExpr flatten_index = make_const(dtype, 0); + for (int i = 0; i < ndim; ++i) { + flatten_index = flatten_index + strides[i] * indices[i]; + } + return flatten_index; + }; + // Step 3. Detect the IterSplitExpr of the indexing pattern + std::vector split_exprs = SplitExprCollector::Collect( + /*index=*/f_flatten_index(indices), input_iters, predicate, + /*require_bijective=*/false, analyzer); + if (split_exprs.empty()) { + return NullOpt; + } + // Step 4. Sort the order of the split expressions + std::vector order(split_exprs.size(), 0); + std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; }); + std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool { + const SplitExprCollector::SplitExpr& a = split_exprs[_a]; + const SplitExprCollector::SplitExpr& b = split_exprs[_b]; + int a_var_id = var2id.at(a.source.get()); + int b_var_id = var2id.at(b.source.get()); + if (a_var_id != b_var_id) { + return a_var_id < b_var_id; + } + return a.lower_factor > b.lower_factor; + }); + // Step 5. Create the indexing mapping + auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), // + split_exprs = std::move(split_exprs), // + order = std::move(order), // + shape = buffer->shape, // + analyzer // + ](Array indices) -> Array { + ICHECK_EQ(indices.size(), shape.size()); + for (int i = 0, n = indices.size(); i < n; ++i) { + analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); + } + PrimExpr index = f_flatten_index({indices.begin(), indices.end()}); + int ndim = split_exprs.size(); + // Step 5.1. Split the flattened index according to `split_exprs` + std::vector split; + split.reserve(ndim); + for (int i = ndim - 1; i >= 0; --i) { + index = analyzer->Simplify(index); + int64_t extent = split_exprs[i].extent; + split.push_back(analyzer->Simplify(floormod(index, extent))); + index = floordiv(index, extent); + } + std::reverse(split.begin(), split.end()); + // Step 5.2. Reorder the indexing pattern according to `order` + Array results; + results.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + results.push_back(split[order[i]]); + } + return results; + }; + return IndexMap::FromFunc(ndim, f_alter_layout); +} + +TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") + .set_body_typed([](Buffer buffer, Array indices, Array loops, + PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py new file mode 100644 index 0000000000000..760b412ac8045 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import List + +from tvm.tir import ( + Evaluate, + For, + ForKind, + IndexMap, + Var, + decl_buffer, + floordiv, + floormod, +) +from tvm.tir.analysis import expr_deep_equal +from tvm.tir.schedule.analysis import suggest_index_map + + +def _make_vars(*args: str) -> List[Var]: + return [Var(arg, dtype="int32") for arg in args] + + +def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]: + assert len(loop_vars) == len(extents) + return [ + For( + loop_var=loop_var, + min_val=0, + extent=extent, + kind=ForKind.SERIAL, + body=Evaluate(0), + ) + for loop_var, extent in zip(loop_vars, extents) + ] + + +def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: + iters_1 = map1.map_indices(map2.initial_indices) + iters_2 = map2.final_indices + assert len(iters_1) == len(iters_2) + for iter1, iter2 in zip(iters_1, iters_2): + assert expr_deep_equal(iter1, iter2) + + +def test_suggest_index_map_simple(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8, 256]), + indices=[ + floordiv(i, 16) * 4 + floordiv(j, 16), + floormod(i, 16) * 16 + floormod(j, 16), + ], + loops=_make_loops( + loop_vars=[i, j], + extents=[32, 64], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x, y: [ + floordiv(x, 4), + floordiv(y, 16), + floormod(x, 4), + floormod(y, 16), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +def test_suggest_index_map_bijective(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8]), + indices=[floormod(j, 4) * 2 + i], + loops=_make_loops( + loop_vars=[i, j], + extents=[2, 32], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x: [ + floormod(x, 2), + floordiv(x, 2), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +if __name__ == "__main__": + test_suggest_index_map_simple() + test_suggest_index_map_bijective()