-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Arith] Merge surjective/non-surjective iter mapping detections #11287
[Arith] Merge surjective/non-surjective iter mapping detections #11287
Conversation
where (or is it neccesary) to write testcase on |
@wrongtest yes we should cover simplifier's behavior, but the rewrite_simplifier testcase should be sufficient for now |
bf9c28d
to
d4d439d
Compare
The failed compute_at's region cover check possibly could get fixed by #11235 improvement on iteration analysis. |
LGTM, let's have #11235 merged first |
To enable region cover proof on such cases, we need to lift |
A gentle ping for @vinx13 |
d4d439d
to
6795cb0
Compare
6795cb0
to
700b702
Compare
@wrongtest Can you elaborate the usage of also cc @Lunderberg for |
Try merge
The #11235 brings great way to analyze iteration form like I think actually, as an example, though
So from my perspective it would be great if we have a uniform interface and share same padding based analysis. Ideally
No, original usages of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I focused on the DetectIterMap
changes, and especially like the merging and de-duplication. Mostly just some nitpicks here and there.
} | ||
|
||
// Step0.1: Check each index to determine required padding | ||
bool allow_padding = !require_bijective; | ||
bool allow_padding = check_level != IterMapLevel::Bijective; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would enable padding for IterMapLevel::Surjective
, which I don't think is correct. Since padding is any output value for which no input value exists, any introduction of padding wouldn't be surjective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the claim~ I try to change padding to iter mark itself.
For example,(x + 7)
x
in [0, 8) => IterMark(IterSplit(IterSum({x}, 7), lower_factor=1, extent=16, scale=1), extent=16
with left_pad=7, right_pad=1
Then (x + 7) // 8
is mapped to range [0, extent//2) == [0, 2), though we have padding into iter mark, the IterSplit's range can be achieved when we only iterate x
in it's original domain: (0 + 7) // 8 = 0, (7 + 7) // 8 = 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, and that does maintain surjectivity for a single index. I'm not entirely sure for the case of two indices, though. For the same x ∈ [0,8)
, the indices [(x+7)//8, (x+7)%8]
would have the same padding left_pad=7
and right_pad=1
. Even though each individual index can take any value in the output ((x+7)//8 ∈[0,2)
and (x+7)%8 ∈ [0,8)
), there are some coordinate pairs that cannot be generated for any value of x
(e.g. [0,0]
and [1,7]
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree! This is where we should be careful. In CheckMapping
with surjective mode when padding exists, we check padded // LCM
and padded % LCM
(or it's sub-splits) must not both exists. The case below depict this check:
sum = 80 + y
dom_map = var_dom([(y, 176)])
# (80 + y) // 32 itself could be surjective
assert_iter_sum_pattern(
{fld(sum, 32): (6, 2, 1)},
dom_map,
)
# (80 + y) % 2, ((80 + y) // 2) % 16) could be surjective,
# since they can be seen as sub-splits of (80 + y) % 32
assert_iter_sum_pattern(
{flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)},
dom_map,
)
# but (80 + y) // 32, (80 + y) % 32 are not surjective
assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map)
Other kinds of negatives like (80 + y) // 32, (80 + y) // 4
would be banned by existing checking rule.
src/arith/iter_affine_map.cc
Outdated
requires_padding_ = requires_padding_ || (left_padding_introduced || right_padding_introduced); | ||
padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); | ||
} | ||
// ICHECK(CanProveDivisible(info.padded->extent, split->lower_factor)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these // ICHECK
lines be either uncommented or removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like to check the padding factor is divisible by split->lower_factor
, then the commented check can be ensured from context. I found it may fail unfortunetely due to simplifier's ability limitation when the padded extent contain complex flm/fld expressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I noticed that there were also some simplification steps that needed to increase the number of iterations performed. Is the failure to prove divisibility related, since CanProveDivisible
only uses the default of 2 steps?
(I'm also wondering if the default for Analyzer::Simplify
should be to iterate until it the simplification converges, rather than using a fixed number of steps.)
Quick note: #11235 is merged |
a1a2086
to
1c15f4d
Compare
f4280f0
to
001ed50
Compare
src/arith/iter_affine_map.cc
Outdated
@@ -1659,7 +1676,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs | |||
PrimExpr divisor = normalizer.Convert(rhs); | |||
|
|||
return analyzer_->CanProveEqual(dividend, divisor) || | |||
analyzer_->CanProve(floormod(dividend, divisor) == 0); | |||
analyzer_->CanProve(analyzer_->Simplify(floormod(dividend, divisor), 8) == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be great to have some explanations here that it need more simplification steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, that is something forget to revert. There is some cases the division could not be proved like
floormod(0 + -x * 8, x) == 0
, floormod(8*c1*c2, c1) == 0
, even we increate iteration num. They get work-around here and there, for example,
if (CanProveDivisible(right_edge, divisor)) {
right_pad = 0;
} else {
right_pad = analyzer_->Simplify(floormod(-right_edge, divisor));
}
@Lunderberg suggest Simplify
could be optimized to iterate until reaching fix point. But now it is suffice to work on existing tests.
001ed50
to
f24db1d
Compare
Could you also update this line https://github.com/apache/tvm/blob/main/src/tir/schedule/primitive/layout_transformation.cc#L395? There are some conflict that CI didn't catch because of concurrent merge |
- determine case like x % 16, x in [0, 5) to be non-surjective, since usages may treat the region extent as 16 by mistake. - skip second round of rewrite when there is no padding - fix some typo in comments
f24db1d
to
48a16f1
Compare
48a16f1
to
4d1239a
Compare
One bug from my side is magically fixed by this PR!! |
Update a simplify rule when c2 is nonzero, original rule is covered with constant folding.
floormod(x * c1, c2)
=>floormod(x * (floordiv(c1, c2) * c2 + floormod(c1, c2)), c2)
=>floormod(x * floormod(c1, c2)), c2)
This is useful for certain non-perfect tiling case, where there are dynamic loop ranges which is actually constant wrt outer loop domain.
For example,
floordiv(floormod(x * 360, 16) + 359, 16)
with x in [0, 2) can finally reduce to constant22
, since the rule could eliminate the multiply factor360
to360 % 16
, activating more available rules.Unfortunately the working example on tiling encounter a region_cover related problem again.