Skip to content

Commit

Permalink
[TIR][CompactBufferAllocation] Improve upperbound estimation of buffe…
Browse files Browse the repository at this point in the history
…r compaction (#12527)

Hi, this change wants to add some minor updation to region estimator used by buffer compaction:
- Add and clearify among `EstimateRegionStrictBound`, `EstimateRegionLowerBound` and `EstimateRegionUpperBound`
   
  Originally we have `EstimateRegionLowerBound`, actually it implements strict bound estimation IMO. Now add `upper` and `strict` version for where we actually want them.

- When estimating upperbounds (eg. in buffer compaction), try estimate each dimension independently when they are dependent accesses where `EstimateRegionLowerBound` is expected to fail. 

  Eg, `A[i, i], 3 < i < 16`  fails via `EstimateRegionLowerBound` who check indices be independent. But we can still try best to invoke strict bound analysis on each dimension individually.

- If range->extent == 1 for `EvalSet(range, dom)`, invoke `EvalSet(range->min, dom)` instead.
  
  Eg, `EvalSet([k*k, k*k+1), dom_k)` results to [-inf, +inf] due to current algorithm limitation but  `EvalSet(k*k, dom_k)` results to a range which makes more sense.
  • Loading branch information
wrongtest-intellif authored Aug 24, 2022
1 parent 989e5a1 commit 1ec2c36
Show file tree
Hide file tree
Showing 10 changed files with 496 additions and 220 deletions.
39 changes: 38 additions & 1 deletion include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,29 @@ Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);
IntSet Intersect(const Array<IntSet>& sets);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* \brief Converts the Ranges to IntSets
* \param var_dom The ranges of variables
* \return The integer sets of the variables
*/
Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate.
* The result should be strict, i.e. no region is discarded or relaxed.
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
* \param analyzer The analyzer used
* \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis
*/
TVM_DLL Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate.
* Some subregion may be discarded during the lower-bound analysis.
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
Expand All @@ -273,6 +295,21 @@ TVM_DLL Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& reg
const PrimExpr& predicate,
arith::Analyzer* analyzer);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added
* to the result.
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
* \param analyzer The analyzer used
* \return an array of arith::IntSet as the result of analysis
*/
TVM_DLL Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SET_H_
8 changes: 7 additions & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
# under the License.
"""Integer bound analysis, simplification and pattern detection."""

from .int_set import IntSet, IntervalSet, estimate_region_lower_bound
from .int_set import (
IntSet,
IntervalSet,
estimate_region_lower_bound,
estimate_region_strict_bound,
estimate_region_upper_bound,
)
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
Expand Down
48 changes: 48 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, min_value, max_value):

def estimate_region_lower_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Some subregion may be discarded during the lower-bound analysis.
Parameters
----------
Expand All @@ -103,6 +104,53 @@ def estimate_region_lower_bound(region, var_dom, predicate):
return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate)


def estimate_region_strict_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
The result should be strict, i.e. no region is discarded or relaxed.
Parameters
----------
region : List[Range]
The region to be analyzed.
var_dom : Dict[Var, Range]
The ranges of the variables
predicate : PrimExpr
The predicate for the affine map
Returns
----------
region_int_set : Optional[List[IntSet]]
None if the detection fails, or an array of IntSets as the result of analysis
"""
return _ffi_api.EstimateRegionStrictBound(region, var_dom, predicate)


def estimate_region_upper_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Relaxation of the region may be used in upper-bound analysis,
i.e. some extra region may be added to the result.
Parameters
----------
region : List[Range]
The region to be analyzed.
var_dom : Dict[Var, Range]
The ranges of the variables
predicate : PrimExpr
The predicate for the affine map
Returns
----------
region_int_set : List[IntSet]
an array of IntSets as the result of analysis
"""
return _ffi_api.EstimateRegionUpperBound(region, var_dom, predicate)


def pos_inf():
"""Returns the symbolic positive infinity
Expand Down
131 changes: 107 additions & 24 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,9 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map<const VarNode*, IntSet>& dom

IntSet EvalSet(Range r, const Map<Var, IntSet>& dom_map) {
Analyzer ana;
if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) {
return EvalSet(r->min, dom_map);
}
IntervalSetEvaluator m(&ana, dom_map);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
PrimExpr sum = r->min + r->extent - 1;
Expand Down Expand Up @@ -1035,15 +1038,57 @@ IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) {
return EvalSet(r, ConvertDomMap(dom_map));
}

Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate, Analyzer* analyzer) {
Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
Map<Var, arith::IntSet> result;
for (auto kv : var_dom) {
const Var& var = kv.first;
const Range& range = kv.second;
result.Set(var, arith::IntSet::FromRange(range));
}
return result;
}

/*! \brief Helper function to convert IterSumExpr to the actual touched range. */
static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
Analyzer* analyzer) {
if (iter_min->args.empty()) {
return IntSet::FromMinExtent(iter_min->base, extent);
}
ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
const IterSplitExpr& split = iter_min->args[0];
if (!analyzer->CanProve(extent >= split->scale)) {
return NullOpt;
}

const PrimExpr& base = iter_min->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
if (analyzer->CanProve(split->scale < 0)) {
// If scale is negative, the var dom is [(extent - 1) * scale, 0]
// The total base is `base + (extent - 1) * scale`,
// while total extent is `dom_extent + (extent - 1) * (-scale)`
const PrimExpr& var_extent = (split->extent - 1) * split->scale;
return IntSet::FromMinExtent(base + var_extent, extent - var_extent);
} else {
// If scale is positive, the var dom is [0, (extent - 1) * scale]
// The total dom is [base, dom_extent + (extent - 1) * scale]
return IntSet::FromMinExtent(base, extent + (split->extent - 1) * split->scale);
}
}

Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate, Analyzer* analyzer) {
int ndim = region.size();
Array<IterSumExpr> iter_sum_exprs{nullptr};
{
Array<PrimExpr> affine_indices;
affine_indices.reserve(ndim);
for (const Range& range : region) {
if (!is_const_number(range->extent)) {
// dynamic extent is not supported yet.
return NullOpt;
}
affine_indices.push_back(range->min);
}
auto res = DetectIterMap(
Expand All @@ -1060,31 +1105,57 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (int i = 0; i < ndim; ++i) {
const IterSumExpr& sum_expr = iter_sum_exprs[i];
const Range& range = region[i];
if (sum_expr->args.empty()) {
result.push_back(IntSet::FromMinExtent(sum_expr->base, range->extent));
continue;
}
ICHECK_EQ(sum_expr->args.size(), 1);
const IterSplitExpr& split = sum_expr->args[0];
if (!analyzer->CanProve(range->extent >= split->scale)) {
Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer);
if (int_set.defined()) {
result.push_back(int_set.value());
} else {
return NullOpt;
}
}
return result;
}

const PrimExpr& base = sum_expr->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
if (analyzer->CanProve(split->scale < 0)) {
// If scale is negative, the var dom is [(extent - 1) * scale, 0]
// The total base is `base + (extent - 1) * scale`,
// while total extent is `dom_extent + (extent - 1) * (-scale)`
const PrimExpr& var_extent = (split->extent - 1) * split->scale;
result.push_back(IntSet::FromMinExtent(base + var_extent, range->extent - var_extent));
} else {
// If scale is positive, the var dom is [0, (extent - 1) * scale]
// The total dom is [base, dom_extent + (extent - 1) * scale]
result.push_back(
IntSet::FromMinExtent(base, range->extent + (split->extent - 1) * split->scale));
Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer) {
return EstimateRegionStrictBound(region, var_dom, predicate, analyzer);
}

Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region, const Map<Var, Range>& var_dom,
const PrimExpr& predicate, Analyzer* analyzer) {
if (Optional<Array<arith::IntSet>> result = EstimateRegionStrictBound(
/*region=*/region,
/*var_dom=*/var_dom,
/*predicate=*/predicate, /*analyzer=*/analyzer)) {
return result.value();
}
Array<IntSet> result;
result.reserve(region.size());
// try estimate each dimension independently
for (const Range& range : region) {
auto res = DetectIterMap(
/*indices=*/{range->min}, /*input_iters=*/var_dom,
/*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
if (!res->indices.empty()) {
ICHECK_EQ(res->indices.size(), 1U);
IterSumExpr sum_expr = res->indices[0];

// dynamic extent is not supported yet.
PrimExpr extent = range->extent;
if (!is_const_number(extent)) {
IntSet relaxed = EvalSet(extent, AsIntSet(var_dom));
ICHECK(relaxed.HasUpperBound());
extent = relaxed.max();
}

if (Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer)) {
result.push_back(int_set.value());
continue;
}
}
// fallback to coarse grained evalset
result.push_back(EvalSet(range, AsIntSet(var_dom)));
}
return result;
}
Expand Down Expand Up @@ -1118,6 +1189,18 @@ TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound")
Analyzer analyzer;
return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer);
});
TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound")
.set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
PrimExpr predicate) -> Optional<Array<IntSet>> {
Analyzer analyzer;
return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer);
});
TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound")
.set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
PrimExpr predicate) -> Optional<Array<IntSet>> {
Analyzer analyzer;
return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer);
});

TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; });
TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; });
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
runtime::StorageRank rank = scope.rank;
if (rank != previous_rank || !var_dom.defined()) {
previous_rank = rank;
var_dom = AsIntSet(LoopDomainOfSRefTreePath(
var_dom = arith::AsIntSet(LoopDomainOfSRefTreePath(
/*low_inclusive=*/relax_path_low_inclusive,
/*high_exclusive=*/relax_path_high_exclusive,
/*extra_relax_scope=*/scope));
Expand Down
14 changes: 6 additions & 8 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"
#include <tvm/arith/int_set.h>

#include "./utils.h"
namespace tvm {
namespace tir {

Expand All @@ -44,13 +45,10 @@ Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region,
/*low_inclusive=*/dom_low_inclusive,
/*high_exclusive=*/dom_high_exclusive,
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope()));
if (Optional<Array<arith::IntSet>> result = EstimateRegionLowerBound(
/*region=*/region->region,
/*var_dom=*/var_dom,
/*predicate=*/predicate, /*analyzer=*/analyzer)) {
return result.value();
}
return arith::EvalSet(region->region, AsIntSet(var_dom));
return EstimateRegionUpperBound(
/*region=*/region->region,
/*var_dom=*/var_dom,
/*predicate=*/predicate, /*analyzer=*/analyzer);
}

/*!
Expand Down
18 changes: 0 additions & 18 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,24 +249,6 @@ inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
}

/******** Integer set ********/

/*!
* \brief Converts the Ranges to IntSets
* \param var_dom The ranges of variables
* \return The integer sets of the variables
*/
inline Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
std::unordered_map<Var, arith::IntSet, ObjectPtrHash, ObjectPtrEqual> result;
result.reserve(var_dom.size());
for (auto kv : var_dom) {
Var& var = kv.first;
Range& range = kv.second;
result.emplace(std::move(var), arith::IntSet::FromRange(std::move(range)));
}
return {result.begin(), result.end()};
}

/**************** Loop extents ****************/

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
var_dom[GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0));
}
Optional<Array<arith::IntSet>> eval_res =
arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer);
arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer);
if (eval_res.defined()) {
return NDIntSet(eval_res.value().begin(), eval_res.value().end());
}
Expand Down
Loading

0 comments on commit 1ec2c36

Please sign in to comment.