diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 4cf6f086d1ed..2c0e5e92997a 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); }; +/*! \brief Mapping level for iterators. */ +enum IterMapLevel { + // Require the mapping to be bijective. + Bijective = 0, + // Require the mapping to be surjective. + Surjective = 1, + // No mapping safety check. + NoCheck = 3 +}; + /*! - * \brief Detect if indices can be written as - * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] - * - * Here y = some-quasi-affine-iter-map(input_iters) - * and c are symbolic constants. - * - * We also requires that y_i and y_j to be independent for i != j. - * - * For returned value rv, the following is always true: - * - rv[i]->args.size() <=1: only one iterator per element. - * - * \param indices The indices to detect pattern for. - * \param input_iters Map from variable to iterator's range. - * \param predicate The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. - * \param analyzer Analyzer used to get context information. - * \param simplify_trivial_iterators If true, iterators with extent of - * 1 will be replaced with a constant value. - * - * \return The detected pattern if a match exists, - * otherwise return an empty array. + * \brief Result of DetectIterMap. */ -Array DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); +class IterMapResultNode : public Object { + public: + // The detected pattern if a match exists. + Array indices; -/*! \brief A utility struct for return values from DetectPaddedIterMap - */ -struct PaddedIterMapResult { // Any errors that occurred while converting the input indices. If // the array is empty, the conversion was successful. Array errors; - // The detected pattern if a match exists. - Array indices; - - /* \brief Boolean expression indicating if padding was required - * - * `requires_padding` evaluates to true if the returned indices - * contain padding relative to the provided expressions, and false - * otherwise. If `input_iters` contains a variable extent, this - * expression may be in terms of those variables. - */ - PrimExpr requires_padding; - - /* \brief Boolean expression indicating if a specific value w + /*! \brief Boolean expression indicating if a specific value w * * `padding_predicate` evaluates to true for a set of indices that * are outside the bounds of the provided index iterators, but @@ -314,43 +290,57 @@ struct PaddedIterMapResult { * `input_iters`. */ PrimExpr padding_predicate; + + // overrides + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("errors", &errors); + v->Visit("indices", &indices); + v->Visit("padding_predicate", &padding_predicate); + } + + static constexpr const char* _type_key = "arith.IterMapResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object); +}; + +/*! + * \brief Managed reference to IterMapResultNode. + * \sa IterMapResultNode + */ +class IterMapResult : public ObjectRef { + public: + // constructor + IterMapResult() { data_ = make_object(); } + + /*! \return mutable pointers to the node. */ + IterMapResultNode* operator->() const { return static_cast(get_mutable()); } }; /*! * \brief Detect if indices can be written as * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] * - * Here y = some-quasi-affine-iter-map(input_iters) and c are - * symbolic constants. The y_i iterators may be padded to fit this - * representation. + * Here y = some-quasi-affine-iter-map(input_iters) + * and c are symbolic constants. * * We also requires that y_i and y_j to be independent for i != j. * * For returned value rv, the following is always true: - * - rv.indices[i]->args.size() <=1: only one iterator per element. + * - rv[i]->args.size() <=1: only one iterator per element. * * \param indices The indices to detect pattern for. - * * \param input_iters Map from variable to iterator's range. - * * \param predicate The predicate constraints on the input iterators - * - * \param require_bijective A boolean flag that indicates whether the - * mapping should be bijective. If true, no padding may be - * introduced. - * + * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. - * * \param simplify_trivial_iterators If true, iterators with extent of * 1 will be replaced with a constant value. * - * \return An instance of PaddedIterMapResult. + * \return The detected iteration result. + * The return object's .indices is empty on failure. */ -PaddedIterMapResult DetectPaddedIterMap(const Array& indices, - const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, - bool simplify_trivial_iterators = true); +IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices @@ -358,12 +348,12 @@ PaddedIterMapResult DetectPaddedIterMap(const Array& indices, * \param indices The indices to detect pattern for. * \param input_iters Map from variable to iterator's range. * \param input_pred The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * \param check_level The iter mapping checking level. * * \return The indices after rewrite */ Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool require_bijective); + const PrimExpr& input_pred, IterMapLevel check_level); /*! * \brief Apply the inverse of the affine transformation to the outputs. @@ -403,7 +393,7 @@ Map InverseAffineIterMap(const Array& iter_map, * \param input_iters Map from variable to iterator's range. * \param sub_iters Iterators of subspace. * \param predicate The predicate constraints on the input iterators - * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * \param check_level The iter mapping checking level. * \param analyzer Analyzer used to get context information. * * \return The result list has length len(bindings) + 1 @@ -416,7 +406,7 @@ Map InverseAffineIterMap(const Array& iter_map, Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer); + IterMapLevel check_level, arith::Analyzer* analyzer); /*! * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 2be939a12277..77d6f418b853 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """ Iterator (quasi)affine mapping patterns.""" +from enum import IntEnum import tvm._ffi from tvm.runtime import Object from tvm.ir import PrimExpr @@ -88,11 +89,35 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) +class IterMapLevel(IntEnum): + """Possible kinds of iter mapping check level.""" + + Bijective = 0 + Surjective = 1 + NoCheck = 3 + + @staticmethod + def from_str(name: str): + """Helper to create level enum from string""" + if name is None: + return IterMapLevel.NoCheck + name = name.lower() + if name == "bijective": + check_level = IterMapLevel.Bijective + elif name == "surjective": + check_level = IterMapLevel.Surjective + elif name == "nocheck": + check_level = IterMapLevel.NoCheck + else: + raise ValueError(f"Unknown check level {name}") + return check_level + + def detect_iter_map( indices, input_iters, predicate=True, - require_bijective=False, + check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True, ): """Detect if indices can be written as mapped iters from input iters @@ -108,8 +133,8 @@ def detect_iter_map( predicate : PrimExpr The predicate constraints on the input iterators - require_bijective : bool - A boolean flag that indicates whether the mapping should be bijective + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping simplify_trivial_iterators: bool If true, iterators with extent of 1 will be replaced with a @@ -117,13 +142,17 @@ def detect_iter_map( Returns ------- - results : List[IterSumExpr] + results : IterMapResult The iter map matching result. - Empty array if no match can be found. + The result's .indices is empty array if no match can be found. """ + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + elif check_level is None: + check_level = IterMapLevel.NoCheck return _ffi_api.DetectIterMap( - indices, input_iters, predicate, require_bijective, simplify_trivial_iterators + indices, input_iters, predicate, check_level, simplify_trivial_iterators ) @@ -143,7 +172,9 @@ def normalize_iter_map_to_expr(expr): return _ffi_api.NormalizeIterMapToExpr(expr) -def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False): +def subspace_divide( + bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective +): """Detect if bindings can be written as [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters) @@ -172,8 +203,8 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi predicate : PrimExpr The predicate constraints on the input iterators - require_bijective : bool - A boolean flag that indicates whether the bindings should be bijective + check_level : Union[str, IterMapLevel] + Checking level of iteration mapping Returns ------- @@ -185,7 +216,9 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi len(bindings): the predicate of outer space and inner space Empty array if no match can be found. """ - return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) + if isinstance(check_level, str): + check_level = IterMapLevel.from_str(check_level) + return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level) def inverse_affine_iter_map(iter_map, outputs): diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a3fa879afa27..48fae479b042 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -867,9 +867,10 @@ Optional> EstimateRegionLowerBound(const Array& region, for (const Range& range : region) { affine_indices.push_back(range->min); } - iter_sum_exprs = DetectIterMap( + auto res = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, - /*predicate=*/predicate, /*require_bijective=*/false, analyzer); + /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); + iter_sum_exprs = res->indices; } if (iter_sum_exprs.empty()) { return NullOpt; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 9fad3b2816a1..cce826fedca6 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -178,10 +178,7 @@ class IterMapRewriter : public ExprMutator { explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, bool simplify_trivial_iterators, Array* errors) - : analyzer_(analyzer), - errors_(*errors), - requires_padding_(const_false()), - padding_predicate_(const_false()) { + : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -202,16 +199,17 @@ class IterMapRewriter : public ExprMutator { } PrimExpr padding_predicate() const { return padding_predicate_; } - PrimExpr requires_padding() const { return requires_padding_; } + bool requires_padding() const { return requires_padding_; } IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - void UpdatePadding(const PrimExpr& expr) { + IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) { update_iterator_padding_ = true; - DirectMutate(expr); + auto res = Rewrite(expr); update_iterator_padding_ = false; + return res; } IterSumExpr RewriteIterConstraint(const PrimExpr& expr, @@ -222,7 +220,7 @@ class IterMapRewriter : public ExprMutator { } /*! - * \brief If require_bijective is true, this function checks two conditions: + * \brief If require bijective mapping, this function checks two conditions: * - C0: Each iter mark should be fully covered by non-overlapping splits. * - C1: All of the input iterators are used. * Example: given x in [0, 8) y in [0, 6) @@ -232,7 +230,7 @@ class IterMapRewriter : public ExprMutator { * contribute two non-overlapping splits that covers x. * - bindings = [x / 4, x % 4] won't pass because y is not used. * - * If require_bijective is false, this function checks one condition: + * If only require surjective mapping, this function checks one condition: * - C0: Each iter mark has a chance to be fully covered by non-overlapping splits. * Example: given x in [0, 8) y in [0, 6) * - bindings = [x / 4] will pass because x / 4 can be one split of x @@ -241,7 +239,7 @@ class IterMapRewriter : public ExprMutator { * - bindings = [x / 3] will not pass because x / 3 can not be one split of x * \return whether the bindings are valid */ - bool CheckMapping(const Array& bindings, bool require_bijective) { + bool CheckMapping(const Array& bindings, IterMapLevel check_level) { IterMarkSplitCollector collector; // We can check that for each iter mark: // All the splits that refers to the iter_mark covers its extent. @@ -249,11 +247,11 @@ class IterMapRewriter : public ExprMutator { collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { - if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) { + if (TryNormalizeSplits(mark, collector.mark2splits_[mark], check_level).empty()) { return false; } } - if (require_bijective) { + if (check_level == IterMapLevel::Bijective) { // all input marks must be visited for (const IterMark& mark : input_marks_) { if (collector.visited_.count(mark) == 0 && !is_one(mark->extent)) { @@ -375,13 +373,14 @@ class IterMapRewriter : public ExprMutator { }; struct IterPaddingInfo { - // Used and collected during first pass - std::vector divisors; + // GCD of padding factor collected during first pass + PrimExpr padding_factor{1}; + + PrimExpr left_pad{0}; + PrimExpr right_pad{0}; - // Defined on first encounter in second pass - IterSplitExpr padded; - PrimExpr left_pad; - PrimExpr right_pad; + // Padded form of original iter mark + IterMark padded; }; // temp hash for de-duplication purposes. @@ -427,41 +426,30 @@ class IterMapRewriter : public ExprMutator { // input iter marks std::vector input_marks_; - // Map from a normal PrimExpr to the padded iterator information for + // Map from an iter mark to the padded iterator information for // it. This is necessary for introducing the same padding in all // usage of an input iterator. (e.g. (i-1) occurring in the // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be // left-padded by 31 for each occurrence.) - std::unordered_map padded_iter_map_; + std::unordered_map padded_iter_map_; + + // Map from padded iter mark to it's origin mark + std::unordered_map padded_origin_map_; - /* If allow_padding_ is true, allow the extents of the IterMap to be + /* If update_iterator_padding_ is true, allow the extents of the IterMap to be * padded beyond the original iterators. * - * For example, if allow_padding_ is true, the expressions i//4 and + * For example, if update_iterator_padding_ is true, the expressions i//4 and * i%4, where i is on the range [0,18), would be represented as * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4). - * This representation would be forbidden if allow_padding_ is false, + * This representation would be forbidden if update_iterator_padding_ is false, * because lower_factor=4 does not evenly divide the original extent of * 18. */ bool update_iterator_padding_{false}; - /* A boolean expression that is true if any padding has been introduced - * by the transformation, and false otherwise. - * - * Example: [i//4, i%4], i in range [0,16) - * requires_padding_ will be false - * - * Example: [i//4, i%4], i in range [0,18) - * requires_padding_ will be true - * - * Example: [i//4, i%4], i in range [0,N) - * requires_padding_ will be the expression N%4==0 - */ - PrimExpr requires_padding_; - /* A boolean expression that is true for any padding that has been - * introduced, and false otherwise. If allow_padding_ is false, + * introduced, and false otherwise. If update_iterator_padding_ is false, * padding_predicate_ will always be false. * * Example: [i//4, i%4], i in range [0,16) @@ -475,6 +463,11 @@ class IterMapRewriter : public ExprMutator { */ PrimExpr padding_predicate_; + /* A boolean flag denotes there are padding iterations detected + * in the first round of indices rewriting. + */ + bool requires_padding_{false}; + // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly // an extra offset) // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) @@ -538,13 +531,12 @@ class IterMapRewriter : public ExprMutator { * If not, return an empty array. * \param mark The iterator of interest. * \param splits The splits to be verified. - * \param require_bijective A boolean flag that indicates whether the bindings should be - * bijective. + * \param check_level Iteration mapping's check level. * \return The normalized splits. */ Array TryNormalizeSplits(const IterMark& mark, const std::vector& splits, - bool require_bijective) { + IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); @@ -559,7 +551,7 @@ class IterMapRewriter : public ExprMutator { } if (j == splits.size()) { // we do not allow incomplete split if the bindings should be bijective - if (require_bijective) { + if (check_level == IterMapLevel::Bijective) { return Array(); } // look for the next split skipping this lower factor @@ -578,18 +570,64 @@ class IterMapRewriter : public ExprMutator { expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; } + // Extract iteration mark info before padding + auto pad_mark_it = padded_origin_map_.find(mark); + bool has_padding = pad_mark_it != padded_origin_map_.end(); + + bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent); + bool match_iter_divisor = + match_full_iter || CanProveDivisible(mark->extent, expected_lower_factor); + // Case 1. bijective is required. - // We check the extent we calculate is consistent with the extent of the mark - // Case 2. bijective is not required. + // We check the extent we calculate is consistent with the extent of the mark and + // iteration mark's padding is not allowed. + // + // Case 2. bijective is not required and there is no padding. // We check the extent we calculate is a factor of the extent of the mark // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. - if (require_bijective) { - if (!analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) { - return Array(); + // + // Case 3. bijective is not required and there exists padding. We check either + // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is + // the single split for the iter mark. + // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective + // according to how we pad the original iteration mark. + // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent + // before padding is greater or equal than the extent we calculate. + // For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24. + // + if (check_level == IterMapLevel::Bijective) { + if (has_padding) { + ErrorLogger(this) << "Bijectvie mapping should not take iter paddings"; + return {}; + } else if (!match_full_iter) { + ErrorLogger(this) << "The iterations do not traverse full iter space"; + return {}; } - } else { - if (!CanProveDivisible(mark->extent, expected_lower_factor)) { - return Array(); + } else if (!has_padding) { + if (!match_iter_divisor) { + ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + return {}; + } + } else if (check_level == IterMapLevel::Surjective) { + PrimExpr extent_before_padding = pad_mark_it->second->extent; + if (match_full_iter) { + if (splits.size() != 1) { + ErrorLogger(this) << "Dependent iterations on padding iter space"; + return Array(); + } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) && + !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { + ErrorLogger(this) << "Split on padding iteration is not surjective " + << "if the split extent equals to the full iter space extent"; + return Array(); + } + } else if (match_iter_divisor) { + if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { + ErrorLogger(this) << "The extent before padding is less than lower factor"; + return Array(); + } + } else { + ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + return {}; } } return Array(iters.rbegin(), iters.rend()); @@ -1018,39 +1056,23 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { return true; } -Array DetectIterMap(const Array& indices, const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, bool simplify_trivial_iterators) { - auto padded_result = DetectPaddedIterMap(indices, input_iters, predicate, require_bijective, - analyzer, simplify_trivial_iterators); - if (padded_result.errors.size()) { - return Array(); - } - if (!analyzer->CanProve(!padded_result.requires_padding)) { - return Array(); - } - return padded_result.indices; -} - -PaddedIterMapResult DetectPaddedIterMap(const Array& indices, - const Map& input_iters, - const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer, - bool simplify_trivial_iterators) { - PaddedIterMapResult result; +IterMapResult DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, IterMapLevel check_level, + arith::Analyzer* analyzer, bool simplify_trivial_iterators) { + IterMapResult result; // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. if (!IterRangeSanityCheck(input_iters)) { - result.errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); + result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other."); return result; } Map constrained_input_iters = input_iters; std::vector constraints; if (!is_one(predicate) && !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { - result.errors.push_back("Could not parse predicate as constraints on the input iterators."); + result->errors.push_back("Could not parse predicate as constraints on the input iterators."); return result; } // We have to make sure when we visit an iterator, all the constraints related with its successors @@ -1065,58 +1087,65 @@ PaddedIterMapResult DetectPaddedIterMap(const Array& indices, [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; }); IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators, - &result.errors); + &result->errors); // Step0.0: rewrite constraints in the order from size-small ones to size-big ones for (const IterConstraint& constraint : constraints) { auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound); - if (result.errors.size()) { + if (result->errors.size() > 0) { return result; } } if (!rewriter.CheckConstraints()) { - result.errors.push_back("Invalid constraints."); + result->errors.push_back("Invalid constraints."); return result; } - // Step0.1: Check each index to determine required padding - bool allow_padding = !require_bijective; + // Step0.1: Rewrite indicies and determine required padding, + // if there is no padding, it should be the final result. + Array rewrite_indices; + rewrite_indices.reserve(indices.size()); + bool allow_padding = check_level != IterMapLevel::Bijective; if (allow_padding) { for (PrimExpr value : indices) { - rewriter.UpdatePadding(value); + rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value)); + if (result->errors.size() > 0) { + return result; + } } } - // Step0.2: rewrite indices - for (PrimExpr value : indices) { - result.indices.push_back(rewriter.Rewrite(value)); - if (result.errors.size()) { - return result; + // Step0.2: Rewrite indices in the second round. + if (!allow_padding || rewriter.requires_padding()) { + rewrite_indices.clear(); + for (PrimExpr value : indices) { + rewrite_indices.push_back(rewriter.Rewrite(value)); + if (result->errors.size() > 0) { + return result; + } } } - - result.requires_padding = rewriter.requires_padding(); - result.padding_predicate = rewriter.padding_predicate(); + result->padding_predicate = rewriter.padding_predicate(); // Step1: IterIndependenceChecker checks if the iterator are independent. - if (!rewriter.CheckMapping(result.indices, require_bijective)) { - if (require_bijective) { - result.errors.push_back("Index mapping does not form a bijective transform."); + if (!rewriter.CheckMapping(rewrite_indices, check_level)) { + if (check_level == IterMapLevel::Bijective) { + result->errors.push_back("Index mapping does not form a bijective transform."); } else { - result.errors.push_back("Mapped indices are not independent."); + result->errors.push_back("Mapped indices are not independent."); } return result; } - + result->indices = rewrite_indices; return result; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") .set_body_typed([](const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool is_bijective, + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, + return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, simplify_trivial_iterators); }); @@ -1246,15 +1275,17 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o auto split = Downcast(dividend); return IterSumExpr({split}, make_zero(split.dtype())); } else if (dividend->IsInstance()) { - auto opt_fused = TryFuseIters(Downcast(dividend)); + auto sum = Downcast(dividend); + if (sum->args.size() <= 1) { + return sum; + } + auto opt_fused = TryFuseIters(sum); if (!opt_fused) { ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) << ", can't be written as a single fused IterSum"; return IterSumExpr(); } - IterSumExpr fused = opt_fused.value(); - ICHECK_EQ(fused->args.size(), 1U); return fused; } else { @@ -1263,140 +1294,159 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } } +/*! \brief Find approximate least common multiplier. */ +PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { + auto fsplit = [](const PrimExpr& e) -> std::pair { + if (const IntImmNode* imm = e.as()) { + return {1, imm->value}; + } + PVar pv; + PVar pc; + if ((pv * pc).Match(e) || (pc * pv).Match(e)) { + return {pv.Eval(), pc.Eval()->value}; + } else { + return {e, 1}; + } + }; + auto p1 = fsplit(a); + auto p2 = fsplit(b); + auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); + if (analyzer->CanProveEqual(p1.first, p2.first)) { + return p1.first * const_lcm; + } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) { + return p1.first * const_lcm; + } else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) { + return p2.first * const_lcm; + } else { + return (p1.first * p2.first) * const_lcm; + } +} + std::pair IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, PrimExpr base, PrimExpr divisor) { // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor // If FloorMod: (((source//lower_factor) % extent) + base) % divisor - PrimExpr lookup_key = split; - - auto modified_divisor = [&]() { - if (update_iterator_padding_) { - return divisor; - } - - auto it = padded_iter_map_.find(lookup_key); - if (it == padded_iter_map_.end()) { - return divisor; - } - - const std::vector& divisors = it->second.divisors; - PrimExpr largest_divisor = divisor; - for (const auto& other : divisors) { - if (CanProveDivisible(other, largest_divisor)) { - // New one is bigger, use it - largest_divisor = other; - } else if (CanProveDivisible(largest_divisor, other)) { - // Current is bigger, keep it - } else { - ErrorLogger(this) << "Iterator appears in multiple terms with incompatible divisors " - << tvm::PrettyPrint(largest_divisor) << " and " - << tvm::PrettyPrint(other); - } - } - return largest_divisor; - }(); - - divisor = modified_divisor; - // First, adding any padding that is on the lower side of a - // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0 - // when iter==0. - - PrimExpr left_pad; - - if (is_zero(base)) { - // Padding on the left is unnecessary if base is known to be zero. - left_pad = make_zero(base->dtype); - } else { - left_pad = analyzer_->Simplify(floormod(base, divisor)); - } + // FloorDiv/FloorMod, such that floormod(split - left_pad, divisor) == 0 + // when iter == 0. + PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor)); // Next, adding any padding that is on the upper side of a - // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, divisor) == 0 - // when iter==extent. - + // FloorDiv/FloorMod, such that floormod(left_pad + split + right_pad, divisor) == 0 + // when iter == extent. PrimExpr right_edge = left_pad + split->extent; PrimExpr right_pad; - if (CanProveDivisible(right_edge, divisor)) { - // Padding on the right is unnecessary if the extent is a multiple of - // the divisor. right_pad = 0; } else { right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); } - if (is_zero(left_pad) && is_zero(right_pad)) { - return {split, left_pad}; - } - + const IterMark& mark = split->source; if (update_iterator_padding_) { // In the first pass, the primary goal is to collect all the divisors - // that may be used for padding. These will impact the divisor used - // to determine padding in the second pass. - IterPaddingInfo& info = padded_iter_map_[lookup_key]; - - info.divisors.push_back(divisor); - - PrimExpr padded_extent = left_pad + split->extent + right_pad; - - IterSumExpr as_sum({split}, left_pad); - IterMark mark(as_sum, padded_extent); - IterSplitExpr new_split(mark); - - return {new_split, left_pad}; + // that may be used for padding. These will impact the divisor used + // to determine padding in the second pass. We try add padding to + // split's source iteraton mark thus all splits under the same mark will + // share the same padded source iteration. + auto& info = padded_iter_map_[mark]; + info.padding_factor = + ApproxLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_); + + // If the split itself require no padding, return directly. + if (is_zero(left_pad) && is_zero(right_pad)) { + return {split, 0}; + } + + // Update padding requirement on the lower side of the source iter mark. + // In the second pass, all splits would check whether the maximum left pading + // on the iter mark is compatible with it's own left padding. + requires_padding_ = true; + PrimExpr mark_left_pad = left_pad * split->lower_factor; + info.left_pad = max(info.left_pad, mark_left_pad); + + // Since we only care the extent in the first pass's result + // we just create result of compatible padded extent, ignoring + // possible relations between different padded iters. + PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad); + split.CopyOnWrite()->extent = padded_extent; + return {split, left_pad}; } - // Any padding that is required during parsing should have been found - // during the first pass that determines the GCD. - auto it = padded_iter_map_.find(lookup_key); + // In the second pass, update iteration mark's to padded form + auto it = padded_iter_map_.find(mark); if (it == padded_iter_map_.end()) { - ErrorLogger(this) << "Dividend has extent " << tvm::PrettyPrint(split->extent) << " and offset " - << tvm::PrettyPrint(base) << ", which requires padding for divisor " - << tvm::PrettyPrint(divisor) << "."; - return {IterSplitExpr(), left_pad}; + return {split, left_pad}; } - IterPaddingInfo& info = it->second; - - if (info.padded.defined()) { - // A previous visit already applied padding to this iterator. - // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`). - ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad)); - ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad)); - - return {info.padded, left_pad}; + auto& info = it->second; + if (is_zero(info.left_pad) && CanProveDivisible(mark->extent, info.padding_factor)) { + // the iter mark requires no padding + return {split, left_pad}; } - // This is the first encounter with the iterator during the second pass. - IterSumExpr as_sum({split}, left_pad); - IterMark mark(as_sum, left_pad + split->extent + right_pad); - info.padded = IterSplitExpr(mark); - info.left_pad = left_pad; - info.right_pad = right_pad; - - auto left_padding_introduced = (left_pad != 0); - // Equivalent to (0 <= split < left_pad), but easier to simplify in - // terms of the transformed variables. - auto left_padding_predicate = - left_padding_introduced && (floordiv(info.padded, divisor) == floordiv(base, divisor) && - floormod(info.padded, divisor) < left_pad); - - PrimExpr nparts = ceildiv(right_edge, divisor); - - auto right_padding_introduced = (right_pad != 0); - - // Equivalent to (right_edge <= split < right_edge+right_pad), but - // easier to simplify in terms of the transformed variables. - auto right_padding_predicate = right_padding_introduced && - (floordiv(info.padded, divisor) == floordiv(right_edge, divisor) && - floormod(info.padded, divisor) >= floormod(right_edge, divisor)); - - requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); - padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + // check that padding factor is compatible with current split and divisor + ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) + << "The padding factor " << info.padding_factor << " is not divisible by " + << split->lower_factor << " for the split " << split; + ICHECK(CanProveDivisible(info.padding_factor, divisor)) + << "The padding factor " << info.padding_factor << " is not divisible by " << divisor + << " for the split " << split; + + if (!info.padded.defined()) { + // the first time encounter the iter mark to pad, update the padded mark. + PrimExpr mark_left_pad = info.left_pad; + if (CanProveDivisible(mark_left_pad, split->lower_factor)) { + // correct current split's left padding + // (mark_left_pad + iter) // lower_factor % extent => + // (left_pad * lower_factor + mark) // lower_factor % extent => + // (left_pad + mark // lower_factor) % extent => + // left_pad + (mark // lower_factor % extent) => + // left_pad + split + // since the extent covers the full padding range. + left_pad = floordiv(mark_left_pad, split->lower_factor); + } else { + ErrorLogger(this) << "Detect incompatible left padding on " + << tvm::PrettyPrint(NormalizeIterMapToExpr(split)) + << ", the iter mark is left padded with " << mark_left_pad; + return {IterSplitExpr(), PrimExpr()}; + } - return {info.padded, left_pad}; + PrimExpr right_edge = mark->extent + mark_left_pad; + PrimExpr mark_right_pad; + if (CanProveDivisible(right_edge, info.padding_factor)) { + mark_right_pad = 0; + } else { + mark_right_pad = floormod(-right_edge, info.padding_factor); + } + PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad); + info.right_pad = mark_right_pad; + info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); + padded_origin_map_[info.padded] = mark; + + auto left_padding_introduced = (mark_left_pad != 0); + + // Equivalent to (0 <= split < left_pad), but easier to simplify in + // terms of the transformed variables. + auto left_padding_predicate = + left_padding_introduced && + (floordiv(info.padded->source, info.padding_factor) == 0 && + floormod(info.padded->source, info.padding_factor) < mark_left_pad); + auto right_padding_introduced = (mark_right_pad != 0); + + // Equivalent to (right_edge <= split < right_edge + right_pad), but + // easier to simplify in terms of the transformed variables. + auto right_padding_predicate = + right_padding_introduced && (floordiv(info.padded->source, info.padding_factor) == + floordiv(right_edge, info.padding_factor) && + floormod(info.padded->source, info.padding_factor) >= + floormod(right_edge, info.padding_factor)); + padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); + } + split.CopyOnWrite()->source = info.padded; + split.CopyOnWrite()->extent = floordiv(info.padded->extent, split->lower_factor); + return {split, left_pad}; } PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { @@ -1462,7 +1512,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), /* scale = */ padded->scale); - auto new_base = floordiv(base - left_pad, rhs); + auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); if (is_zero(new_base)) { return std::move(new_split); } else { @@ -1540,7 +1590,6 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below // where x=floormod(floordiv(iter, lower_factor), extent) + base - auto pair = PadDividendToDivisor(lhs, base, rhs); IterSplitExpr padded = pair.first; if (!padded.defined()) { @@ -1671,19 +1720,20 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, - const PrimExpr& input_pred, bool require_bijective) { + const PrimExpr& input_pred, IterMapLevel check_level) { if (!IterRangeSanityCheck(input_iters)) return indices; Analyzer analyzer; - Array rewrite = - DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer); + Array rewrite = res->indices; + if (rewrite.empty()) { return indices; } - Array res; - res.reserve(rewrite.size()); + Array simplified; + simplified.reserve(rewrite.size()); IterMapToExprNormalizer converter(&analyzer); - for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); - return res; + for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); + return simplified; } /*! @@ -1963,10 +2013,10 @@ class SubspaceDivider { Array> SubspaceDivide(const Array& bindings, const Map& input_iters, const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective, arith::Analyzer* analyzer) { + IterMapLevel check_level, arith::Analyzer* analyzer) { if (!IterRangeSanityCheck(input_iters)) return Array>(); - const Array& maps = - DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); + auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer); + const Array& maps = res->indices; if (maps.empty()) return {}; std::unordered_set inner_iter_set; @@ -1993,10 +2043,10 @@ Array> SubspaceDivide(const Array& bindings, TVM_REGISTER_GLOBAL("arith.SubspaceDivide") .set_body_typed([](const Array& bindings, const Map& root_iters, - const Array& sub_iters, const PrimExpr& predicate, - bool require_bijective) { + const Array& sub_iters, const PrimExpr& predicate, int check_level) { arith::Analyzer ana; - return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), + &ana); }); class InverseAffineIterMapTransformer { @@ -2128,5 +2178,7 @@ Map InverseAffineIterMap(const Array& iter_map, TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); +TVM_REGISTER_NODE_TYPE(IterMapResultNode); + } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 7d1f315b3cb3..6abcc728fc8d 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -203,6 +203,8 @@ class PVar : public Pattern> { return value_; } + T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; } + protected: /*! \brief The matched value */ mutable T value_; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index dab78c77a0a1..f9e38dee48e5 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -776,26 +776,32 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3), c1.Eval()->value > 0 && c3.Eval()->value > 0); - if (floordiv(x * c1, c2).Match(ret)) { + if (floordiv(x * c1 + y, c2).Match(ret) || floordiv(x * c1, c2).Match(ret) || + floordiv(y + x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - if (c1val > 0 && c2val > 0) { - if (c1val % c2val == 0) return (x * floordiv(c1, c2)).Eval(); - if (c2val % c1val == 0) return floordiv(x, floordiv(c2, c1)).Eval(); + PrimExpr yval = y.EvalOr(Integer(0)); + if (c2val == 0) return ret; + + // try eliminate residue part + PrimExpr residue = + floordiv(x.Eval() * floormod(c1.Eval(), c2val) + floormod(yval, c2val), c2val); + PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); + auto bound = analyzer_->const_int_bound(residue); + if (bound.defined() && bound->max_value == bound->min_value) { + return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); } - } - if (floordiv(x * c1 + c2, c3).Match(ret)) { - int64_t c1val = c1.Eval()->value; - int64_t c2val = c2.Eval()->value; - int64_t c3val = c3.Eval()->value; - if (c1val > 0 && c3val > 0 && c3val % c1val == 0 && floormod(c2val, c3val) < c1val) { - // assume c3 == a * c1, x == a * y + b, c2 = d * c3 + e then - // (x * c1 + c2) // c3 - // ==> ((a * y + b) * c1 + d * a * c1 + e) // (a * c1) - // ==> y + d + (b * c1 + e) // c3 - // ==> y + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 - // ==> x // (c3 // c1) + (c2 // c3) - return (floordiv(x, floordiv(c3, c1)) + floordiv(c2, c3)).Eval(); + + // try simplify divisor + if (c1val > 0 && c2val > 0 && c2val % c1val == 0 && + CanProveLess(floormod(yval, c2val), c1val)) { + // assume c2 == a * c1, x == a * x' + b, y = d * c2 + e then + // (x * c1 + y) // c2 + // ==> ((a * x' + b) * c1 + d * a * c1 + e) // (a * c1) + // ==> x' + d + (b * c1 + e) // c2 + // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 + // ==> x // (c2 // c1) + (y // c2) + return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; } } @@ -804,28 +810,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -878,6 +868,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); + + TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); } return ret; } @@ -930,22 +922,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here - TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), + c2.Eval()->value != 0); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, c1.Eval()->value > 0 && c2.Eval()->value > 0 && c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), + c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), + c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 258f833a7b21..202b9209da6d 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -110,6 +110,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { return analyzer_->CanProveGreaterEqual(x, val); } + // Whether x < val + bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); } // Whether x == val bool CanProveEqual(const PrimExpr& x, int64_t val) { // TODO(tqchen) refer back to super-analyzer. diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index ccf186634b8a..dffb8b499285 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -75,13 +75,15 @@ inline std::vector ExprSplitAddition(const PrimExpr& expr) { } // Searches for the following types of expr: -// mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki -// mod_l_expr = c +// mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki +// mod_l_expr = c2 // mod_r_expr = k1 * k2 * ... * ki -// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) +// where c1 ~= c2 mod k1 * k2 * ... * ki +// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const PrimExpr& mult_expr, +inline std::pair MergeMulModInner(arith::Analyzer* analyzer, + const PrimExpr& mult_expr, const PrimExpr& mod_l_expr, const PrimExpr& mod_r_expr) { using namespace tir; @@ -119,9 +121,10 @@ inline std::pair MergeMulModInner(const PrimExpr& mult_expr, } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && - expr_equal(inner_div_ptr->a, mod_l_expr)) { + analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) { // Found! - PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; + PrimExpr ret = + no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a; return std::make_pair(true, ret); } else { return std::make_pair(false, PrimExpr()); @@ -204,7 +207,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { std::pair ret = - MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); + MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 77678d829a8e..ba329676b1c3 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -76,17 +76,16 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto padded_iter_map = - DetectPaddedIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ false, &analyzer, - /* simplify_trivial_iterators = */ false); - CHECK(padded_iter_map.errors.empty()) << "Could not parse mapping as sum of iterators. " - << "Error: " << padded_iter_map.errors[0]; + auto padded_iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/arith::IterMapLevel::NoCheck, &analyzer, + /*simplify_trivial_iterators=*/false); + CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " + << "Error: " << padded_iter_map->errors[0]; // Determine expressions for the input variables, in terms of the // output variables. Map inverse_exprs_map = InverseAffineIterMap( - padded_iter_map.indices, Array(output_vars.begin(), output_vars.end())); + padded_iter_map->indices, Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; @@ -94,7 +93,7 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia inverse_exprs.push_back(inverse_exprs_map.at(index)); } - PrimExpr padding_predicate = padded_iter_map.padding_predicate; + PrimExpr padding_predicate = padded_iter_map->padding_predicate; padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); padding_predicate = Substitute(padding_predicate, inverse_exprs_map); @@ -141,14 +140,14 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // indices. arith::Analyzer analyzer; auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, - /* require_bijective = */ true, &analyzer, + /* check_level = */ arith::IterMapLevel::Bijective, &analyzer, /* simplify_trivial_iterators = */ false); - CHECK(iter_map.size()) << "Index transformation was not bijective."; + CHECK(iter_map->indices.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the // output variables. - Map inverse_exprs_map = - InverseAffineIterMap(iter_map, Array(output_vars.begin(), output_vars.end())); + Map inverse_exprs_map = InverseAffineIterMap( + iter_map->indices, Array(output_vars.begin(), output_vars.end())); // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c4719015daa4..83ef6adae3b2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -533,16 +533,16 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va if (loop_var_ranges.empty()) { return true; } - Array results = arith::DetectIterMap( + auto res = arith::DetectIterMap( /*indices=*/realize->iter_values, /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, - /*require_bijective=*/false, + /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/analyzer); - if (results.empty()) { + if (res->indices.empty()) { return false; } - for (const arith::IterSumExpr& sum_expr : results) { + for (const arith::IterSumExpr& sum_expr : res->indices) { const Array& args = sum_expr->args; if (!args.empty() && !is_one(args[0]->scale)) { return false; diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index 993557f8be2f..b0cafac3151f 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -68,17 +68,18 @@ class SplitExprCollector { * \param index The indexing pattern * \param input_iters The input iterators' domain * \param predicate The predicate of the affine map - * \param require_bijective Whether the affine map is required to be bijective + * \param check_level The iter mapping checking level * \param analyzer The analyzer * \return The collected split expressions */ static std::vector Collect(const PrimExpr& index, const Map& input_iters, // const PrimExpr& predicate, // - bool require_bijective, // + arith::IterMapLevel check_level, // arith::Analyzer* analyzer) { - Array iter_sum_exprs = arith::DetectIterMap( - {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer); + arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, + predicate, check_level, analyzer); + const auto& iter_sum_exprs = res->indices; if (iter_sum_exprs.empty()) { return {}; } @@ -149,7 +150,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& // Step 3. Detect the IterSplitExpr of the indexing pattern std::vector split_exprs = SplitExprCollector::Collect( /*index=*/f_flatten_index(indices), input_iters, predicate, - /*require_bijective=*/false, analyzer); + /*check_level=*/arith::IterMapLevel::Surjective, analyzer); if (split_exprs.empty()) { return NullOpt; } diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 7ed80a1c5b8f..4ede2dd90da8 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -258,10 +258,9 @@ Array> CheckSubspaceDivisible(const IRModule& mod, arith::Analyzer* analyzer) { const Block& block = block_realize->block; - Array> division = - arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, - collector.inner_loop_vars, block_realize->predicate, - /*require_bijective=*/false, analyzer); + Array> division = arith::SubspaceDivide( + block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars, + block_realize->predicate, arith::IterMapLevel::Surjective, analyzer); if (division.empty()) { // If we can't do perfect subspace division, check if it is a trivial case of subspace division. diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 2a349f8fe61e..7f1d74ac2021 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -244,7 +244,7 @@ class ScopeReconstructor : private StmtMutator { if (preserve_unit_loops || !is_one(iter_dom->extent)) { Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); loop_vars.push_back(var); - loop_extents.push_back(iter_dom->extent); + loop_extents.push_back(analyzer->Simplify(iter_dom->extent)); iter_values.push_back(iter_dom->min + var); analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent)); } else { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 452f72e7228f..ad15e06e285a 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -552,13 +552,14 @@ class ReverseComputeInliner : public BaseInliner { } } - buffer_load_iter_map_ = arith::DetectIterMap( + auto res = arith::DetectIterMap( /*indices=*/buffer_load_indices_, /*input_iters=*/consumer_iter_doms, /*predicate=*/true, - /*require_bijective=*/true, + /*check_level=*/arith::IterMapLevel::Bijective, /*analyzer=*/&analyzer, /*simplify_trivial_iterators=*/false); + buffer_load_iter_map_ = res->indices; if (buffer_load_iter_map_.empty()) { // Failure: indices of BufferLoad are not bijective affine return false; diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 6da796fc955f..692f68a600ae 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -392,8 +392,9 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, auto iter_map = arith::DetectIterMap( /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true), - /*require_bijective=*/true, &analyzer, /*simplify_trivial_iterators=*/true); - if (iter_map.empty()) { + /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*simplify_trivial_iterators=*/true); + if (iter_map->indices.empty()) { throw NotBijectiveAffineIndexMapError(self->mod, index_map); } @@ -417,7 +418,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters // in the body. - auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars); + auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars); // Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant // zero. for (const auto& iter_var : block_ptr->iter_vars) { diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index dbe6a3bbc0c5..5315b139f0f6 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -115,7 +115,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, - /*require_bijective=*/false); + /*check_level=*/arith::IterMapLevel::Surjective); if (v.same_as(op->iter_values)) { return GetRef(op); } else { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index fe766b921806..d7bfa1c91947 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from xml import dom import tvm import tvm.testing -from tvm import te from tvm.tir import floormod, floordiv @@ -48,56 +48,69 @@ def convert_iter_expr(expr): return tvm.arith.normalize_iter_map_to_expr(expr) -def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): - """Check the sum expr have the right pattern.""" - assert isinstance(sum_expr, tvm.arith.IterSumExpr) - if extent == 1: - assert len(sum_expr.args) == 0 - else: - assert len(sum_expr.args) == 1 - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) - tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) - tvm.testing.assert_prim_expr_equal(sum_expr.base, base) +def assert_iter_sum_pattern( + expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True +): + keys = list(expect_dict.keys()) + res = tvm.arith.detect_iter_map( + keys, + dom_map, + predicate=predicate, + check_level=check_level, + simplify_trivial_iterators=simplify_trivial_iterators, + ) + indices = res.indices + assert len(indices) == len(keys), res.errors + print(indices) + for i, input_iter in enumerate(keys): + spec = expect_dict[input_iter] + ( + extent, + base, + ) = spec[0:2] + scale = spec[2] if len(spec) > 2 else 1 + expect_iter = spec[3] if len(spec) > 3 else None + sum_expr = indices[i] + assert isinstance(sum_expr, tvm.arith.IterSumExpr) + if extent == 1: + assert len(sum_expr.args) == 0 + else: + assert len(sum_expr.args) == 1 + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) + tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) + tvm.testing.assert_prim_expr_equal(sum_expr.base, base) + if expect_iter is not None: + if not isinstance(expect_iter, tvm.arith.IterMapExpr): + sum_expr = convert_iter_expr(sum_expr) + tvm.ir.assert_structural_equal(sum_expr, expect_iter) + + +def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"): + res = tvm.arith.detect_iter_map( + list(iters), dom_map, predicate=predicate, check_level=check_level + ).indices + assert len(res) == 0 def test_trivial(): - x = tvm.tir.Var("x", "int32"), 3 - y = tvm.tir.Var("y", "int32"), 4 - z = tvm.tir.Var("z", "int32"), 1 - - res = tvm.arith.detect_iter_map([x[0], y[0], 3], var_dom([x, y])) + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + dom_map = var_dom([(x, 3), (y, 4), (z, 1)]) - assert len(res) == 3 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - assert_iter_sum_pattern(res[2], 1, 3) - - res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 1, 3) + assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map) + assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map) # not independent - res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) - assert len(res) == 0 + assert_iter_sum_failure([x, x, 3], dom_map) - res = tvm.arith.detect_iter_map( - [x[0], y[0]], var_dom([x, y, z]), require_bijective=True, simplify_trivial_iterators=True + assert_iter_sum_pattern( + {x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True ) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - - res = tvm.arith.detect_iter_map( - [x[0], y[0]], var_dom([x, y, z]), require_bijective=True, simplify_trivial_iterators=False + assert_iter_sum_pattern( + {x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False ) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 4, 0) - - # not bijective - res = tvm.arith.detect_iter_map([x[0], z[0]], var_dom([x, y, z]), require_bijective=True) - assert len(res) == 0 + assert_iter_sum_failure([x, z], dom_map, check_level="bijective") def test_fuse(): @@ -106,42 +119,27 @@ def test_fuse(): c = tvm.tir.SizeVar("c", "int32") c0 = tvm.tir.SizeVar("c0", "int32") - res = tvm.arith.detect_iter_map([y * 3 + 1 + c + x], var_dom([(x, 3), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 12, 1 + c) + assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)])) - res = tvm.arith.detect_iter_map([ifuse([(x, 3), (y, 4)])[0]], var_dom([(x, 3), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 12, 0) + assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)])) # fuse with symbolic factor - res = tvm.arith.detect_iter_map([(y + 1) * c + x], var_dom([(x, c), (y, 4)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 4 * c, c) + assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)])) # duplication - res = tvm.arith.detect_iter_map([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 - - # duplication 2 - res = tvm.arith.detect_iter_map([y, x + 1, y], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 + assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) + assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)])) # factor mismatch - res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)])) - assert len(res) == 0 + assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)])) # simple stride pattern - res = tvm.arith.detect_iter_map([x * 4 + y * 2], var_dom([(x, 3), (y, 2)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 6, 0, scale=2) - tvm.ir.assert_structural_equal(convert_iter_expr(res[0]), (x * 2 + y) * 2) + assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)])) # simple stride pattern with symbolic - res = tvm.arith.detect_iter_map([x * 2 * c0 + y * 2], var_dom([(x, 3), (y, c0)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 3 * c0, 0, scale=2) - tvm.ir.assert_structural_equal(convert_iter_expr(res[0]), (x * c0 + y) * 2) + assert_iter_sum_pattern( + {x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)]) + ) def test_split(): @@ -152,171 +150,138 @@ def test_split(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - res = tvm.arith.detect_iter_map([fld(x, 3), flm(x, 3) * 2 + c1], var_dom([(x, 24)])) + assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)])) - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 8, 0) - assert_iter_sum_pattern(res[1], 3, c1, 2) - - res = tvm.arith.detect_iter_map([fld(x, 6), fld(flm(x, 6), 2), flm(x, 2)], var_dom([(x, 24)])) - - assert len(res) == 3 - assert_iter_sum_pattern(res[0], 4, 0) - assert_iter_sum_pattern(res[1], 3, 0) - assert_iter_sum_pattern(res[2], 2, 0) + assert_iter_sum_pattern( + {fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)]) + ) # simple symbolic bound # TODO(tvm-team) improve symbolic divisible check to enable # more complicated symbolic bound - res = tvm.arith.detect_iter_map([fld(x, c0), flm(x, c0)], var_dom([(x, c1 * c0)])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], c1, 0) - assert_iter_sum_pattern(res[1], c0, 0) - - res = tvm.arith.detect_iter_map([fld(x * 2, 4), flm(x * 2, 4)], var_dom([(x, 8)])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 4, 0, scale=1) - assert_iter_sum_pattern(res[1], 2, 0, scale=2) + assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)])) - res = tvm.arith.detect_iter_map([fld(x * 2, 4) * 4 + flm(x * 2, 4)], var_dom([(x, 8)])) + assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)])) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 8, 0, scale=2) + assert_iter_sum_pattern( + { + fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2), + }, + var_dom([(x, 8)]), + ) - res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) - assert len(res) == 0 + assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) def test_compound(): - x = tvm.tir.Var("x", "int32"), 10 - y = tvm.tir.Var("y", "int32"), 9 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") - xo, xi = isplit(x, 5) - yo, yi = isplit(y, 3) + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) - res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) - - assert len(res) == 2 - assert_iter_sum_pattern(res[0], 18, 0) - assert_iter_sum_pattern(res[1], 5, 0) # reconstruct the pattern manually - mx = tvm.arith.IterMark(x[0], 10) - my = tvm.arith.IterMark(y[0], 9) - + mx = tvm.arith.IterMark(x, 10) + my = tvm.arith.IterMark(y, 9) xoscale = 3 - xiscale = 1 yoscale = 6 yiscale = 1 mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale) - mxi = tvm.arith.IterSplitExpr(mx, 1, 5, xiscale) myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) - mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) - tvm.ir.assert_structural_equal(sz, res[0]) + assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) def test_predicate(): - x = tvm.tir.Var("x", "int32"), 13 - y = tvm.tir.Var("y", "int32"), 10 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") # available contraints # upper bound only - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128 + ) + + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127 + ) # lower bound only - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 6) - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 124, 6) + assert_iter_sum_pattern( + {x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5 + ) + + assert_iter_sum_pattern( + {x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6 + ) # lower bound + upper bound - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128), + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 6) - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127), + + assert_iter_sum_pattern( + {x * 10 + y: (122, 6)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 122, 6) # constraint on one fused iter i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") k = tvm.tir.Var("k", "int32") - res = tvm.arith.detect_iter_map( - [i * 8 + j * 2 + k], + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (88, 1)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), + predicate=tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), ) - assert_iter_sum_pattern(res[0], 88, 1) # constraint on single var - res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10)) - assert_iter_sum_pattern(res[0], 10, 0) + assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10) - # iterations are subparts of constraint, invalid, case 1 - res = tvm.arith.detect_iter_map( + # iterations are subparts of constraint, invalid case 1 + assert_iter_sum_failure( [i, j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100), ) - assert len(res) == 0 - # iterations are subparts of constraint, invalid, case 2 - res = tvm.arith.detect_iter_map( + # iterations are subparts of constraint, invalid case 2 + assert_iter_sum_failure( [i * 128 + j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=i * 16384 + j * 128 + k < 100, ) - assert len(res) == 0 # irrelavant predicate - res = tvm.arith.detect_iter_map( - [i + j], - var_dom([(i, 1)]), - j <= 24, - ) - assert_iter_sum_pattern(res[0], 1, j) + assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24) # constraint on nested fused iters - res = tvm.arith.detect_iter_map( - [i * 8 + j * 2 + k], + assert_iter_sum_pattern( + {i * 8 + j * 2 + k: (22, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25), + predicate=tvm.tir.all( + 1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25 + ), ) - assert_iter_sum_pattern(res[0], 22, 3) # duplicate constraint on one fused iter - res = tvm.arith.detect_iter_map( - [i * 6 + j * 2 + k], + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (66, 2)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), + predicate=tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), ) - assert_iter_sum_pattern(res[0], 66, 2) # duplicate constraint on nested fused iters - res = tvm.arith.detect_iter_map( - [i * 6 + j * 2 + k], + assert_iter_sum_pattern( + {i * 6 + j * 2 + k: (15, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all( + predicate=tvm.tir.all( 1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, @@ -327,15 +292,13 @@ def test_predicate(): i * 6 + j * 2 + k < 18, ), ) - assert_iter_sum_pattern(res[0], 15, 3) # constraint on non-disjoint fused iters should fail - res = tvm.arith.detect_iter_map( + assert_iter_sum_failure( [i * 8 + j * 2 + k], var_dom([(i, 11), (j, 5), (k, 2)]), - tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), + predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), ) - assert len(res) == 0 # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) @@ -347,147 +310,135 @@ def test_predicate(): i3 = tvm.tir.Var("i3", "int32") i4 = tvm.tir.Var("i4", "int32") i5 = tvm.tir.Var("i5", "int32") - res = tvm.arith.detect_iter_map( - [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5], + assert_iter_sum_pattern( + {i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)}, var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), - tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), + predicate=tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), ) - assert_iter_sum_pattern(res[0], 540, 93) # constraint on many disjoint fused iters, case 2 - res = tvm.arith.detect_iter_map( - [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4], + assert_iter_sum_pattern( + {i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)}, var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), - tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10), + predicate=tvm.tir.all( + 3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10 + ), ) - assert_iter_sum_pattern(res[0], 135, 28) # constraint on split iters - res = tvm.arith.detect_iter_map( - [i % 16, i // 16], + assert_iter_sum_pattern( + {i % 16: (7, 3), i // 16: (8, 4)}, var_dom([(i, 1024)]), - tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), - require_bijective=True, + predicate=tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), + check_level="bijective", ) - assert_iter_sum_pattern(res[0], 7, 3) - assert_iter_sum_pattern(res[1], 8, 4) # constraint on split iters, nested case 1 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (7, 3)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), + predicate=tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), ) - assert_iter_sum_pattern(res[0], 7, 3) # constraint on split iters, nested case 2 - res = tvm.arith.detect_iter_map( - [(i * 32 + j) % 16], + assert_iter_sum_failure( + [ + (i * 32 + j) % 16, + ], var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + check_level="bijective", ) - assert len(res) == 0 - res = tvm.arith.detect_iter_map( - [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16], + assert_iter_sum_pattern( + {(i * 32 + j) % 16: (16, 0)}, var_dom([(i, 5), (j, 32)]), - tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + ) + assert_iter_sum_pattern( + {(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)}, + var_dom([(i, 5), (j, 32)]), + predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), ) - assert_iter_sum_pattern(res[0], 16, 0) - assert_iter_sum_pattern(res[1], 4, 0) # non-standard form of predicate - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0]) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) + assert_iter_sum_pattern( + {x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y + ) # duplicate constraint - res = tvm.arith.detect_iter_map( - [x[0] * 10 + y[0]], - var_dom([x, y]), - tvm.tir.all(x[0] * 10 + y[0] < 128, x[0] * 10 + y[0] < 64), + assert_iter_sum_pattern( + {x * 10 + y: (64, 0)}, + var_dom([(x, 13), (y, 10)]), + predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 64, 0) - # useless constraint - res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 140) - - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 130, 0) + assert_iter_sum_pattern( + {x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140 + ) - i1 = tvm.tir.Var("i1", "int32"), 7 - i2 = tvm.tir.Var("i2", "int32"), 2 - i3 = tvm.tir.Var("i3", "int32"), 4 - i4 = tvm.tir.Var("i4", "int32"), 3 - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + i1 = tvm.tir.Var("i1", "int32") + i2 = tvm.tir.Var("i2", "int32") + i3 = tvm.tir.Var("i3", "int32") + i4 = tvm.tir.Var("i4", "int32") + assert_iter_sum_pattern( + {i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)}, + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 10, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, ) ), ) - assert len(res) == 1 - assert_iter_sum_pattern(res[0], 128, 0) - - i1 = tvm.tir.Var("i1", "int32"), 7 - i2 = tvm.tir.Var("i2", "int32"), 2 - i3 = tvm.tir.Var("i3", "int32"), 4 - i4 = tvm.tir.Var("i4", "int32"), 3 # wrong constraint - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 7, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 7, ) ), ) - assert len(res) == 0 # incompatible constraint - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i3[0] * 3 + i4[0] < 10, - i1[0] * 4 + i3[0] < 20, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i3 * 3 + i4 < 10, + i1 * 4 + i3 < 20, ) ), ) - assert len(res) == 0 - - res = tvm.arith.detect_iter_map( - [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], - var_dom([i1, i2, i3, i4]), - ( + assert_iter_sum_failure( + [i1 * 20 + i2 * 10 + i3 * 3 + i4], + var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), + predicate=( tvm.tir.all( - i1[0] * 2 + i2[0] < 13, - i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, - i1[0] * 4 + i3[0] < 20, + i1 * 2 + i2 < 13, + i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, + i1 * 4 + i3 < 20, ) ), ) - assert len(res) == 0 # zero iter - xo = tvm.tir.Var("xo", "int32"), 1 - xi = tvm.tir.Var("xi", "int32"), 129 - y = tvm.tir.Var("y", "int32"), 128 - - res = tvm.arith.detect_iter_map( - [xo[0] * 129 + xi[0], y[0]], var_dom([xo, xi, y]), xo[0] * 129 + xi[0] < 128 + xo = tvm.tir.Var("xo", "int32") + xi = tvm.tir.Var("xi", "int32") + y = tvm.tir.Var("y", "int32") + assert_iter_sum_pattern( + {xo * 129 + xi: (128, 0), y: (128, 0)}, + var_dom([(xo, 1), (xi, 129), (y, 128)]), + predicate=xo * 129 + xi < 128, ) @@ -554,9 +505,10 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + assert_iter_sum_pattern + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.2 @@ -568,9 +520,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[1][0], 0) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.3 @@ -589,9 +541,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) tvm.ir.assert_structural_equal(res[2][1], True) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.5 @@ -607,9 +559,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], True) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.6 @@ -644,9 +596,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.2 @@ -662,9 +614,11 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])) + res1 = tvm.arith.detect_iter_map( + [res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3]) + ).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices assert len(res2) == 3 # compound 2.3 @@ -692,9 +646,9 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.5 @@ -730,13 +684,6 @@ def test_complex(): i0 = ifuse([j0, j1], 200) i1 = ifuse([j2, j3], 50) - res = tvm.arith.detect_iter_map( - [i0[0], i1[0]], - var_dom([l0, l1, n0, n1, m1, l3]), - tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), - ) - assert len(res) == 2 - n0_mark = tvm.arith.IterMark(n0[0], n0[1]) n1_mark = tvm.arith.IterMark(n1[0], n1[1]) l0_mark = tvm.arith.IterMark(l0[0], l0[1]) @@ -784,16 +731,20 @@ def test_complex(): i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) - tvm.ir.assert_structural_equal(i0_final, res[0]) - tvm.ir.assert_structural_equal(i1_final, res[1]) + assert_iter_sum_pattern( + {i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)}, + var_dom([l0, l1, n0, n1, m1, l3]), + predicate=tvm.tir.all( + i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15 + ), + ) # wrong constraint - res = tvm.arith.detect_iter_map( + assert_iter_sum_failure( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), ) - assert len(res) == 0 # subspace_division res = tvm.arith.subspace_divide( @@ -822,34 +773,33 @@ def test_complex(): ), ) - res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([n0, n1, m1, l3]), res[2][1]) - assert len(res1) == 2 - res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([l0, l1])) - assert len(res2) == 2 + assert_iter_sum_pattern( + {res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1] + ) + assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1])) def test_normalize_iter_map_to_expr(): fld = tvm.tir.floordiv flm = tvm.tir.floormod - x = tvm.tir.Var("x", "int32"), 10 - y = tvm.tir.Var("y", "int32"), 9 + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") - xo, xi = isplit(x, 5) - yo, yi = isplit(y, 3) + xo, xi = isplit((x, 10), 5) + yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) - - res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y])) + res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)])) tvm.ir.assert_structural_equal( - tvm.arith.normalize_iter_map_to_expr(res[0]), - fld(y[0], 3) * 6 + fld(x[0], 5) * 3 + flm(y[0], 3), + tvm.arith.normalize_iter_map_to_expr(res.indices[0]), + fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3), ) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5)) # iter mark wrap a complex expr - split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1) - tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1) + split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1) + tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1) def test_inverse_affine_iter_map(): @@ -863,7 +813,9 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 4) l0_1_l1_1_fused = ifuse([l0_1, l1_1]) - iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) + iter_map = tvm.arith.detect_iter_map( + [l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1]) + ).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 @@ -882,7 +834,7 @@ def test_inverse_affine_iter_map(): iter_map = tvm.arith.detect_iter_map( [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) - ) + ).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 @@ -902,7 +854,7 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 8) l2 = ifuse([l1_1, l1_0]) - iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 @@ -918,12 +870,11 @@ def test_free_variables(): z = tvm.tir.Var("z", "int32") # illegal iter if z is within dom - res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) - assert len(res) == 0 + assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) # iter is valid if z is free, even there are linear forms of z - res = tvm.arith.detect_iter_map( - [z * 19 + y * 3 + x], + assert_iter_sum_pattern( + {z * 19 + y * 3 + x: (9, z * 19)}, var_dom( [ (x, 3), @@ -931,9 +882,8 @@ def test_free_variables(): ] ), ) - assert_iter_sum_pattern(res[0], 9, z * 19) - res = tvm.arith.detect_iter_map( - [z * z + y * 3 + x], + assert_iter_sum_pattern( + {z * z + y * 3 + x: (9, z * z)}, var_dom( [ (x, 3), @@ -941,7 +891,105 @@ def test_free_variables(): ] ), ) - assert_iter_sum_pattern(res[0], 9, z * z) + + +def test_padding(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + fld = tvm.tir.floordiv + flm = tvm.tir.floormod + + # left padding only, offset divisible + sum = 64 + y + dom_map = var_dom([(y, 192)]) + assert_iter_sum_pattern( + {fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)}, + dom_map, + check_level="bijective", + ) + + # left padding only, offset non-divisible + sum = 80 + y + dom_map = var_dom([(y, 176)]) + assert_iter_sum_pattern( + {fld(sum, 32): (6, 2, 1)}, + dom_map, + ) + assert_iter_sum_pattern( + {flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)}, + dom_map, + ) + assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map) + assert_iter_sum_failure({fld(sum, 32), fld(sum, 4)}, dom_map) + + # right padding only, offset divisible + sum = x * 32 + y * 8 + dom_map = var_dom([(x, 5), (y, 4)]) + assert_iter_sum_pattern( + {fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)}, + dom_map, + ) + assert_iter_sum_failure({fld(sum, 5)}, dom_map) + + # right padding only, offset non-divisible + dom_map = var_dom([(x, 26)]) + assert_iter_sum_pattern( + {fld(x, 15): (2, 0, 1)}, + dom_map, + ) + assert_iter_sum_pattern( + {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}, + dom_map, + ) + + # padding constants on both side + sum = x + 71 + dom_map = var_dom([(x, 45)]) + assert_iter_sum_pattern({fld(sum, 32): (2, 2, 1)}, dom_map) + assert_iter_sum_pattern( + {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}, + dom_map, + ) + + # padding for free iteration part + sum = x * 360 + y + dom_map = var_dom([(y, 360)]) + assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map) + assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map) + + # multiple split with same mark offset, could + # be surjective on missing (padded // LCM) + assert_iter_sum_pattern( + { + flm(x + 10, 3): (3, 0), + flm(fld(x + 10, 3), 4): (4, 0), + flm(fld(fld(x + 10, 3), 4), 5): (5, 0), + }, + var_dom([(x, 240)]), + ) + assert_iter_sum_failure( + { + flm(x + 10, 3), + flm(fld(x + 10, 3), 4), + flm(fld(fld(x + 10, 3), 4), 5), + fld(fld(fld(x + 10, 3), 4), 5), + }, + var_dom([(x, 240)]), + ) + + # different offsets on splits + assert_iter_sum_pattern( + { + flm(x + 1, 3): (3, 0), + flm(fld(x + 10, 3) + 2, 4): (4, 0), + flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0), + }, + var_dom([(x, 240)]), + ) + + # original extent is smaller than the divident + # it is not surjective wrt to the region [0, 16) + assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)])) if __name__ == "__main__": diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 8d26710f40db..82e1372f991e 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -459,11 +459,13 @@ def test_div_index_simplify(): def test_floordiv_index_simplify(): # short name for floordiv fld = tvm.te.floordiv + flm = tvm.te.floormod ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") ck.verify(fld(fld(x, 2), 3), fld(x, 6)) ck.verify(fld(fld(x, 2) + 1, 3), fld(x + 2, 6)) + ck.verify(fld(x - flm(x, 21), 21), fld(x, 21)) ck.verify(fld(x * 2, 4), fld(x, 2)) ck.verify(fld(x * 4, 2), x * 2) @@ -472,11 +474,17 @@ def test_floordiv_index_simplify(): ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + -1, 16)) ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True) + ck.verify(fld(x * 360 + y, 16), x * 22) + ck.verify(fld(x * 360 + y, 25), x * 14) + ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)) + ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))) ck.verify(fld(tvm.te.max(x * 6, y), 2), tvm.te.max(x * 3, fld(y, 2))) - ck.verify(fld(y + x * 4, 2), fld(y, 2) + x * 2) + ck.verify(fld(y + x * 4, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)) ck.verify(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)) @@ -549,15 +557,17 @@ def test_mod_index_simplify(): def test_floormod_index_simplify(): # short name for floordiv flm = tvm.te.floormod - ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") ck = RewriteChecker() x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), te.var("z") ck.verify(flm(x * 10, 2), 0) + ck.verify(flm(x * 9600, 6400), flm(x * 3200, 6400)) ck.verify(flm(x * 10 + y, 2), flm(y, 2)) + ck.verify(flm(x * 360 + y, 16), flm(x * 8 + y, 16)) ck.verify(flm(x + 10, 2), flm(x, 2)) ck.verify(flm(x + y * 10, 2), flm(x, 2)) + ck.verify(flm(x + y * 360, 16), flm(x + y * 8, 16)) ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1) ck.verify(flm(x * (-10), 2), 0) ck.verify(flm(x * (-10) + y, 2), flm(y, 2)) diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 337f9cbc0722..10e827978cc0 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -137,6 +137,7 @@ def assert_simplified_equal(index_simplified, index_direct): idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod + # Test Case1 index_simplified = A_stride.offset_of( (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1) @@ -174,7 +175,7 @@ def assert_simplified_equal(index_simplified, index_direct): j = te.size_var("j") k = te.size_var("k") - index_simplified = B.offset_of( + index_simplified1 = B.offset_of( ( idxd(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), idxm(idxd(idxd((i * 50176 + j * 28672 + k), 1024), 14), 14), @@ -182,8 +183,17 @@ def assert_simplified_equal(index_simplified, index_direct): idxm((i * 50176 + j * 28672 + k), 1024), ) ) + index_simplified2 = B.offset_of( + ( + idxd(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14), + idxm(idxd(i * 49 + j * 28 + idxd(k, 1024), 14), 14), + idxm(i * 7 + idxd(k, 1024), 14), + idxm(k, 1024), + ) + ) index_direct = B.offset_of((0, 0, 0, (i * 50176 + j * 28672 + k))) - assert_simplified_equal(index_simplified, index_direct) + assert_simplified_equal(index_simplified1, index_direct) + assert_simplified_equal(index_simplified2, index_direct) @tvm.testing.requires_llvm diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index b06dcebe1d1c..f477367adfad 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1249,6 +1249,44 @@ def test_compute_at_simplify_static_bound(): verify_trace_roundtrip(sch=sch, mod=static_bound) +def test_compute_at_non_perfect_channel_group(): + @T.prim_func + def grouped_channel_bias( + X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] + ): + B = T.alloc_buffer([45], dtype="float32", scope="") + for i in T.grid(45): + with T.block("init"): + vi = T.axis.remap("S", [i]) + B[vi] = vi + for c_o, h, w, c_i in T.grid(2, 8, 8, 360): + with T.block("compute"): + hh, ww = T.axis.remap("SS", [h, w]) + cc = T.axis.spatial(720, c_o * 360 + c_i) + Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] + + @T.prim_func + def grouped_channel_bias_non_perfect_tiled( + X: T.Buffer[(720, 8, 8), "float32"], Y: T.Buffer[(720, 8, 8), "float32"] + ): + B = T.alloc_buffer([45], dtype="float32") + for c_o in range(2): + for ax0 in range(23): + with T.block("init"): + vi = T.axis.spatial(45, c_o * 22 + ax0) + B[vi] = vi + for h, w, c_i in T.grid(8, 8, 360): + with T.block("compute"): + hh, ww = T.axis.remap("SS", [h, w]) + cc = T.axis.spatial(720, c_o * 360 + c_i) + Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16] + + sch = tir.Schedule(grouped_channel_bias, debug_mask="all") + loop = sch.get_loops(sch.get_block("compute"))[0] + sch.compute_at(sch.get_block("init"), loop) + tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled) + + def test_fail_subtree_complete_block(): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0")