Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add additional termination condition to For node to enable While loop like feature #7385

Closed
wants to merge 16 commits into from

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 1, 2021

This is my proposed solution to add While loop like feature to TIR, in the simplest, the least invasive way. It generalizes the For node termination condition from

loop_var < extent

to

loop_var < extent && test

Using this, we can write binary search as follows (see the complete test case, which implements numpy searchsorted function, here).

lo[0] = 0
hi[0] = n
v = Bptr[i]
num_loop = int(np.log2(n)) + 1

with ib.for_range(0, num_loop, test=(lo[0] < hi[0])) as _:
    mid = lo[0] + tvm.tir.floordiv(hi[0] - lo[0], 2).astype("int32")
    with ib.if_scope(Aptr[mid] < v):
        lo[0] = mid + 1
    with ib.else_scope():
        hi[0] = mid

Cptr[i] = lo[0]

My motivation was to improve GPU NMS performance using while loop, and it indeed did:

NMS workload from PyTorch MaskRCNN:

Without while loop (current main): 4.11 milli sec
With while loop (my branch): 3.66 milli sec

And a crazy 120000 box + 100 max_out_size NMS workload from TF MaskRCNN. The difference is huge because the # of iterations changed from 120000 to 100 (roughly)

Without while loop (current main): 51.31 milli sec
With while loop (my branch): 17.63 milli sec

please review @tqchen @mbrookhart @kevinthesun @zhiics @Laurawly @anijain2305 @trevor-m

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! This is awesome stuff.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Longer term we probably want to add this to hybrid script and/or tvm script

@tqchen tqchen added the status: need RFC need RFC discussion label Feb 1, 2021
@@ -802,6 +802,9 @@ class ForNode : public StmtNode {
ForKind kind;
/*! \brief The body of the for loop. */
Stmt body;
/*! \brief The additional termination condition of the for loop. */
Optional<PrimExpr> test;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to have a RFC discussion, since different strategies changes to the IR can have different implications

@tqchen
Copy link
Member

tqchen commented Feb 1, 2021

Thanks @masahi . I think it is great to enable support for some form of While loop.

It would be great to have an RFC thread discussing the alternatives in the IR node design. Since the IR node design can impact the general ability to do analysis and will impact how would we engineer future transformation passes.

For example, I can see two possible variants:

  • V0: Put the condition into the for loop as it is (the current approach)
  • V1: Introduce a separate While node for while loops

V0 means the for loop is somewhat overloaded for both while and For. On one hand it brings the benefit of richer semantics and the minimum set of changes to enable such feature.

This does mean that the visitors to For would need to handle the semantics of break. Given that the for loop is used for regular interval analysis, it could be beneficial to distinguish between "regular structured loop" vs "un-structured loop", unless the condition testing also offers some analysis benefits.

I think this PoC is a great start and would be great to have a design discussion over the RFC

@masahi
Copy link
Member Author

masahi commented Feb 1, 2021

@tqchen Yes, what you said totally makes sense to me. As you said, this is a minimal-change solution, but I think ideally we want a separate While node. Since while loop only makes sense for sequential loop (I think), I think it is better to decouple a simpler While node that doesn't need any analysis from heavy-duty For node.

I'll send a RFC, sure.

@masahi
Copy link
Member Author

masahi commented Feb 2, 2021

@masahi masahi mentioned this pull request Feb 9, 2021
@masahi
Copy link
Member Author

masahi commented Feb 9, 2021

TIR While node added in #7425

@masahi masahi closed this Feb 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need RFC need RFC discussion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants