Skip to content

Commit

Permalink
Fixes for codegen (#18)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
spectrometerHBH and junrushao authored Jan 22, 2022
1 parent 3e4a30e commit dbe279f
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 61 deletions.
112 changes: 67 additions & 45 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}

IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
const PrimExpr& predicate_induced_max) {
IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
const Optional<PrimExpr>& predicate_induced_min,
const Optional<PrimExpr>& predicate_induced_max) {
return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
predicate_induced_max);
}
Expand Down Expand Up @@ -494,14 +495,16 @@ class IterMapRewriter : public ExprMutator {
* \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
* \return The Normalized expression.
*/
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
PrimExpr predicate_induced_max) {
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
Optional<PrimExpr> predicate_induced_max) {
// normalize to zero base
PrimExpr base = expr->base;
if (!is_zero(base)) {
expr.CopyOnWrite()->base = 0;
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
if (predicate_induced_min.defined())
predicate_induced_min = predicate_induced_min.value() - base;
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
Expand All @@ -521,27 +524,28 @@ class IterMapRewriter : public ExprMutator {
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
if (predicate_induced_min.defined()) {
iter_min = max(predicate_induced_min, iter_min);
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
iter_max = min(predicate_induced_max, iter_max);
iter_max = min(predicate_induced_max.value(), iter_max);
}
if (!is_zero(iter_min)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
if (analyzer_->CanProve(iter_min <= iter_max)) {
if (!is_zero(iter_min)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};
// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
return expr;
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};

// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
return expr;
}
Fail(Diagnostic::Error(expr->span)
<< "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min
Expand Down Expand Up @@ -608,7 +612,7 @@ class IterMapRewriter : public ExprMutator {
}
}
}
if (!base_scale) {
if (!base_scale || base_scale.value()->value < 0) {
diag_ctx_.Emit(Diagnostic::Error(expr->span)
<< "Fuse iters failed, can not find a valid base scale");
return NullOpt;
Expand Down Expand Up @@ -770,14 +774,15 @@ class IterMapRewriter : public ExprMutator {
struct IterConstraint {
// The expr of the iter
PrimExpr iter;
// The expr of the lower_bound
PrimExpr lower_bound;
// The expr of the upper_bound
PrimExpr upper_bound;
// The expr of the lower_bound, maybe undefined
Optional<PrimExpr> lower_bound;
// The expr of the upper_bound, maybe undefined
Optional<PrimExpr> upper_bound;
// The size of the iter, which is the number of nodes
size_t expr_size = 0;

IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
size_t size)
: iter(std::move(iter)),
lower_bound(std::move(lower_bound)),
upper_bound(std::move(upper_bound)),
Expand All @@ -787,11 +792,11 @@ struct IterConstraint {
/*!
* \brief Split the predicate into `(a < b) && (c < d) && ...`
* \param pred The predicate to be split.
* \param result The result of predicate split.
* \return A list of IterConstraint, empty if the split failed.
*/
std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
const Map<Var, Range>& input_iters) {
std::vector<IterConstraint> result;
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>& input_iters,
std::vector<IterConstraint>& result) {
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
// try extract comparisions
Expand Down Expand Up @@ -820,14 +825,14 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
is_equal = true;
is_finish = true;
} else {
return std::vector<IterConstraint>();
return false;
}
PrimExpr lhs_expr = lhs.Eval();
PrimExpr rhs_expr = rhs.Eval();
// we only accept predicate of integers
if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
(rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
return std::vector<IterConstraint>();
return false;
}
// determine iter and bound, if we can not distinguish them simply,
// try divide (lhs - rhs) into itervar aware and itervar free parts
Expand Down Expand Up @@ -863,35 +868,49 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
PrimExpr lower_bound, upper_bound, iter;
Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
PrimExpr iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter
// bound > iter / bound >= iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
} else {
// iter > bound
// iter > bound / iter >= bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
}
} else {
if (bound_at_left) {
// bound < iter
// bound < iter / bound <= iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
} else {
// iter < bound
// iter < bound / iter <= bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
}
}
result.emplace_back(iter, lower_bound, upper_bound, 0);
// If it is a predicate for input iters
if (const auto* var_ptr = iter.as<VarNode>()) {
auto it = input_iters.find(GetRef<Var>(var_ptr));
if (it == input_iters.end()) {
return false;
}
PrimExpr iter_min = (*it).second->min;
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
input_iters.Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
} else {
result.emplace_back(iter, lower_bound, upper_bound, 0);
}
if (is_finish) {
break;
}
pred = rest.Eval();
}
return result;
return true;
}

bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
Expand All @@ -911,8 +930,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
if (!is_one(predicate) && constraints.empty()) {
Map<Var, Range> constrained_input_iters = input_iters;
std::vector<IterConstraint> constraints;
if (!is_one(predicate) &&
!MatchBoundConstraints(predicate, constrained_input_iters, constraints)) {
diag_ctx.Emit(Diagnostic::Error(predicate->span)
<< "Fail to collect constraints from iteration predicate: " << predicate);
return Array<IterSumExpr>();
Expand All @@ -929,10 +950,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
IterMapRewriter rewriter(analyzer, constrained_input_iters, diag_ctx);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
constraint.upper_bound);
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
}
if (!rewriter.CheckConstraints()) {
Expand Down
3 changes: 2 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,15 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());

// PHASE 3
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RenormalizeSplitPattern());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
pass_list.push_back(tir::transform::HoistIfThenElse());
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
for (int i = 0; i < n; i++) {
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
substitute_value = substitute_value * factor + var;
if (!is_one(factor)) substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Expand Down
32 changes: 29 additions & 3 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
return result;
}

NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
std::unordered_map<const VarNode*, arith::IntSet>& dom_map,
arith::Analyzer* analyzer) {
std::unordered_map<Var, Range, ObjectPtrHash, ObjectEqual> var_dom;
for (const auto& it : dom_map) {
var_dom[GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0));
}
Optional<Array<arith::IntSet>> eval_res =
arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer);
if (eval_res.defined()) {
NDIntSet res(0);
for (const auto& it : eval_res.value()) res.push_back(it);
return res;
}
return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
}

/*!
* \brief Collect the access region of each buffer.
* \note The param buffer regions will not be collected.
Expand Down Expand Up @@ -149,7 +166,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
return;
}
return StmtExprVisitor::VisitExpr_(op);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BlockNode* op) final {
Expand Down Expand Up @@ -198,6 +215,13 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
}

void VisitStmt_(const BlockRealizeNode* op) final {
PrimExpr cur_predicate = predicate_in_scope;
predicate_in_scope = op->predicate;
StmtExprVisitor::VisitStmt_(op);
predicate_in_scope = cur_predicate;
}

/**************** Helper functions ****************/

void VisitBufferAccess(const BufferRegion& buffer_region) {
Expand All @@ -206,7 +230,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
if (it != buffer_var_in_scope_.end()) {
const Buffer& buffer = it->second.first;
size_t n_ancestor_loops = it->second.second;
NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region);
// Step 1. Stop ancestor loop vars out of the allocation block from
// being relaxed unless NeedRelaxThread() is true.
std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
Expand All @@ -222,7 +245,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
dom_map_.erase(dom_it);
}
// Step 2. Relax the access region
nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_);
NDIntSet nd_int_set =
NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_);
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_loops_[i]->loop_var.get();
Expand Down Expand Up @@ -279,6 +303,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
*/
std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, ObjectPtrEqual>
buffer_var_in_scope_;
/*! \brief The block predicate of current scope */
PrimExpr predicate_in_scope{true};

/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
Expand Down
Loading

0 comments on commit dbe279f

Please sign in to comment.