-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functionalization integration #4158
Conversation
@bdhirsh build is failing, is that expected? |
348ba86
to
6b694e0
Compare
whoops - I think this is now fixed, but I'll confirm |
6b694e0
to
28a6a79
Compare
68ba20b
to
8c460ab
Compare
Okay, fixed some build failures that I had after rebasing the upstream PyTorch PR and this PR. The CI failure on this CI is not up to date, it is being shown because the pinned upstream PR is out of date. Now I can see that bunch of PyTorch's
This test failure is caused at line https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L3007. Putting some debugging statements, I can see that the parameters (
The line |
39583b3
to
179ac8b
Compare
Opened a separate PR on PyTorch side, since I didn't have access to update/rebase the original one -- pytorch/pytorch#88787. With the latest rebase, the new sym int copy ops (79cbdb6) caused some failures that was addressed by updating |
07d072c
to
022a67f
Compare
This comment was marked as off-topic.
This comment was marked as off-topic.
4e3c508
to
4bcbbad
Compare
4bcbbad
to
fa96188
Compare
Was able to fix some of the errors seen before, but now I'm also seeing an odd failure for a very simple test
According to the error message, it seems like the op Meanwhile, I'll get started with a draft PR to remove view/aliasing infrastructure in XLA. |
61a6426
to
9e5f141
Compare
a5684ce
to
dc3ed92
Compare
@wonjoolee95 Can you use |
ae482b4
to
5009774
Compare
Seems like the upstream PR might need a rebase. @alanwaketan, do we need the commit |
My bad, yea, the upstream PR will need a rebase as well. |
This reverts commit 2ded7ca.
Summary: test_exponential should be fixed by upstream: pytorch/pytorch#93053. Test Plan: CI.
…k_and_unscale_ op tests to python (#4687) * Move nan_to_num_ and _amp_foreach_non_finite_check_and_unscale_ op tests to python * Update test call to runAtenTest
…ionalize pass (#4681) Summary: For any CompositeExplicitAutograd ops, we are supposed to explicitly re-enable functionalization such that any decomposed ops within those ops get functionalized as well. However, if directly calling into at::functionalization::functionalize_aten_op, convolution_backward will somehow omit convolution_backward_overridable which is our own kernel to calculate convolution. Thus, no grads are produced. To workaround the issue, we manually redispatch convolution_backward to functionalize pass. Test Plan: PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward
Summary: This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor API. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule. P.S. It also removes two .data tests that are no longer applicable. Test Plan: CI.
cf44d6f
to
f82f718
Compare
I'm preparing to land the branch finally. @yeounoh Please review the SPMD part and make sure it's correct. @vanbasten23 Please review the DS part and make sure it's correct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and great work @alanwaketan @wonjoolee95 @vanbasten23 @bdhirsh !
Thanks Jack and Xiongfei for reviewing. |
Summary: This is the consolidated branch of all the functionalization changes. For individual changes, please visit the corresponding PRs for details. Here is a brief summary of the highlights: Funtionalization is a dispatcher pass introduced in upstream to remove views and mutations from a PyTorch program to produce functional graphs which is better for backend compilers to do their optimizations. We have a in-house infrastructure to do similar tricks as the HLOs we generated are always functional. The benefits of adopting the upstream Funtionalization pass is that we can get rid of the in-house view infrastructure that we have struggled to maintain for years, and let upstream do all those heavy lifting for us. Implementation details: 1. To enable Funtionalization, we just need to wrap our newly created at::Tensor that holds a XLATensor in a FunctionalWrapper and then return the wrapper to Python. Then for any consecutive ops will then first go to the Funtionalization pass before reaching us. Correspondingly, we then have to unwrap the FunctionalWrapper before getting the XLATensor. Basically a thin layer called FunctionalWrapper is added to the whole aten-xla bridge process. FunctionalWrapper <=> at::Tensor <=> XLATensor. 2. To support the new view ops, for each view op, we have to implement at most two variants: view_copy: the view op but returns a copy instead; view_scatter: sometimes we need extra logic to reapply the updated value to the base view. 3. For in-place ops, we have a new op called _propagate_xla_data to keep any in-place ops optimization we had before active. Test Plan: CI
Summary:
This is the consolidated branch of all the functionalization changes. For individual changes, please visit the corresponding PRs for details.
Here is a brief summary of the highlights: Funtionalization is a dispatcher pass introduced in upstream to remove views and mutations from a PyTorch program to produce functional graphs which is better for backend compilers to do their optimizations. We have a in-house infrastructure to do similar tricks as the HLOs we generated are always functional. The benefits of adopting the upstream Funtionalization pass is that we can get rid of the in-house view infrastructure that we have struggled to maintain for years, and let upstream do all those heavy lifting for us.
Implementation details:
Design Doc
Test Plan:
CI