Skip to content

Commit

Permalink
[ARITH] Enhance CanProve to handle symbolic bound
Browse files Browse the repository at this point in the history
This PR enhances CanProve to handle symbolic bound.
Such analysis is essential to eliminate predicates in
dynamic shape workloads.

We also the int set analysis singlepoint check to avoid recursion
and improve the overall analysis speed.

Added CanProveSinglePoint to serve previous stronger checks.

The new CanProve comes with additinal strength argument
that can only be used in top-level setting with stronger analysis.

Added comment for future implementation efficiency.

Testcases are added to cover the cases.
  • Loading branch information
tqchen committed Apr 8, 2023
1 parent af39b34 commit 57fdeb3
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 19 deletions.
23 changes: 22 additions & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ enum DivMode {
kFloorDiv
};

/*!
* \brief The strength used in top-level condition proves
* \note The higher, the more time consuming it can be.
*
* Do not use level beyond kDefault in internal recursive rewriting in arith
* analysis and only use it at top-level simplification to avoid speed issues.
*/
enum class ProofStrength : int {
/*! \brief default strength, can be used in. */
kDefault = 0,
/*!
* \brief Prove using symbolic bound analysis
*/
kSymbolicBound = 1
};

/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
Expand Down Expand Up @@ -656,11 +672,16 @@ class TVM_DLL Analyzer {
* \brief Whether can we prove condition.
*
* \param cond The expression to be proved.
* \param strength the strength of the prove.
*
* \return The result.
*
* \note Analyzer will call into sub-analyzers to get the result.
* Do not use strength beyond default in sub-analyzers and
* only use it in top-level predicate analysis.
*/
bool CanProve(const PrimExpr& cond);
bool CanProve(const PrimExpr& cond, ProofStrength strength = ProofStrength::kDefault);

/*!
* \brief Simplify expr.
*
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ class IntSet : public ObjectRef {
bool IsEverything() const;
/*! \return Whether the set is a single point */
bool IsSinglePoint() const;
/*!
* \brief Check if we can prove it is a single point.
*
* Unlike IsSinglePoint, which only checks ptr equality
* this function will invoke analyzer to do stonger proofs
* but also takes longer time.
*
* Use this function in some of the primitives but do not
* use it in the inner loop of simplification.
*
* \param ana Analyzer used in the proof.
* \return Whether we can prove it is a single point
*/
bool CanProveSinglePoint(Analyzer* ana) const;
// TODO(tvm-team): update all CanProve to explicitly take
// analyzer to encourage more analyzer reuse
/*! \return Whether the set is proved to be bigger than 0 */
bool CanProvePositive() const;
/*! \return Whether the set is proved to be smaller than 0 */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
estimate_region_strict_bound,
estimate_region_upper_bound,
)
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
from .int_solver import solve_linear_equations, solve_linear_inequalities
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
# specific language governing permissions and limitations
# under the License.
"""Arithmetic data structure and utility"""
from enum import IntEnum
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api


class ProofStrength(IntEnum):
"""Proof strength of the analysis"""

DEFAULT = 0
SYMBOLIC_BOUND = 1


@tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z"""
Expand Down Expand Up @@ -91,6 +99,7 @@ def __init__(self):
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")
self._can_prove_equal = _mod("can_prove_equal")
self._can_prove = _mod("can_prove")

def const_int_bound(self, expr):
"""Find constant integer bound for expr.
Expand Down Expand Up @@ -190,6 +199,24 @@ def int_set(self, expr, dom_map):
"""
return self._int_set(expr, dom_map)

def can_prove(self, expr, strength=ProofStrength.DEFAULT):
"""Check whether we can prove expr to be true.
Parameters
----------
expr : PrimExpr
The expression.
strength: ProofStrength
The proof strength
Returns
-------
result : Expr
The result.
"""
return self._can_prove(expr, strength)

def bind(self, var, expr):
"""Bind a variable to the expression.
Expand Down
43 changes: 40 additions & 3 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,47 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
return CanProve(lhs - rhs == 0);
}

bool Analyzer::CanProve(const PrimExpr& expr) {
bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// Avoid potentially expensive simplification unless required.
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
}

PrimExpr simplified = Simplify(expr);
const int64_t* as_int = tir::as_const_int(simplified);
return as_int && *as_int;
if (as_int && *as_int) return true;
if (strength >= ProofStrength::kSymbolicBound) {
// NOTE: we intentionally only pattern match common bound predicate i < bound
// and put this implementation at the top-level.
// This is to avoid repeatitive calling of this function
// that causes speed issues.
// This strategy can only be called from top-level and not from sub-analyzers.
Optional<PrimExpr> pos_diff;
int lower_bound = 0;
if (const auto* ptr_lt = expr.as<tir::LTNode>()) {
pos_diff = ptr_lt->b - ptr_lt->a;
lower_bound = 1;
}
if (const auto* ptr_le = expr.as<tir::LENode>()) {
pos_diff = ptr_le->b - ptr_le->a;
lower_bound = 0;
}
if (const auto* ptr_gt = expr.as<tir::GTNode>()) {
pos_diff = ptr_gt->a - ptr_gt->b;
lower_bound = 1;
}
if (const auto* ptr_ge = expr.as<tir::GENode>()) {
pos_diff = ptr_ge->a - ptr_ge->b;
lower_bound = 0;
}
if (pos_diff) {
IntSet iset = this->int_set(this->Simplify(pos_diff.value()));
if (iset.HasLowerBound()) {
ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min()));
if (relaxed_lower_bound->min_value >= lower_bound) return true;
}
}
}
return false;
}

PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
Expand Down Expand Up @@ -189,6 +221,11 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
self->Bind(args[0], args[1].operator PrimExpr());
}
});
} else if (name == "can_prove") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
int strength = args[1];
*ret = self->CanProve(args[0], static_cast<ProofStrength>(strength));
});
} else if (name == "enter_constraint_context") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
Expand Down
24 changes: 21 additions & 3 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {

IntervalSet VisitExpr_(const CastNode* op) final {
IntervalSet value_set = this->Eval(op->value);
// short cut for the int set.
if (value_set->min_value.same_as(value_set->max_value)) {
if (value_set->IsEmpty()) return value_set;
return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value));
}
PrimExpr min_value =
value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf();
PrimExpr max_value =
Expand Down Expand Up @@ -723,6 +728,13 @@ bool IntSet::IsSinglePoint() const {
return (s_int && s_int->IsSinglePoint());
}

bool IntSet::CanProveSinglePoint(Analyzer* ana) const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
if (!s_int) return false;
if (s_int->IsSinglePoint()) return true;
return ana->CanProveEqual(s_int->min_value, s_int->max_value);
}

bool IntSet::CanProvePositive() const {
Analyzer analyzer;
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
Expand Down Expand Up @@ -943,9 +955,15 @@ IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
}

IntSet IntSet::Vector(PrimExpr x) {
Analyzer ana;
Map<Var, IntSet> dmap;
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
// short cut: simply get single point
if (x.dtype().lanes() == 1) {
return IntSet::SinglePoint(x);
} else {
// vector case.
Analyzer ana;
Map<Var, IntSet> dmap;
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
}
}

IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
Expand Down
9 changes: 4 additions & 5 deletions src/arith/interval_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,11 @@ class IntervalSetNode : public IntSetNode {
bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); }
/*! \return Whether the interval is a single point. */
bool IsSinglePoint() const {
if (min_value.same_as(max_value)) {
return true;
}
Analyzer analyzer;
return analyzer.CanProveEqual(min_value, max_value);
// NOTE: we are only doing cheap check as this is a frequently called routine,
// do manual prove of min and max for stronger single point check.
return min_value.same_as(max_value);
}

/*! \return whether interval represent nothing */
bool IsEmpty() const {
// during computations, either extreme could occur.
Expand Down
8 changes: 8 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr

// try to prove x equals val
CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) {
// NOTE on implementation: this function can be called many times and can be a bottleneck,
// As a result, we keep comparison here lightweight.
// We only do constant int bound analysis here.
//
// For stronger comparison proof that is out of the recursive simplifcation
// consider look at analyzer::CanProveStrong
PrimExpr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImmNode>()) {
if (ptr->value == val) {
Expand All @@ -176,6 +182,8 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
if (dbound->max_value <= val) {
return CompareResult::kLE;
}

// modular analysis
if (val == 0) {
ModularSet dmod = analyzer_->modular_set(diff);
if (dmod->base != 0) {
Expand Down
1 change: 0 additions & 1 deletion src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {

// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;

/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
Expand Down
4 changes: 3 additions & 1 deletion src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*!\ brief Internal analyzer. */
arith::Analyzer ana_;

/*!
* \brief Update read/write buffers and regions with provided buffer and region
Expand Down Expand Up @@ -318,7 +320,7 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
for (size_t j = 0; j < regions[i].size(); j++) {
const tvm::arith::IntSet& range = regions[i][j];
if (range.IsSinglePoint()) {
if (range.CanProveSinglePoint(&ana_)) {
PrimExpr min = range.min();
region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1)));
} else {
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ void UpdateBlockVarDomainDimwise(
arith::IntSet required = required_region[i];
PrimExpr dim_max = max(buffer->shape[i] - 1, 0);

if (provided.IsSinglePoint() && is_const_int(provided.min())) {
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) {
ICHECK(required.CanProveSinglePoint(analyzer) &&
analyzer->CanProveEqual(provided.min(), required.min()));
continue;
}

Expand Down Expand Up @@ -515,7 +516,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array<IterVar>&
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
// we only support single point provided region now, which could cover most cases
for (const auto& intset : provided_region) {
if (!intset.IsSinglePoint()) return false;
if (!intset.CanProveSinglePoint(analyzer)) return false;
}
// calculate forward mapping (block vars -> provided region point)
Map<Var, Range> dom_map;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
&opaque_block_reuse)(std::move(new_stmt));
// Step 3. Update predicate to guard the loop
PrimExpr predicate = substitute_value < loop->extent;
if (!analyzer.CanProve(predicate)) {
if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) {
new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt));
}
// Step 4. Generate nested loops to replace the original loop and simplify the binding
Expand Down
31 changes: 31 additions & 0 deletions tests/python/unittest/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,36 @@ def test_simplify_reshape_flattened_index():
)


def test_simplify_symbolic_comparison():
ana = tvm.arith.Analyzer()

i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
outer = (n + 31) // 32
ana.bind(i0, tvm.ir.Range(0, outer))
ana.bind(i1, tvm.ir.Range(0, 32))
PS = tvm.arith.ProofStrength

assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)


def test_regression_simplify_inf_recursion():
ana = tvm.arith.Analyzer()
cond = tir.Var("cond", "int32")

res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype(
"int32"
) == 0
# regression in a previous case
# try compare and int set recursive call can cause infinite loop
ana.rewrite_simplify(res)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 57fdeb3

Please sign in to comment.