Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for PyTorch/XLA functionalization integration #88787

Closed
wants to merge 11 commits into from

Conversation

wonjoolee95
Copy link
Collaborator

Picking up #88506

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 10, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88787

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 1f72367:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@wonjoolee95 wonjoolee95 force-pushed the functionalization branch 2 times, most recently from 612fc0a to df3bf42 Compare November 18, 2022 19:26
@wonjoolee95 wonjoolee95 force-pushed the functionalization branch 2 times, most recently from cfc2922 to 86f8b11 Compare December 6, 2022 23:30
@alanwaketan
Copy link
Collaborator

@wonjoolee95 You probably want to hide your debug hints while running the CI here. It makes the log too large and impossible to parse over the browser. Also, a rebase will be appreciated.

@alanwaketan
Copy link
Collaborator

I guess it's better to log those information with proper log level control such that you can still use it locally for debugging but it won't increase the test log size dramatically.

@wonjoolee95 wonjoolee95 force-pushed the functionalization branch 2 times, most recently from bd0a882 to 2197973 Compare December 14, 2022 09:48
@wonjoolee95
Copy link
Collaborator Author

Hmm, seems like some functorch tests are still failing even after applying the diff generated by EXPECTTEST_ACCEPT=1. I'll try it one more time. If they're still failing, I'll fix the rest manually.

@alanwaketan
Copy link
Collaborator

It looks like those machines are with gpu devices. Do you have a gpu env? I guess you need to run those tests on a gpu env. Otherwise, the tests will be skipped.

@wonjoolee95
Copy link
Collaborator Author

Please correct me if I'm wrong, seems like only the first two failing tests are gpu devices? I'll wait for the rest of the CI to complete and then first all the non-gpu tests first.

@wonjoolee95
Copy link
Collaborator Author

With the latest commit, all the TestAOTAutograd tests should succeed:

(base) jenkins@26d7adccbc26:/workspace/pytorch$ python test/functorch/test_aotdispatch.py TestAOTAutograd
/opt/conda/lib/python3.7/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libc10_cuda.so: cannot open shared object file: No such file or directory
  warn(f"Failed to load image Python extension: {e}")
ss2022-12-15 00:26:27.721580: W 2153590 tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-12-15 00:26:27.721700: W 2153590 tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
./opt/conda/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py:919: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
...test/functorch/test_aotdispatch.py:241: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /workspace/pytorch/build/aten/src/ATen/core/TensorBody.h:485.)
  grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
...................................s...../opt/conda/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py:919: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
..../opt/conda/lib/python3.7/site-packages/torch/_functorch/aot_autograd.py:919: UserWarning: Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
...
----------------------------------------------------------------------
Ran 54 tests in 3.871s

OK (skipped=3)
(base) jenkins@26d7adccbc26:/workspace/pytorch$ 

I'm still unsure about the previous GPU failure that failed with:

/var/lib/jenkins/multipy/multipy/runtime/../../multipy/runtime/interpreter/builtin_registry.h:31:10: fatal error: gtest/gtest_prod.h: No such file or directory

But I'll let the CI to run one more time before spending more time on the GPU test.

@alanwaketan
Copy link
Collaborator

That failure is very likely unrelated.

@alanwaketan
Copy link
Collaborator

According to hud, the deploy failure shouldn't be related.

@@ -347,7 +347,6 @@ def emit_view_functionalization_body(
}}
);
auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Is this guard safe to remove? With it, I crashed on xla symbolic expand:

root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=CPU python test/test_dynamic_shapes.py -v TestDynamicShapes.test_simple_expand
test_simple_expand (__main__.TestDynamicShapes) ... ERROR

======================================================================
ERROR: test_simple_expand (__main__.TestDynamicShapes)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_dynamic_shapes.py", line 18, in test_simple_expand
    t5.expand(t2.size(0))
RuntimeError: /workspaces/work/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp:2109: SymIntArrayRef expected to contain only concrete integers

----------------------------------------------------------------------
Ran 1 test in 0.022s

FAILED (errors=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, it regresses. It's not safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alanwaketan so, the idea motivating this bit if code is the following:

  • pytorch/XLA doesn't care about strides, so when comparing a pytorch program when run on CUDA vs XLA, the user will witness different strides on the tensors throughout their program
  • functionalization gives XLA the ability to fix that problem; XLA can choose to not care about the value of strides in all of its kernels, but functionalization can run the meta function for every ATen, to properly set the strides

One question here is - do you think that's a benefit worth trying to capture for pytorch/XLA (stride correctness, for the user's perspective, for XLA tensors)? I'd be interested in @JackCaoG 's opinion.

Our options are either to:
(1) kill that code, and not bother trying to get correct strides
(2) make it more robust so it works on this test

In this test, it looks like you're using dynamic shapes, and the meta function we're calling doesn't play well with dynamic shapes. The way that the dynamic shapes workstream in core has been handling this is that we have python implementations / decompositons of a bunch of our ops, that we want to run when dynamic shapes are enabled. And it looks like... for some reason we aren't calling that python impl, and are instead calling the C++ one?

There's probably a better way to arrange for this to work with XLA, but one option option is enable the python dispatcher in your test, which should override a bunch of C++ meta kernels with their python equivalents:

from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
    test()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the follow up on the xla side: pytorch/xla#4448.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like by enabling python dispatcher and implementing missing sym size ops can workaround this. But then it brings a bigger question whether we should enable python dispatcher for dynamic shapes or not.

@wonjoolee95
Copy link
Collaborator Author

Rebased this and the XLA POC PR with master. The aodispatch tests here will fail, as the master branch had some changes as well. I'll follow-up later to fix those.

@wonjoolee95
Copy link
Collaborator Author

Bunch of dynamo related tests failing now with messages:

 Error: trace_fork_wait_inline (jit.test_async.TestAsync) ... [2023-01-25 03:37:54,841] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT <graph break in test_trace_fork_wait_inline> /var/lib/jenkins/workspace/test/jit/test_async.py line 416 

The last commit does touch dynamo related code but the CI used to be green even with that commit. While I try to reproduce this locally, let me also try to rebase from master and re-run this CI.

@alanwaketan
Copy link
Collaborator

Bunch of dynamo related tests failing now with messages:

 Error: trace_fork_wait_inline (jit.test_async.TestAsync) ... [2023-01-25 03:37:54,841] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT <graph break in test_trace_fork_wait_inline> /var/lib/jenkins/workspace/test/jit/test_async.py line 416 

The last commit does touch dynamo related code but the CI used to be green even with that commit. While I try to reproduce this locally, let me also try to rebase from master and re-run this CI.

We can always use hud.pytorch.org to determine if the test is broken in tip of the tree.

@wonjoolee95
Copy link
Collaborator Author

Thanks for the info, Jiewen. Looks like master also was seeing the same issue for a while:
Screenshot 2023-01-25 at 1 49 25 PM
But seems like a recent commit fixed it, so this PR's CI should be good now. I'll let the CI verify.

@@ -14580,3 +14580,6 @@
dispatch:
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw, _fused_adamw.out

- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I have added the op you suggested in pytorch/xla#4505 (comment). Please review. You can just check the commit called: [Functionalization] Adds _propagate_xla_data.

@wonjoolee95
Copy link
Collaborator Author

Moving to a new PR with a new branch -- #94537. Marking this one closed.

@wonjoolee95 wonjoolee95 closed this Feb 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants