diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 885c23f49186..e64426aca3db 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -59,6 +59,22 @@ enum DivMode { kFloorDiv }; +/*! + * \brief The strength used in top-level condition proves + * \note The higher, the more time consuming it can be. + * + * Do not use level beyond kDefault in internal recursive rewriting in arith + * analysis and only use it at top-level simplification to avoid speed issues. + */ +enum class ProofStrength : int { + /*! \brief default strength, can be used in. */ + kDefault = 0, + /*! + * \brief Prove using symbolic bound analysis + */ + kSymbolicBound = 1 +}; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. @@ -656,11 +672,16 @@ class TVM_DLL Analyzer { * \brief Whether can we prove condition. * * \param cond The expression to be proved. + * \param strength the strength of the prove. + * * \return The result. * * \note Analyzer will call into sub-analyzers to get the result. + * Do not use strength beyond default in sub-analyzers and + * only use it in top-level predicate analysis. */ - bool CanProve(const PrimExpr& cond); + bool CanProve(const PrimExpr& cond, ProofStrength strength = ProofStrength::kDefault); + /*! * \brief Simplify expr. * diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 60d7c53d28e8..f09564d050ca 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -85,6 +85,22 @@ class IntSet : public ObjectRef { bool IsEverything() const; /*! \return Whether the set is a single point */ bool IsSinglePoint() const; + /*! + * \brief Check if we can prove it is a single point. + * + * Unlike IsSinglePoint, which only checks ptr equality + * this function will invoke analyzer to do stonger proofs + * but also takes longer time. + * + * Use this function in some of the primitives but do not + * use it in the inner loop of simplification. + * + * \param ana Analyzer used in the proof. + * \return Whether we can prove it is a single point + */ + bool CanProveSinglePoint(Analyzer* ana) const; + // TODO(tvm-team): update all CanProve to explicitly take + // analyzer to encourage more analyzer reuse /*! \return Whether the set is proved to be bigger than 0 */ bool CanProvePositive() const; /*! \return Whether the set is proved to be smaller than 0 */ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 423aafe5d69f..401836aa1968 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -23,7 +23,7 @@ estimate_region_strict_bound, estimate_region_upper_bound, ) -from .analyzer import ModularSet, ConstIntBound, Analyzer +from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 28adbe9d815f..5ea2dfad9dc6 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -15,11 +15,19 @@ # specific language governing permissions and limitations # under the License. """Arithmetic data structure and utility""" +from enum import IntEnum import tvm._ffi from tvm.runtime import Object from . import _ffi_api +class ProofStrength(IntEnum): + """Proof strength of the analysis""" + + DEFAULT = 0 + SYMBOLIC_BOUND = 1 + + @tvm._ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" @@ -91,6 +99,7 @@ def __init__(self): self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") self._can_prove_equal = _mod("can_prove_equal") + self._can_prove = _mod("can_prove") def const_int_bound(self, expr): """Find constant integer bound for expr. @@ -190,6 +199,24 @@ def int_set(self, expr, dom_map): """ return self._int_set(expr, dom_map) + def can_prove(self, expr, strength=ProofStrength.DEFAULT): + """Check whether we can prove expr to be true. + + Parameters + ---------- + expr : PrimExpr + The expression. + + strength: ProofStrength + The proof strength + + Returns + ------- + result : Expr + The result. + """ + return self._can_prove(expr, strength) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 4714cf1df59f..89dcb8301a1b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -115,15 +115,47 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { return CanProve(lhs - rhs == 0); } -bool Analyzer::CanProve(const PrimExpr& expr) { +bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // Avoid potentially expensive simplification unless required. if (const auto* ptr = expr.as()) { return ptr->value != 0; } - PrimExpr simplified = Simplify(expr); const int64_t* as_int = tir::as_const_int(simplified); - return as_int && *as_int; + if (as_int && *as_int) return true; + if (strength >= ProofStrength::kSymbolicBound) { + // NOTE: we intentionally only pattern match common bound predicate i < bound + // and put this implementation at the top-level. + // This is to avoid repeatitive calling of this function + // that causes speed issues. + // This strategy can only be called from top-level and not from sub-analyzers. + Optional pos_diff; + int lower_bound = 0; + if (const auto* ptr_lt = expr.as()) { + pos_diff = ptr_lt->b - ptr_lt->a; + lower_bound = 1; + } + if (const auto* ptr_le = expr.as()) { + pos_diff = ptr_le->b - ptr_le->a; + lower_bound = 0; + } + if (const auto* ptr_gt = expr.as()) { + pos_diff = ptr_gt->a - ptr_gt->b; + lower_bound = 1; + } + if (const auto* ptr_ge = expr.as()) { + pos_diff = ptr_ge->a - ptr_ge->b; + lower_bound = 0; + } + if (pos_diff) { + IntSet iset = this->int_set(this->Simplify(pos_diff.value())); + if (iset.HasLowerBound()) { + ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min())); + if (relaxed_lower_bound->min_value >= lower_bound) return true; + } + } + } + return false; } PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { @@ -189,6 +221,11 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu self->Bind(args[0], args[1].operator PrimExpr()); } }); + } else if (name == "can_prove") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + int strength = args[1]; + *ret = self->CanProve(args[0], static_cast(strength)); + }); } else if (name == "enter_constraint_context") { return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { // can't use make_shared due to noexcept(false) decl in destructor, diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a75d316a7ece..1ad182aa8351 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -492,6 +492,11 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const CastNode* op) final { IntervalSet value_set = this->Eval(op->value); + // short cut for the int set. + if (value_set->min_value.same_as(value_set->max_value)) { + if (value_set->IsEmpty()) return value_set; + return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value)); + } PrimExpr min_value = value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); PrimExpr max_value = @@ -723,6 +728,13 @@ bool IntSet::IsSinglePoint() const { return (s_int && s_int->IsSinglePoint()); } +bool IntSet::CanProveSinglePoint(Analyzer* ana) const { + const IntervalSetNode* s_int = (*this).as(); + if (!s_int) return false; + if (s_int->IsSinglePoint()) return true; + return ana->CanProveEqual(s_int->min_value, s_int->max_value); +} + bool IntSet::CanProvePositive() const { Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); @@ -943,9 +955,15 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map) { } IntSet IntSet::Vector(PrimExpr x) { - Analyzer ana; - Map dmap; - return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + // short cut: simply get single point + if (x.dtype().lanes() == 1) { + return IntSet::SinglePoint(x); + } else { + // vector case. + Analyzer ana; + Map dmap; + return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + } } IntSet EvalSet(PrimExpr e, const Map& dom_map) { diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 98fe5bdc2bc6..dc40fa9d4dee 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -60,12 +60,11 @@ class IntervalSetNode : public IntSetNode { bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ bool IsSinglePoint() const { - if (min_value.same_as(max_value)) { - return true; - } - Analyzer analyzer; - return analyzer.CanProveEqual(min_value, max_value); + // NOTE: we are only doing cheap check as this is a frequently called routine, + // do manual prove of min and max for stronger single point check. + return min_value.same_as(max_value); } + /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 40a5977ec54c..c9acc8f751a6 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -150,6 +150,12 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr // try to prove x equals val CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) { + // NOTE on implementation: this function can be called many times and can be a bottleneck, + // As a result, we keep comparison here lightweight. + // We only do constant int bound analysis here. + // + // For stronger comparison proof that is out of the recursive simplifcation + // consider look at analyzer::CanProveStrong PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { @@ -176,6 +182,8 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val if (dbound->max_value <= val) { return CompareResult::kLE; } + + // modular analysis if (val == 0) { ModularSet dmod = analyzer_->modular_set(diff); if (dmod->base != 0) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b8e7fcdd9433..22e7a0b74c40 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -104,7 +104,6 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // maximum number of recursion allowed during a single pass. static const constexpr int kMaxRecurDepth = 5; - /*! * \brief try to compare x against val. * \param x The expression to be evaluated. diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 409356c2b155..057cec475d84 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; + /*!\ brief Internal analyzer. */ + arith::Analyzer ana_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -318,7 +320,7 @@ Array BlockReadWriteDetector::CollectRegions( ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - if (range.IsSinglePoint()) { + if (range.CanProveSinglePoint(&ana_)) { PrimExpr min = range.min(); region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); } else { diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 988c73c3f071..75ea308de8a3 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -455,8 +455,9 @@ void UpdateBlockVarDomainDimwise( arith::IntSet required = required_region[i]; PrimExpr dim_max = max(buffer->shape[i] - 1, 0); - if (provided.IsSinglePoint() && is_const_int(provided.min())) { - ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min())); + if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) { + ICHECK(required.CanProveSinglePoint(analyzer) && + analyzer->CanProveEqual(provided.min(), required.min())); continue; } @@ -515,7 +516,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& std::unordered_map* iter_doms) { // we only support single point provided region now, which could cover most cases for (const auto& intset : provided_region) { - if (!intset.IsSinglePoint()) return false; + if (!intset.CanProveSinglePoint(analyzer)) return false; } // calculate forward mapping (block vars -> provided region point) Map dom_map; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index d9c58a038103..a26843b7bd05 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -430,7 +430,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; - if (!analyzer.CanProve(predicate)) { + if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } // Step 4. Generate nested loops to replace the original loop and simplify the binding diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index aa9d5179aa3f..754bf36d7ab2 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -34,5 +34,36 @@ def test_simplify_reshape_flattened_index(): ) +def test_simplify_symbolic_comparison(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64") + outer = (n + 31) // 32 + ana.bind(i0, tvm.ir.Range(0, outer)) + ana.bind(i1, tvm.ir.Range(0, 32)) + PS = tvm.arith.ProofStrength + + assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND) + + +def test_regression_simplify_inf_recursion(): + ana = tvm.arith.Analyzer() + cond = tir.Var("cond", "int32") + + res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype( + "int32" + ) == 0 + # regression in a previous case + # try compare and int set recursive call can cause infinite loop + ana.rewrite_simplify(res) + + if __name__ == "__main__": tvm.testing.main()