Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalization integration #4158

Merged
merged 130 commits into from
Mar 2, 2023
Merged

Functionalization integration #4158

merged 130 commits into from
Mar 2, 2023

Conversation

bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Nov 4, 2022

Summary:
This is the consolidated branch of all the functionalization changes. For individual changes, please visit the corresponding PRs for details.

Here is a brief summary of the highlights: Funtionalization is a dispatcher pass introduced in upstream to remove views and mutations from a PyTorch program to produce functional graphs which is better for backend compilers to do their optimizations. We have a in-house infrastructure to do similar tricks as the HLOs we generated are always functional. The benefits of adopting the upstream Funtionalization pass is that we can get rid of the in-house view infrastructure that we have struggled to maintain for years, and let upstream do all those heavy lifting for us.

Implementation details:

  1. To enable Funtionalization, we just need to wrap our newly created at::Tensor that holds a XLATensor in a FunctionalWrapper and then return the wrapper to Python. Then for any consecutive ops will then first go to the Funtionalization pass before reaching us. Correspondingly, we then have to unwrap the FunctionalWrapper before getting the XLATensor. Basically a thin layer called FunctionalWrapper is added to the whole aten-xla bridge process. FunctionalWrapper <=> at::Tensor <=> XLATensor.
  2. To support the new view ops, for each view op, we have to implement at most two variants: view_copy: the view op but returns a copy instead; view_scatter: sometimes we need extra logic to reapply the updated value to the base view.
  3. For in-place ops, we have a new op called _propagate_xla_data to keep any in-place ops optimization we had before active.

Design Doc

Test Plan:
CI

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 4, 2022

@bdhirsh build is failing, is that expected?

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Nov 4, 2022

whoops - I think this is now fixed, but I'll confirm

@wonjoolee95 wonjoolee95 marked this pull request as draft November 8, 2022 22:47
@wonjoolee95
Copy link
Collaborator

wonjoolee95 commented Nov 9, 2022

Okay, fixed some build failures that I had after rebasing the upstream PyTorch PR and this PR. The CI failure on this CI is not up to date, it is being shown because the pinned upstream PR is out of date.

Now I can see that bunch of PyTorch's index_reduce tests failing. For a specific example:

