Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Analyzer CanonicalSimplifier #2891

Merged
merged 9 commits into from
Mar 31, 2019
Merged

[ARITH] Analyzer CanonicalSimplifier #2891

merged 9 commits into from
Mar 31, 2019

Conversation

tqchen
Copy link
Member

@tqchen tqchen commented Mar 25, 2019

This PR contains one step of #2588

  • CanonicalSimplifier Infra
  • Support "split normal form" to handle simplification of div mod expressions
  • Move the old Canonical simplification to the new one
  • Move the reduction simplification to the new infra.

The main highlight of this PR is the introduction of the "split normal form", so we can simplify the following expression.

x/6*6 + (((x/3) % 2)*3) + (x % 3) => x

It is quite fun to implement the split normalization. Currently, we only support constant div and mod co-efficient for simplicity, we can consider adding symbolic support later.

@tqchen
Copy link
Member Author

tqchen commented Mar 25, 2019

@tqchen
Copy link
Member Author

tqchen commented Mar 25, 2019

Also as a side note, this PR helps to demonstrate how we can consolidate some of the simplification infra around the Analayzer, which could be helpful to improve some open PR that @sgrechanik-h is working on

@tqchen tqchen force-pushed the canonical branch 2 times, most recently from 337f7d1 to b433c00 Compare March 25, 2019 05:12
@tqchen
Copy link
Member Author

tqchen commented Mar 26, 2019

CI is now green, would be great if we can get some inputs into the PR

@tqchen
Copy link
Member Author

tqchen commented Mar 27, 2019

@Hzfengsy can you also help review this PR?

Copy link
Contributor

@sgrechanik-h sgrechanik-h left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have little time currently. I'll try to look again today or tomorrow.

this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
this->rewrite_simplify.Update(var, this->rewrite_simplify(expr));
Expr new_expr = expr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why do you copy expr to new_expr here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make new expr mutable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But none of the subsequent lines changes it, or am I wrong?.

/*!
* \brief Internal "Split normal form" of expression.
*
* This is a special expression that represent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: represents

Expr NormalizeWithScale(int64_t sscale) const {
Expr res = this->index;
Type dtype = this->type;
CHECK_EQ(this->type, dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check seems redundant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like an asset, to check runtime consistency

Copy link
Contributor

@sgrechanik-h sgrechanik-h Mar 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But dtype gets initialized with this->type in the previous line, so this check is obviously true.

* args are divided into segments with the same index.
* within each segment, the SplitExpr is ordered in descending order of lower_factor.
*
* \note Can be mutated by TryMergeSplitExpr, which is idempotent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is TryMergeSplitExpr? (Didn't find it in the code)

//
// ((x / (c * s)) * s + (x % (c * s)) / c
// => ((x / c) / s) * s + ((x / c) % s)
// => (x / c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly speaking, I can't understand this algorithm. Probably I have to return to it in a better
state of mind. Expanding the explanation may be helpful too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplification rule and proof are correct. It's based on two basic rules:

Rule 1:  (x % (c * s)) / c  =  (x / c) % s
Proof:
   x can always be decomposed into p * c * s + q * c + r  where  0 <= q * c + r < c * s  and  0 <= r  <  c.
   Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q
         rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q
   Thus, lhs = rhs

Rule 2:  (x / s) * s + x % s = x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Although it not obvious to me if the rules are still correct for the C/C++ division used in tvm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rule works for both trunc div and floor version. Mainly because that the first rule only involves mul div and mod. And you can simply take abs of all operands and then take addd the final sign. The second rule is an invariant for both types of div

}
// sort by the entry
auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) {
// order by scale first
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be ordered by index first? Or at least if the indices are different, the elements
should be incomparable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, however, it is can be quite costly to deep compare indices. So instead we just order by the scale and factor so that it is mostly in a consistent form(ignoring the indices)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the algorithm assumes that the vector contains contiguous segments of same-index elements, we have to check in the comparison function if lhs and rhs have the same index, otherwise this assumption may be destroyed by sorting.

Also I would still suggest sorting by index because otherwise we may often get into the situation when something like f(x + y) - f(y + x) don't get simplified. And if it really leads to performance problems, we should think about optimizing deep comparison somehow (probably we can cache the size or some other measure of an expression).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now the code is fine because the result is only directly used by normalize, which is the intended usecase. It does break if we call it in the middle.

The only reason why I am not sure about comparing index is that var comparison can depend on runtime, which makes its behavior indeterministic. I want to think a bit more about this before we come back to revisit it

void DivideBy(int64_t scale) {
this->base /= scale;
for (size_t i = 0; i < this->args.size(); ++i) {
args[i].CopyOnWrite()->scale /= scale;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we raise an error if some of the scales are not divisible by the argument?

return;
}
}
// Insert other in the end.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should also sort by index.

* \param other The expression to be added.
* \param scale The additional scale on value.
*/
void AddToSelf(SplitExpr other, int64_t scale) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be a better way to use const SplitExpr &other instead of SplitExpr other

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to use CopyOnWrite inside the arguments so use SplitExpr directly. If the item is newly constructed, other will be directly passed in as a unique copy, and CopyOnWrite will reuse that data from there.

