-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass that removes reshapes post LowerTE #12215
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ashutosh-arm. looking good!
I think we need unit tests for the pass as well.
(E.g. https://github.com/apache/tvm/blob/main/tests/python/relay/test_pass_partition_graph.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @ashutosh-arm! Just some small things I picked up on..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't have tons of context here, but left a couple suggestions
return WithFields(GetRef<Let>(let), var, value, body); | ||
} | ||
|
||
/*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm probably missing some context here, but what about just returning the args to reshape()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Graph contains let nodes in between the call_lowered(). I've included the following piece as part of the Rewrite_() as well.
/*
%1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...);
let %x: = on_device(%1, ...);
%2 = (%x,);
%3 = call_lowered(@tvmgen_default_fused_reshape, %2, ...,
"relay_attrs"=__dict__="relay.reshape_only"=1, ...);
*/
Change-Id: Iaf5a5f44776080b0b842af4b563d596134508de1
Change-Id: I1f45ee3b15fbe290fdce69832a850d7d85ea1681
Change-Id: I81462a552f467d88cf1288acef2f9cbacc3ff532
Change-Id: I8502bc74eb0914cfcaa86cb809d7c4a9c6e86c70
ca579b2
to
389cadb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Thanks @ashutosh-arm @manupa-arm @areusch! |
Introduces a Pass for removing intermediate reshapes post LowerTE() in AOT compiler. This commit adds pass specific tests and updates usmp generated workspace pools due to reduction in number of allocations post reshape removals. Note: this pass at present does not support first reshape appearing in the graph. If seen as a useful case, it can be added in the future.
Introduces a Pass for removing intermediate reshapes post
LowerTE() in AOT compiler. This commit adds pass specific
tests and updates usmp generated workspace pools due to
reduction in number of allocations post reshape removals.
Note: this pass at present does not support first reshape
appearing in the graph. If seen as a useful case, it can be
added in the future.
cc: @manupa-arm @grant-arm