Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Passes] Iterative A-normal Traversals #7374

Merged
merged 5 commits into from
Feb 2, 2021

Conversation

mbrookhart
Copy link
Contributor

@mbrookhart mbrookhart commented Jan 29, 2021

On very large models, we can have ASTs that stack overflow because we recursively traverse tens of thousands of chained Let nodes. In A-normal form, we can simply iterate over that chain instead.

These changes find every pass that segfaults with a stack overflow in ONNX SSD-Mobilenet (~20k nodes) and hijacks the visitor to do iterative traversals over chains of lets. This lets me run that model without setting ulimit -s unlimited by significantly reducing the stack size.

I also needed to convert FuseOps to mixed mode.

Thanks!

cc @jroesch @tqchen @tristan-arm @altanh @zhiics @masahi @kevinthesun

@manupak
Copy link
Contributor

manupak commented Feb 1, 2021

Hi @mbrookhart ,

This looks great! In the code clean up, would you be considering factor out let-chain traversal possibly as an utility ?. I see a pattern that the current implementation of Let node is had "pre" and "post" processing stage that is being done prior and after the pushing them to the stack. So I was wondering, whether its possible to expose a generic interface just to provide the pre processing functionality (going down the let chain) and post processing functionality (coming up the let chain). I m not sure how the actual implementation should look like -- do let me know what you have in mind :) .

@mbrookhart
Copy link
Contributor Author

@manupa-arm Thanks for the suggestion. @altanh and I were talking about doing something similar late last week using higher order functions. I'll let you know what I can figure out.

Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM except for a few nits, clean solution!

I stopped repeating my comment but basically I'd prefer all implicit uses of captured this be made explicit- currently it's not consistent across the passes.

src/relay/transforms/de_duplicate.cc Outdated Show resolved Hide resolved
src/relay/transforms/de_duplicate.cc Outdated Show resolved Hide resolved
src/relay/transforms/fold_constant.cc Outdated Show resolved Hide resolved
src/relay/transforms/fold_constant.cc Outdated Show resolved Hide resolved
src/relay/transforms/fold_constant.cc Outdated Show resolved Hide resolved
src/relay/transforms/fold_constant.cc Outdated Show resolved Hide resolved
src/relay/transforms/fuse_ops.cc Show resolved Hide resolved
@mbrookhart mbrookhart changed the title [WIP][Relay][Passes] Iterative A-normal Traversals [Relay][Passes] Iterative A-normal Traversals Feb 1, 2021
auto post_visit = [this](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
Var var = Downcast<Var>(this->VisitExpr(op->var));
Expr value = this->VisitExpr(op->value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to visit var and value again in the post?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need to pull the values out of the cache. Instead of maintaining a cache shared by the two lambdas, I'm using the memorization cache in the Mutator. The second time visit is called, it will short circuit and return the previously computed value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. thanks.

Copy link
Contributor

@manupak manupak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbrookhart exactly what I had in mind as well :) Thanks!

@tmoreau89 tmoreau89 merged commit 0ab9c95 into apache:main Feb 2, 2021
@tmoreau89
Copy link
Contributor

Thanks @mbrookhart @manupa-arm @zhiics @altanh the PR has been merged!

@mbrookhart mbrookhart deleted the iterative_a_normal_lets branch February 2, 2021 17:40
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 11, 2021
* [WIP][Relay][Passes] non-recursive a-normal traversals

* fix clang warning

* Refactor ANormal Iterative traversal into a higher order function utility with lambdas

* refactor missed pass

* add explict use of  to lamdbas
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* [WIP][Relay][Passes] non-recursive a-normal traversals

* fix clang warning

* Refactor ANormal Iterative traversal into a higher order function utility with lambdas

* refactor missed pass

* add explict use of  to lamdbas
Lokiiiiii pushed a commit to Lokiiiiii/tvm that referenced this pull request Mar 2, 2021
* [WIP][Relay][Passes] non-recursive a-normal traversals

* fix clang warning

* Refactor ANormal Iterative traversal into a higher order function utility with lambdas

* refactor missed pass

* add explict use of  to lamdbas
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2021
* [WIP][Relay][Passes] non-recursive a-normal traversals

* fix clang warning

* Refactor ANormal Iterative traversal into a higher order function utility with lambdas

* refactor missed pass

* add explict use of  to lamdbas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants