Skip to content

Commit

Permalink
[Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer (#1…
Browse files Browse the repository at this point in the history
…1970)

* [Arith] Allow binding of Var in IntSetAnalyzer

The other four subanalyzers in `arith::Analyzer` can each be provided
with variable bindings/constraints that are remembered internally.
This adds the same capability to `IntSetAnalyzer`, rather than
requiring users to independently track and maintain a `Map<Var,
IntSet>` containing the domain of each variable, and applies
bindings/constraints alongside the other subanalyzers.

* [Arith] Updated IRVisitorWithAnalyzer to mimic IRMutatorWithAnalyzer

Previously, `IRVisitorWithAnalyzer` did not allow subclassing, and
could only be used to collect bounds of variables along an entire
statement, and could not be used to perform scope-dependent analysis.
This commit removes `final` from `IRVisitorWithAnalyzer` and provides
the same scope-based constraints/bindings during iteration as are
provided by `IRMutatorWithAnalyzer`.

* [Arith] Moved IRVisitorWithAnalyzer to tvm::arith namespace

Changing for consistency, since `IRVisitorWithAnalyzer` it is part of
the `src/arith` directory and the analogous `IRMutatorWithAnalyzer` is
already part of the `arith` namespace.

* [Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer

This used the earlier changes to allow subclasses of
`IRVisitorWithAnalyzer`, and to expose binding/constraints to
`IntSetAnalyzer`.

* Avoid accidental Bind with dynamic Range

* [Arith] Do not visit SelectNode in IRVisitorWithAnalyzer

Because both sides of a `Select` node are visited regardless of the
condition, the `SelectNode::condition` should not be treated as a
known value.

* [Arith][IntSet] Track global and scope-dependent bounds separately

Resolves a bug that was found in CI, where an earlier scope-dependent
constraint was treated as a conflict by a later global bound.

* [Arith] Recovery function for each subanalyzer

This way, if a subanalyzer throws an exception during
`EnterConstraint`, the other subanalyzers are still appropriately
backed out of the constraint.

* [Arith][IntSet] Use CanProve instead of CanProveGreaterEqual

The `min_value - max_value` in the `CanProveGreaterEqual` argument can
result in an exception being thrown for unsigned integers where
subtraction would wrap.

* [Arith] Allow vector expressions in IntSet::operator(PrimExpr)

Since these are tracked when lowering expressions, should allow
post-vectorization expressions.

To maintain previous behavior, this only applies when using the
automatically tracked `Map<Var, IntSet> dom_map_`.  If an explicit
domain map is passed, the previous behavior of raising an error for
vectorized expressions still occurs.

* Avoid comparisons between integer and handle datatypes

* [Arith] IntSet, Combine() extension

Previously, the Combine() method didn't handle values without a known
lower bound, for boolean operators.

* Added docstring

* Naming consistency of `IntSetAnalyzer` methods.

To be consistent with other subanalyzers, using "Update" when
providing the analyzer with the same data structure as is used
internally, and "Bind" used when providing it with something that must
be converted to the internal data structure.
  • Loading branch information
Lunderberg authored Jul 13, 2022
1 parent 7d9a07c commit 4b5dd13
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 90 deletions.
46 changes: 41 additions & 5 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param allow_override Whether do we allow override of existing information.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
/*!
Expand Down Expand Up @@ -224,7 +224,7 @@ class ModularSetAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param allow_override Whether do we allow override of existing information.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);

Expand Down Expand Up @@ -263,10 +263,16 @@ class RewriteSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param allow_override Whether do we allow override of existing information.
* \param allow_override Whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
Expand Down Expand Up @@ -297,7 +303,7 @@ class CanonicalSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param allow_override Whether do we allow override of existing information.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

Expand Down Expand Up @@ -347,7 +353,7 @@ class ConstraintContext {
/*! \brief The constraint */
PrimExpr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
std::vector<std::function<void()>> recovery_functions_;
};

/*!
Expand All @@ -365,6 +371,36 @@ class IntSetAnalyzer {
*/
TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);

/*!
* \brief Find a symbolic integer set that contains all possible
* values of expr given the domain of each variables, using
* the domain map defined by bound variables.
*
* \param expr The expression of interest.
* \return the result of the analysis.
*/
TVM_DLL IntSet operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_interval_set The set of allowed values for this var.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_range The range of allowed values for this var.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);

std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
Expand Down
26 changes: 14 additions & 12 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
Expand All @@ -64,22 +66,22 @@ void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
}

void ConstraintContext::EnterWithScope() {
ICHECK(exit_ == nullptr);
ICHECK(recovery_functions_.size() == 0);
// entering the scope.
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
// recovery function.
exit_ = [f0, f1, f2]() {
if (f2 != nullptr) f2();
if (f1 != nullptr) f1();
if (f0 != nullptr) f0();
};
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
}

void ConstraintContext::ExitWithScope() {
ICHECK(exit_ != nullptr);
exit_();
while (recovery_functions_.size()) {
auto& func = recovery_functions_.back();
if (func) {
func();
}
recovery_functions_.pop_back();
}
}

bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
Expand Down
43 changes: 11 additions & 32 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <unordered_map>
#include <unordered_set>

#include "ir_visitor_with_analyzer.h"

namespace tvm {
namespace arith {

Expand All @@ -56,7 +58,7 @@ using BufferDomainAccess = std::tuple<LoadAccess, StoreAccess, CombinedAccess>;
} // namespace

// Find Read region of the tensor in the stmt.
class BufferTouchedDomain final : public StmtExprVisitor {
class BufferTouchedDomain final : public IRVisitorWithAnalyzer {
public:
BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }

Expand Down Expand Up @@ -90,65 +92,42 @@ class BufferTouchedDomain final : public StmtExprVisitor {
return ret;
}

void VisitStmt_(const ForNode* op) final {
const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}

void VisitStmt_(const LetStmtNode* op) final {
dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_);
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(op->var.get());
}

/* TODO: Thread extent unitest not generated.*/
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
ICHECK(thread_axis);
const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
private:
using Parent = IRVisitorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt_;

void VisitExpr_(const BufferLoadNode* op) final {
// Record load-exclusive buffer access
Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitExpr_(op);
Parent::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
// Record store-exclusive buffer access
Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitStmt_(op);
Parent::VisitStmt_(op);
}

private:
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) const {
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
(*bounds)[i].emplace_back(analyzer_.int_set(args[i]));
}
}
}

std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};

Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
Expand Down
Loading

0 comments on commit 4b5dd13

Please sign in to comment.