======================================================================
ERROR: test_index_reduce_reduce_mean_xla_bfloat16 (__main__.TestTorchDeviceTypeXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 378, in instantiated_test
    result = test(self, **param_kwargs)
  File "test/test_torch.py", line 3011, in test_index_reduce
    dest.index_reduce_(dim, idx, src, reduce, include_self=include_self)
IndexError: select(): index 0 out of range for tensor of size [3, 4, 0] at dimension 2

----------------------------------------------------------------------
Ran 1 test in 13.533s

This test failure is caused at line https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L3007. Putting some debugging statements, I can see that the parameters (dim, idx, src) that cause the failures are:

dim=2
idx=tensor([], device='xla:0', dtype=torch.int64)
src=tensor([], device='xla:0', size=(3, 4, 0), dtype=torch.bfloat16)

The line dest.index_reduce_(dim, idx, src, reduce, include_self=include_self) doesn't seem to trigger XLA's select ops either. @bdhirsh, any quick ideas why this line might be give such error?

@wonjoolee95
Copy link
Collaborator

Opened a separate PR on PyTorch side, since I didn't have access to update/rebase the original one -- pytorch/pytorch#88787.

With the latest rebase, the new sym int copy ops (79cbdb6) caused some failures that was addressed by updating at::native::narrow_copy_dense_symint to at::functionalization::functionalize_aten_op_symint<ATEN_OP(narrow_copy)>. Need to double check if the dense can be safely dropped.

@wonjoolee95 wonjoolee95 force-pushed the functionalization branch 2 times, most recently from 07d072c to 022a67f Compare November 10, 2022 19:05
@alanwaketan

This comment was marked as off-topic.

@wonjoolee95 wonjoolee95 force-pushed the functionalization branch 3 times, most recently from 4e3c508 to 4bcbbad Compare November 11, 2022 22:55
@wonjoolee95
Copy link
Collaborator

Was able to fix some of the errors seen before, but now I'm also seeing an odd failure for a very simple test test_bool_tensor_value_change at https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L2760:

test/test_torch.py:2764: UserWarning: 0The operator aten::as_strided appears to be a view operator, but it has no implementation for the backend "xla:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/CPUFallback.cpp:188.)
  x[0] = False
E
======================================================================
ERROR: test_bool_tensor_value_change_xla (__main__.TestTorchDeviceTypeXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 391, in instantiated_test
    raise rte
  File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 378, in instantiated_test
    result = test(self, **param_kwargs)
  File "test/test_torch.py", line 2764, in test_bool_tensor_value_change
    x[0] = False
RuntimeError: self__storage_saved.value().is_alias_of(result.storage()) INTERNAL ASSERT FAILED at "/workspace/pytorch/torch/csrc/autograd/generated/VariableType_3.cpp":11489, please report a bug to PyTorch. 

----------------------------------------------------------------------
Ran 1 test in 0.081s

FAILED (errors=1)

According to the error message, it seems like the op as_strided doesn't seem to get correctly dispatched into XLA hence getting dispatched to CPU and failing with the error above. If I test without the changes in this PR, I can confirm that this op gets dispatched to XLA. I'm syncing offline with @bdhirsh to fix this.

Meanwhile, I'll get started with a draft PR to remove view/aliasing infrastructure in XLA.

This was referenced Nov 21, 2022
@wonjoolee95 wonjoolee95 force-pushed the functionalization branch 4 times, most recently from 61a6426 to 9e5f141 Compare December 6, 2022 23:48
@alanwaketan
Copy link
Collaborator

alanwaketan commented Dec 16, 2022

@wonjoolee95 Can you use TF_VLOG instead of std::cout for your own logging? Such that we can optionally turn them on or off by envs.

@wonjoolee95
Copy link
Collaborator

Seems like the upstream PR might need a rebase. @alanwaketan, do we need the commit [Revert "Use C10_AS_INTARRAYREF_SLOW]?

@alanwaketan
Copy link
Collaborator

My bad, yea, the upstream PR will need a rebase as well.

alanwaketan and others added 13 commits March 1, 2023 23:43
This reverts commit 2ded7ca.
Summary:
test_exponential should be fixed by upstream: pytorch/pytorch#93053.

Test Plan:
CI.
)

* implement has_hint

* fix linter

* reenable DS tests.
…k_and_unscale_ op tests to python (#4687)

* Move nan_to_num_ and _amp_foreach_non_finite_check_and_unscale_ op tests to python

* Update test call to runAtenTest
…ionalize pass (#4681)

Summary:
For any CompositeExplicitAutograd ops, we are supposed to explicitly re-enable functionalization such that any decomposed ops within those ops get functionalized as well.

However, if directly calling into at::functionalization::functionalize_aten_op, convolution_backward will somehow omit convolution_backward_overridable which is our own kernel to calculate convolution. Thus, no grads are produced.

To workaround the issue, we manually redispatch convolution_backward to functionalize pass.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_conv2d_backward
Summary:
This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor API. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule.

P.S. It also removes two .data tests that are no longer applicable.

Test Plan:
CI.
@alanwaketan alanwaketan marked this pull request as ready for review March 1, 2023 23:46
@alanwaketan
Copy link
Collaborator

I'm preparing to land the branch finally. @yeounoh Please review the SPMD part and make sure it's correct. @vanbasten23 Please review the DS part and make sure it's correct.

@alanwaketan alanwaketan changed the title POC of functionalization integration Functionalization integration Mar 2, 2023
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alanwaketan
Copy link
Collaborator

Thanks Jack and Xiongfei for reviewing.

@alanwaketan alanwaketan merged commit e49df83 into master Mar 2, 2023
mateuszlewko pushed a commit that referenced this pull request Mar 15, 2023
Summary:
This is the consolidated branch of all the functionalization changes. For individual changes, please visit the corresponding PRs for details.

Here is a brief summary of the highlights: Funtionalization is a dispatcher pass introduced in upstream to remove views and mutations from a PyTorch program to produce functional graphs which is better for backend compilers to do their optimizations. We have a in-house infrastructure to do similar tricks as the HLOs we generated are always functional. The benefits of adopting the upstream Funtionalization pass is that we can get rid of the in-house view infrastructure that we have struggled to maintain for years, and let upstream do all those heavy lifting for us.

Implementation details:
1. To enable Funtionalization, we just need to wrap our newly created at::Tensor that holds a XLATensor in a FunctionalWrapper and then return the wrapper to Python. Then for any consecutive ops will then first go to the Funtionalization pass before reaching us. Correspondingly, we then have to unwrap the FunctionalWrapper before getting the XLATensor. Basically a thin layer called FunctionalWrapper is added to the whole aten-xla bridge process. FunctionalWrapper <=> at::Tensor <=> XLATensor.
2. To support the new view ops, for each view op, we have to implement at most two variants: view_copy: the view op but returns a copy instead; view_scatter: sometimes we need extra logic to reapply the updated value to the base view.
3. For in-place ops, we have a new op called _propagate_xla_data to keep any in-place ops optimization we had before active.

Test Plan:
CI
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants