Skip to content

Commit

Permalink
- check incompatible left paddings
Browse files Browse the repository at this point in the history
- determine case like x % 16, x in [0, 5) to be non-surjective, since usages may treat the region extent as 16 by mistake.
- skip second round of rewrite when there is no padding
- fix some typo in comments
  • Loading branch information
wrongtest-intellif committed May 28, 2022
1 parent 28daf94 commit f24db1d
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 65 deletions.
7 changes: 5 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,11 @@ class IterMapResultNode : public Object {
*/
class IterMapResult : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IterMapResult, ObjectRef, IterMapResultNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMapResultNode);
// constructor
IterMapResult() { data_ = make_object<IterMapResultNode>(); }

/*! \return mutable pointers to the node. */
IterMapResultNode* operator->() const { return static_cast<IterMapResultNode*>(get_mutable()); }
};

/*!
Expand Down
152 changes: 92 additions & 60 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,17 @@ class IterMapRewriter : public ExprMutator {
}

PrimExpr padding_predicate() const { return padding_predicate_; }
PrimExpr requires_padding() const { return !analyzer_->CanProveEqual(padding_predicate_, 0); }
bool requires_padding() const { return requires_padding_; }

IterSumExpr Rewrite(const PrimExpr& expr) {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}

void UpdatePadding(const PrimExpr& expr) {
IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) {
update_iterator_padding_ = true;
DirectMutate(expr);
auto res = Rewrite(expr);
update_iterator_padding_ = false;
return res;
}

IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
Expand Down Expand Up @@ -425,27 +426,30 @@ class IterMapRewriter : public ExprMutator {
// input iter marks
std::vector<IterMark> input_marks_;

// Map from a split iter to the padded iterator information for
// Map from an iter mark to the padded iterator information for
// it. This is necessary for introducing the same padding in all
// usage of an input iterator. (e.g. (i-1) occurring in the
// expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be
// left-padded by 31 for each occurrence.)
std::unordered_map<IterMark, IterPaddingInfo, StructuralHash, StructuralEqual> padded_iter_map_;

/* If allow_padding_ is true, allow the extents of the IterMap to be
// Map from padded iter mark to it's origin mark
std::unordered_map<IterMark, IterMark, StructuralHash, StructuralEqual> padded_origin_map_;

/* If update_iterator_padding_ is true, allow the extents of the IterMap to be
* padded beyond the original iterators.
*
* For example, if allow_padding_ is true, the expressions i//4 and
* For example, if update_iterator_padding_ is true, the expressions i//4 and
* i%4, where i is on the range [0,18), would be represented as
* IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4).
* This representation would be forbidden if allow_padding_ is false,
* This representation would be forbidden if update_iterator_padding_ is false,
* because lower_factor=4 does not evenly divide the original extent of
* 18.
*/
bool update_iterator_padding_{false};

/* A boolean expression that is true for any padding that has been
* introduced, and false otherwise. If allow_padding_ is false,
* introduced, and false otherwise. If update_iterator_padding_ is false,
* padding_predicate_ will always be false.
*
* Example: [i//4, i%4], i in range [0,16)
Expand All @@ -459,6 +463,11 @@ class IterMapRewriter : public ExprMutator {
*/
PrimExpr padding_predicate_;

/* A boolean flag denotes there are padding iterations detected
* in the first round of indices rewriting.
*/
bool requires_padding_{false};

// The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
// an extra offset)
// Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
Expand Down Expand Up @@ -488,25 +497,6 @@ class IterMapRewriter : public ExprMutator {
// The flattened forms of constrained iters
std::vector<IterSumExpr> constrained_iters_flattened_;

/*!
* \brief Extract original iteration mark's extent before padding, return NullOpt is
* there is no extra padding.
*/
Optional<PrimExpr> ExtractExtentBeforePadding(const IterMark& mark, Analyzer* analyzer) {
const IterSumExprNode* sum = mark->source.as<IterSumExprNode>();
if (!sum || sum->args.size() != 1) {
return NullOpt;
}
IterSplitExpr split = sum->args[0];
if (!analyzer->CanProveEqual(split->extent, mark->extent) &&
analyzer->CanProveEqual(split->scale, 1) &&
analyzer->CanProveEqual(split->lower_factor, 1) &&
analyzer->CanProveEqual(split->source->extent, split->extent)) {
return sum->args[0]->extent;
}
return NullOpt;
}

/*!
* \brief Look for a split in splits that is not used such that its lower_factor is smallest.
* Note that here we use division to compare lower_factor.
Expand Down Expand Up @@ -580,9 +570,9 @@ class IterMapRewriter : public ExprMutator {
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}

// Extract padding info of the iteration mark, extent before padding
// is only defined when padding exists.
Optional<PrimExpr> extent_before_padding = ExtractExtentBeforePadding(mark, analyzer_);
// Extract iteration mark info before padding
auto pad_mark_it = padded_origin_map_.find(mark);
bool has_padding = pad_mark_it != padded_origin_map_.end();

bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent);
bool match_iter_divisor =
Expand All @@ -598,33 +588,46 @@ class IterMapRewriter : public ExprMutator {
//
// Case 3. bijective is not required and there exists padding. We check either
// (3.1) The extent we calculate is consistent with the extent of the padded mark and it is
// the single split for the iter mark.
// the single split for the iter mark.
// For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective
// according to how we pad the original iteration mark.
// (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent
// before padding is greater or equal than the extent we calculate.
// For example, padded iter p in [0, 24), the original extent is 14, [(p % 12)] is
// valid.
// before padding is greater or equal than the extent we calculate.
// For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24.
//
if (check_level == IterMapLevel::Bijective) {
if (extent_before_padding.defined() || !match_full_iter) {
return Array<IterSplitExpr>();
if (has_padding) {
ErrorLogger(this) << "Bijectvie mapping should not take iter paddings";
return {};
} else if (!match_full_iter) {
ErrorLogger(this) << "The iterations do not traverse full iter space";
return {};
}
} else if (!extent_before_padding.defined()) {
} else if (!has_padding) {
if (!match_iter_divisor) {
return Array<IterSplitExpr>();
ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
return {};
}
} else if (check_level == IterMapLevel::Surjective) {
PrimExpr extent_before_padding = pad_mark_it->second->extent;
if (match_full_iter) {
if (splits.size() != 1) {
ErrorLogger(this) << "Dependent iterations on padding iter space";
return Array<IterSplitExpr>();
} else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) &&
!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
ErrorLogger(this) << "Split on padding iteration is not surjective "
<< "if the split extent equals to the full iter space extent";
return Array<IterSplitExpr>();
}
} else if (match_iter_divisor) {
if (!analyzer_->CanProve(extent_before_padding.value() >= expected_lower_factor)) {
if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
ErrorLogger(this) << "The extent before padding is less than lower factor";
return Array<IterSplitExpr>();
}
} else {
return Array<IterSplitExpr>();
ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
return {};
}
}
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
Expand Down Expand Up @@ -1056,22 +1059,21 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, IterMapLevel check_level,
arith::Analyzer* analyzer, bool simplify_trivial_iterators) {
IterMapResult result_obj = IterMapResult(make_object<IterMapResultNode>());
auto result = result_obj.CopyOnWrite();
IterMapResult result;

// Overall detection algorithm is divided into two steps:
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
if (!IterRangeSanityCheck(input_iters)) {
result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other.");
return result_obj;
return result;
}
Map<Var, Range> constrained_input_iters = input_iters;
std::vector<IterConstraint> constraints;
if (!is_one(predicate) &&
!MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) {
result->errors.push_back("Could not parse predicate as constraints on the input iterators.");
return result_obj;
return result;
}
// We have to make sure when we visit an iterator, all the constraints related with its successors
// in the iter var graph has been visited, where the expression of this iterator will contain the
Expand All @@ -1090,32 +1092,39 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
constraint.upper_bound);
if (result->errors.size()) {
return result_obj;
if (result->errors.size() > 0) {
return result;
}
}
if (!rewriter.CheckConstraints()) {
result->errors.push_back("Invalid constraints.");
return result_obj;
return result;
}

// Step0.1: Check each index to determine required padding
// Step0.1: Rewrite indicies and determine required padding,
// if there is no padding, it should be the final result.
Array<IterSumExpr> rewrite_indices;
rewrite_indices.reserve(indices.size());
bool allow_padding = check_level != IterMapLevel::Bijective;
if (allow_padding) {
for (PrimExpr value : indices) {
rewriter.UpdatePadding(value);
rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value));
if (result->errors.size() > 0) {
return result;
}
}
}

// Step0.2: rewrite indices
Array<IterSumExpr> rewrite_indices;
for (PrimExpr value : indices) {
rewrite_indices.push_back(rewriter.Rewrite(value));
if (result->errors.size()) {
return result_obj;
// Step0.2: Rewrite indices in the second round.
if (!allow_padding || rewriter.requires_padding()) {
rewrite_indices.clear();
for (PrimExpr value : indices) {
rewrite_indices.push_back(rewriter.Rewrite(value));
if (result->errors.size() > 0) {
return result;
}
}
}

result->padding_predicate = rewriter.padding_predicate();

// Step1: IterIndependenceChecker checks if the iterator are independent.
Expand All @@ -1125,10 +1134,10 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
} else {
result->errors.push_back("Mapped indices are not independent.");
}
return result_obj;
return result;
}
result->indices = rewrite_indices;
return result_obj;
return result;
}

TVM_REGISTER_GLOBAL("arith.DetectIterMap")
Expand Down Expand Up @@ -1304,6 +1313,10 @@ PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyze
auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second));
if (analyzer->CanProveEqual(p1.first, p2.first)) {
return p1.first * const_lcm;
} else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) {
return p1.first * const_lcm;
} else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) {
return p2.first * const_lcm;
} else {
return (p1.first * p2.first) * const_lcm;
}
Expand Down Expand Up @@ -1348,6 +1361,9 @@ std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSpl
}

// Update padding requirement on the lower side of the source iter mark.
// In the second pass, all splits would check whether the maximum left pading
// on the iter mark is compatible with it's own left padding.
requires_padding_ = true;
PrimExpr mark_left_pad = left_pad * split->lower_factor;
info.left_pad = max(info.left_pad, mark_left_pad);

Expand Down Expand Up @@ -1381,6 +1397,22 @@ std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSpl
if (!info.padded.defined()) {
// the first time encounter the iter mark to pad, update the padded mark.
PrimExpr mark_left_pad = info.left_pad;
if (CanProveDivisible(mark_left_pad, split->lower_factor)) {
// correct current split's left padding
// (mark_left_pad + iter) // lower_factor % extent =>
// (left_pad * lower_factor + mark) // lower_factor % extent =>
// (left_pad + mark // lower_factor) % extent =>
// left_pad + (mark // lower_factor % extent) =>
// left_pad + split
// since the extent covers the full padding range.
left_pad = floordiv(mark_left_pad, split->lower_factor);
} else {
ErrorLogger(this) << "Detect incompatible left padding on "
<< tvm::PrettyPrint(NormalizeIterMapToExpr(split))
<< ", the iter mark is left padded with " << mark_left_pad;
return {IterSplitExpr(), PrimExpr()};
}

PrimExpr right_edge = mark->extent + mark_left_pad;
PrimExpr mark_right_pad;
if (CanProveDivisible(right_edge, info.padding_factor)) {
Expand All @@ -1391,6 +1423,7 @@ std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSpl
PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad);
info.right_pad = mark_right_pad;
info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent);
padded_origin_map_[info.padded] = mark;

auto left_padding_introduced = (mark_left_pad != 0);

Expand Down Expand Up @@ -1557,7 +1590,6 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P

// We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below
// where x=floormod(floordiv(iter, lower_factor), extent) + base

auto pair = PadDividendToDivisor(lhs, base, rhs);
IterSplitExpr padded = pair.first;
if (!padded.defined()) {
Expand Down Expand Up @@ -1676,7 +1708,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs
PrimExpr divisor = normalizer.Convert(rhs);

return analyzer_->CanProveEqual(dividend, divisor) ||
analyzer_->CanProve(analyzer_->Simplify(floormod(dividend, divisor), 8) == 0);
analyzer_->CanProve(floormod(dividend, divisor) == 0);
}

PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) {
Expand Down
Loading

0 comments on commit f24db1d

Please sign in to comment.