-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functionalization increases tracing time #6294
Comments
As noted by @anw90 at https://dev-discuss.pytorch.org/t/decomposition-slows-down-the-lazy-tensor-tracing/1788/5, this issue may be related to the xla/torch_xla/csrc/torch_util.cpp Line 67 in 9f1afbd
After being wrapped into a functional tensor, torch dispatches the op back to the python level in torch::impl::dispatch::PythonKernelHolder, which may cause slower tracing. |
Looking more into this, the issue here seems to be due to decompositions in
PyTorch/XLA with functionalization is ~10x slower than PyTorch. PyTorch/XLA with functionalization disabled is a bit better but still ~7x slower than PyTorch. The captured profile in the original issue shows the dispatches. For dispatch trace, I've attached them below:
Looking at the profile and the dispatch trace, there does not seem to be specific decompositions that take extraordinarily long. Rather, they are decomposed mostly evenly that each slow down PyTorch/XLA's |
Hmm @wonjoolee95 - looking at the traces, they actually look almost identical between "XLA w/ functionalization" and "XLA w/o functionalization". You can see from the snippet below that functionalization itself is not decomposing lerp: we go all the way to XLA's lowering for
Compared to the dispatcher operations for "plain" pytorch eager mode for lerp (no extra dispatcher calls, there is a direct CPU implementation for the entire
If you want to limit the number of dispatcher hops for lerp for XLA, maybe you can avoid calling into more aten ops in the XLA lowering? Although that doesn't explain why XLA is any slower with functionalization |
Hmm. The profile trace in the original post https://dev-discuss.pytorch.org/t/decomposition-slows-down-the-lazy-tensor-tracing/1788 seems to indicate that the lerp decomposition here is being run in XLA's hot path? Although I don't see that decomposition of ATen ops in the "dispatch trace" linked above (since it includes ops like |
Thanks for the pointers, @bdhirsh! The profile trace in the original's post is different from the code that was used to capture the dispatch trace -- the code used to capture the dispatch trace is a simplified version of just calling a single |
Looking back at PyTorch/XLA's lowering of lerp -- https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/ops.cpp#L768, the issue may be due to how this lowering is implemented in IR level. We probably want a proper lowering to avoid these unnecessary decompositions. Working on a fix now. |
Moving https://dev-discuss.pytorch.org/t/decomposition-slows-down-the-lazy-tensor-tracing/1788 to GitHub Issue.
Original author: @anw90
Original question
The decomposition slows down the lazy tensor tracing when running with TorchXLA. Here is the timeline with stack info:

Removing the decomposition-related code resolves the issue:

This is the timeline after removing the decomposition-related code:

As opposed to Torch native job has no such issue:

Summary
We saw that disabling functionalization pass by using
XLA_DISABLE_FUNCTIONALIZATION=1
brings the tracing becomes normal. As a result, functionalization pass in PyTorch/XLA may be introducing unnecessarily large tracing time.Reproducible code
Patch the code below to reproduce the slow tracing issue. Here is the runing command:
XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1:
Version info
cc @alanwaketan
The text was updated successfully, but these errors were encountered: