[TOPI] TE implementation of LSTM using scan #11531
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a TE implementation of LSTM (with optional modifications, similar to those in https://github.com/apache/tvm/blob/main/python/tvm/relay/frontend/common.py#L774), using the
te.scan
construct (so that the recurrent loop is truly a sequential loop, rather than unrolled statically). This compute should support symbolic sequence length.Missing from this PR:
I'll send a follow-up PR for the Relay op, but scheduling the LSTM might take a while (if anyone is interested, please feel free to take a stab!). The main thing to optimize is the dense operations within the kernel (the initial input-hidden dense, recurrent hidden-hidden dense, and hidden-projection dense). I couldn't figure out a great way to use existing schedules here...
Things I am hoping to try:
Regarding metascheduling: the current
CreatePrimFunc
conversion from TE -> S-TIR doesn't support scan operations. I have a hack that makes this conversion work, but am hitting some snags regarding schedule rules, primitives, and post procs (the outer scan axis seems to break a lot of assumptions). I can try to clean up this conversion if that's valuable, but also am curious if anyone is interested in tackling this by adjusting the constraints on blocks to support outer scan axis.cc @vinx13 @junrushao1994 @tkonolige @michalpiszczek @masahi
Additional thanks to @vinx13 and @zxybazh for helping debug metaschedule issues (I hope this PR helps as a concrete starting point for getting things working), maybe you guys can cc others who may be interested? And thanks @junrushao1994 for the very helpful LSTM example from ~5 (!) years ago https://github.com/apache/tvm/blob/main/apps/topi_recipe/rnn/lstm.py which I used as a starting point.