-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Enforce static dim for non-concat axis if one or more tensors have static dim #7487
Conversation
56d53b2
to
047a0b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition. I think, we might have similar scenarios in TF models as well @trevor-m
LGTM
@mbrookhart Can you take a look as well? |
My main concern with this, and please correct me if I'm wrong, is concatenating something like shape [1, 3] and shape [1, ?] on axis 0. In that case, I believe this will return [2, 3], and we don't do any runtime checks on the second argument. If the model is wrong, and, say, the input comes in as [1, 2] instead of [1, 3], we'll end up with a runtime segfault in generated code, which is very difficult to debug. |
@mbrookhart I'm not sure about this, but doesn't concat shape function do the runtime check to prevent segfault? I mean, for the shape func, the case where all dims are dynamic vs dims are mixed static/dynamic shouldn't make a difference for what the shape func should do? |
The shape function only runs if the output shape of the op is dynamic at compile time: tvm/src/relay/transforms/memory_alloc.cc Lines 185 to 187 in 88a4fdd
tvm/src/relay/transforms/memory_alloc.cc Line 356 in 88a4fdd
tvm/src/relay/transforms/memory_alloc.cc Lines 287 to 350 in 88a4fdd
|
If you have dynamically shaped inputs and a statically shaped output, the shape func won't run, so you wont be able to update/check your assumption on that dynamic input at runtime. |
Thanks, that makes sense. I wish there was a way to decouple output shape calculation and input shape check from shape func, so that we can add input shape check for case like this (dynamic input shape + static output shape). |
ae3dce3
to
2d99cd6
Compare
2d99cd6
to
5d66703
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @anijain2305 @mbrookhart |
… have static dim (apache#7487) * enforce static dim for non-concat axis * assign any when all dims are dyn * add missing case * simplify * add test * only enforce static dim constraint if concat output is dynamic * more update to concat type rel * update tests * fixed compile warning
… have static dim (apache#7487) * enforce static dim for non-concat axis * assign any when all dims are dyn * add missing case * simplify * add test * only enforce static dim constraint if concat output is dynamic * more update to concat type rel * update tests * fixed compile warning
Currently, the concat type relation assigns output shape Any() to non-concat axes if there is even one
Any
in corresponding input tensor shapes. But it is clear that for non-concat axes, if there is one or more static dim in input tensors, the input static dims and the output dim must all be the same static value.For example, after this PR, the typing for the concat op with concat axis == 0 changes as follows:
This, together with #7479, removes all unnecessary
any_dim
from PyTorch MaskRCNN, and significantly simplifies dynamic injective kernels. Below is an example of super-complicated and slow dynamic injective op, due to too manyany_dim
. This has been the bottleneck in PyTorch MaskRCNN, but after this PR, it becomes reasonable and no longer bottleneck.Thanks to this optimization, PyTorch MaskRCNN runs 30 milli seconds (!!) faster now.
please review @kevinthesun @anijain2305 @mbrookhart @jwfromm