if (!IsIndexType(op->type)) {
return Rewriter::Mutate_(op, self);
}
// normalize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we may build a function to reduce the duplicate code in above three functions

// note: x = z, c = 3, s = 2
//
// ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3
// => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be a condition that lhs->upper_factor % rhs->upper_factor == 0 so that we can perform the transformation z % rhs->upper_factor => (z % lhs->upper_factor) % rhs->upper_factor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an invariant condition that lhs->upper_factor % lhs->lower_factor == 0

@tqchen
Copy link
Member Author

tqchen commented Mar 29, 2019

@Hzfengsy @sxjscience @sgrechanik-h thanks for the reviews, I have updated the comment blocks to add more explanations about the proof

@xqdan
Copy link
Contributor

xqdan commented Mar 29, 2019

we have a case like this, can this PR handle it?

for (ee10, 0, 16) {
  for (mo_11, 0, 5) {
    for (mi_12, 0, 16) {
      for (ee13, 0, 16) {
        max_1[(((((((ee10*14) + (mo_11 + ((mi_12 + (mo_11*16))/14)))*14) + ((mi_12 + (mo_11*16)) % 14))*16) + ee13) + 2688)] = max_1_local_UB[((((((ee10*5) + mo_11)*16) + mi_12)*16) + ee13)]
      }
    }
  }
}

@tqchen
Copy link
Member Author

tqchen commented Mar 29, 2019

@xqdan you should try it out. In theory this canonical simplifier is able to handle all kinds of div mode mul pattern that comes out from split and re-fuse

@xqdan
Copy link
Contributor

xqdan commented Mar 30, 2019

@xqdan you should try it out. In theory this canonical simplifier is able to handle all kinds of div mode mul pattern that comes out from split and re-fuse

nice, I will try after this is merged.

// note also the invariance lhs->upper_factor % lhs->lower_factor == 0
//
SplitExprNode* merged = rhs.CopyOnWrite();
merged->upper_factor = lhs->upper_factor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct only when lhs->uppper_factor == kPosInf? For example, ((x % 5) / (3 * 2)) * 2 + (x % (3 * 2)) / 3 is simplified to x % 5 / 3, but this is not correct when x == 5.

Copy link
Member Author

@tqchen tqchen Mar 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above on invariance

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, sorry for my confusion.

// - s = lhs->scale / rhs->scale
// - c = rhs->lower_factor
//
// ((x / (c * s)) * s + (x % (c * s)) / c
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant (.

(x / (c * s)) * s + (x % (c * s)) / c

// note also the invariance lhs->upper_factor % lhs->lower_factor == 0
//
SplitExprNode* merged = rhs.CopyOnWrite();
merged->upper_factor = lhs->upper_factor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, sorry for my confusion.

if (cval % lhs->scale == 0) {
int64_t scaled_cval = cval / lhs->scale;
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this guarantee the invariance lhs->upper_factor % lhs->lower_factor == 0? It looks not obvious and I wonder if we should call lhs->Verify() here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch :)

@tqchen
Copy link
Member Author

tqchen commented Mar 30, 2019

@Hzfengsy @sxjscience @sgrechanik-h @kazum thanks for the reviews, please take another look

Copy link
Contributor

@sgrechanik-h sgrechanik-h left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skimmed through the rest, seems ok. My main concern is sorting by the index field, but it can be done in subsequent PRs.

* \param coeff The co-efficient.
* \param out_divisible The result divisible component.
* \param out_non_divisible The non-divisible component.
* \return Whetjer detection is successful.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo whether

ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition)));
else_case = this->Mutate(op->else_case);
}
if (is_one(condition)) return op->then_case;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return then_case instead of op->then_case here? (And same thing with the else_case)

return lhs;
} else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) {
// (x % c1) / c2 => 0 when c2 >= c1
return ToSplitExpr(make_zero(lhs.type()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also return zero when cval % lhs->scale != 0? I mean the below looks correct in any cases.

if (lhs->upper_factor <= (lhs->lower_factor * cval / lhs->scale))
   return ToSplitExpr(make_zero(lhs.type()));

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the mul and division are not necessarily exchangeable, in the case of cval % lhs->scale != 0, we will need to consider the consequence more carefully, because it is rare to have such case, we just skip the optimization

@tqchen tqchen merged commit 7afbca5 into apache:master Mar 31, 2019
@tqchen
Copy link
Member Author

tqchen commented Mar 31, 2019

Thanks, @sgrechanik-h @sxjscience @kazum @xqdan @Hzfengsy , this is now merged

@tqchen tqchen deleted the canonical branch March 31, 2019 22:26
wweic pushed a commit to wweic/tvm that referenced this pull request Apr 7, 2019
wweic pushed a commit to wweic/tvm that referenced this pull request Apr 7, 2019
wweic pushed a commit to wweic/tvm that referenced this pull request Apr 8, 2019
wweic pushed a commit to wweic/tvm that referenced this pull request Apr 10, 2019
wweic pushed a commit to neo-ai/tvm that referenced this pull request Apr 11, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants