-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Add initial control flow support #4964
Conversation
@masahi Thank you very much for the nice work! I am familiar the TF control-flow, and trying to read more about PyTorch control-flow constructs. Will have a careful review by tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments, also reading up on control flow in Torch.
25e2895
to
3b9e838
Compare
can you dont call it parse? parse is for converting strings to an ast. converting ast to ast is a converter. nice work otherwise. |
Tests passed!!
no probelm, will do. @alexwong @zhiics a good way to get familiar with torch script is to try and tweak simple examples. For example, import torch
class SimpleLoop(torch.nn.Module):
def forward(self, inp):
a = inp
for i in range(10):
a += i
return a
class SimpleWhileLoop(torch.nn.Module):
def forward(self, inp):
a = inp
i = 0
while i < 10:
a += i
i += 1
return a
print(torch.jit.script(SimpleLoop()).graph)
print(torch.jit.script(SimpleWhileLoop()).graph) would print
hopefully it is much simpler than tensorflow control flow. |
One thing to keep in mind is that torchscript has mutation, continue, while... So the design should be prepared to generate code in A-Normal Form for correctness. Right now it is not needed as they are all purely functional, but the architecture shouldnt make such a change as simple as possible (e.g. not requiring rewriting all the code). |
@MarisaKirisame I have something relevant to share. I was told by one of the torchscript devs on their forum that he is working on a "functionalization" pass pytorch/pytorch#33020, and that could help my use case. I'm not sure what exactly it does, but from looking at their code it seems it tries to find a "functional" subset of nodes, i.e. nodes with no impure operation, and extract it as a subgraph. I guess it helps applying some of their optimizations more aggressively. He also has others related PRs ongoing that aim at making the graph "more pure". One on removing inplace op pytorch/pytorch#33186, another on removing list append pytorch/pytorch#33199. I literally encountered these impure ops when I was trying to convert more realistic lstm models than the one I have in this PR, so these new feature in Torch could be useful to us later. |
f60c82a
to
93371a4
Compare
... WTF |
a83a729
to
b1333cb
Compare
b1333cb
to
b936bfa
Compare
load with torch-1.4 + torchvision 0.5
b936bfa
to
30f2129
Compare
@zhiics @MarisaKirisame @alexwong Comments were addressed and I have no plan of updating this PR further. Can we merge this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@MarisaKirisame could you take a look and approve explicitly if it looks good to you as well?
Thanks @masahi @MarisaKirisame @alexwong |
* Add support for prim::If and prim::Loop with test cases * rebase and fix tests * add some comments * simplifying, fix float cast * parse -> convert * recursivly retrive ops in get_all_op_names * use multiple return values from block correctly, simplify loop convert * choose dtype properly for zeros and ones * simplifying, replace convert_inputs with _get_relay_input_vars * fix for while loop with non input dependent init cond * add assert on loop var update * move the condition around * better testing for seg models * rebase fix, disable inception v3 in quant test as it is too slow to load with torch-1.4 + torchvision 0.5 * simplify and add more comparison op converter
* Add support for prim::If and prim::Loop with test cases * rebase and fix tests * add some comments * simplifying, fix float cast * parse -> convert * recursivly retrive ops in get_all_op_names * use multiple return values from block correctly, simplify loop convert * choose dtype properly for zeros and ones * simplifying, replace convert_inputs with _get_relay_input_vars * fix for while loop with non input dependent init cond * add assert on loop var update * move the condition around * better testing for seg models * rebase fix, disable inception v3 in quant test as it is too slow to load with torch-1.4 + torchvision 0.5 * simplify and add more comparison op converter
This adds support for parsing Torchscript
prim::If
andprim::Loop
nodes. This is also the first attempt at translating from the output oftorch.jit.script(...)
. See the test cases for currently supported Python construct.prim::If
can be easily translated by recursively parsing true and else branches. The spec https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#ifprim::Loop
requires more mork, but using Relay conditional and tail recursion it is not too hard. https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loopsThe related discussion (with an example IR dump): https://discuss.tvm.ai/t/discuss-adding-a-pytorch-frontend/5026/24
The CI is blocked by unrelated sphinx issue, but it is ready for review.
cc @zhiics @icemelon9 @wweic @jroesch @MarisaKirisame @alexwong @tqchen @junrushao1994 @ajtulloch @yinghai