Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalization increases tracing time #6294

Closed
wonjoolee95 opened this issue Jan 11, 2024 · 7 comments · Fixed by #6339
Closed

Functionalization increases tracing time #6294

wonjoolee95 opened this issue Jan 11, 2024 · 7 comments · Fixed by #6339
Assignees

Comments

@wonjoolee95
Copy link
Collaborator

wonjoolee95 commented Jan 11, 2024

Moving https://dev-discuss.pytorch.org/t/decomposition-slows-down-the-lazy-tensor-tracing/1788 to GitHub Issue.

Original author: @anw90

Original question

The decomposition slows down the lazy tensor tracing when running with TorchXLA. Here is the timeline with stack info:
image

Removing the decomposition-related code resolves the issue:
image

This is the timeline after removing the decomposition-related code:
image

As opposed to Torch native job has no such issue:
image

Summary

We saw that disabling functionalization pass by using XLA_DISABLE_FUNCTIONALIZATION=1 brings the tracing becomes normal. As a result, functionalization pass in PyTorch/XLA may be introducing unnecessarily large tracing time.

Reproducible code

Patch the code below to reproduce the slow tracing issue. Here is the runing command:
XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1:

diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py
index 43c4c96..7af3bdc 100644
--- a/test/test_train_mp_imagenet.py
+++ b/test/test_train_mp_imagenet.py
@@ -266,10 +266,9 @@ def train_imagenet():
   writer = None
   if xm.is_master_ordinal():
     writer = test_utils.get_summary_writer(FLAGS.logdir)
-  optimizer = optim.SGD(
+  optimizer = optim.AdamW(
       model.parameters(),
       lr=FLAGS.lr,
-      momentum=FLAGS.momentum,
       weight_decay=1e-4)
   num_training_steps_per_epoch = train_dataset_len // (
       FLAGS.batch_size * xm.xrt_world_size())
@@ -289,6 +288,11 @@ def train_imagenet():
   def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
+    prof = torch.profiler.profile(
+          activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
+          schedule=torch.profiler.schedule(wait=2, warmup=2, active=3),
+          with_stack=True,
+          on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile"))
     for step, (data, target) in enumerate(loader):
       with xp.StepTrace('train_imagenet'):
         with xp.Trace('build_graph'):
@@ -306,6 +310,8 @@ def train_imagenet():
         if step % FLAGS.log_steps == 0:
           xm.add_step_closure(
               _train_update, args=(device, step, loss, tracker, epoch, writer))
+      xm.mark_step()
+      prof.step()

   def test_loop_fn(loader, epoch):
     total_samples, correct = 0, 0

Version info

torch_xla.version.__xla_gitrev__: 'efa6fcfdac5368330a0770e9019649eba08b5f56'
torch_xla.version.__torch_gitrev__: 'f6dfbffb3bb46ada6fe66b5da4f989f9d4d69b3c'

cc @alanwaketan

@wonjoolee95 wonjoolee95 self-assigned this Jan 11, 2024
@wonjoolee95
Copy link
Collaborator Author

As noted by @anw90 at https://dev-discuss.pytorch.org/t/decomposition-slows-down-the-lazy-tensor-tracing/1788/5, this issue may be related to the MaybeWrapTensorToFunctional function at

at::Tensor MaybeWrapTensorToFunctional(const at::Tensor& tensor) {

After being wrapped into a functional tensor, torch dispatches the op back to the python level in torch::impl::dispatch::PythonKernelHolder, which may cause slower tracing.

@wonjoolee95
Copy link
Collaborator Author

Looking more into this, the issue here seems to be due to decompositions in aten::lerp. Doing a quick time measurement on aten::lerp using a simple script like this https://gist.github.com/wonjoolee95/a1baa9e5908e85bd80c2c066fe9a0917:

  • PyTorch lerp: 0.011274337768554688s
  • PyTorch/XLA lerp (with func): 0.1137242317199707s
  • PyTorch/XLA lerp (without func): 0.07381987571716309s

PyTorch/XLA with functionalization is ~10x slower than PyTorch. PyTorch/XLA with functionalization disabled is a bit better but still ~7x slower than PyTorch.

The captured profile in the original issue shows the dispatches. For dispatch trace, I've attached them below:

Looking at the profile and the dispatch trace, there does not seem to be specific decompositions that take extraordinarily long. Rather, they are decomposed mostly evenly that each slow down PyTorch/XLA's aten::lerp. Given that PyTorch/XLA's implementation is still ~7x slower without functionalization compared to native PyTorch, the problem is likely not related to functionalization itself, however, @bdhirsh, would you have any sight into why such decomposition is happening?

@bdhirsh
Copy link
Collaborator

bdhirsh commented Jan 18, 2024

Hmm @wonjoolee95 - looking at the traces, they actually look almost identical between "XLA w/ functionalization" and "XLA w/o functionalization".

You can see from the snippet below that functionalization itself is not decomposing lerp: we go all the way to XLA's lowering for at::lerp, which looks like it decomposes into several CPU ops:

 [call] op=[aten::lerp.Scalar], key=[AutogradXLA]
  [redispatch] op=[aten::lerp.Scalar], key=[Functionalize]
   [callBoxed] op=[aten::lerp.Scalar], key=[XLA]
    [call] op=[aten::scalar_tensor], key=[BackendSelect]
     [redispatch] op=[aten::scalar_tensor], key=[CPU]
    [call] op=[aten::to.dtype_layout], key=[BackendSelect]
     [redispatch] op=[aten::to.dtype_layout], key=[CPU]
      [call] op=[aten::_to_copy], key=[BackendSelect]
       [redispatch] op=[aten::_to_copy], key=[CPU]
        [call] op=[aten::empty_strided], key=[BackendSelect]
         [redispatch] op=[aten::empty_strided], key=[CPU]
        [call] op=[aten::copy_], key=[CPU]
    [call] op=[aten::to.dtype_layout], key=[BackendSelect]
     [redispatch] op=[aten::to.dtype_layout], key=[CPU]
      [call] op=[aten::_to_copy], key=[BackendSelect]
       [redispatch] op=[aten::_to_copy], key=[CPU]
        [call] op=[aten::empty.memory_format], key=[BackendSelect]
         [redispatch] op=[aten::empty.memory_format], key=[CPU]
        [call] op=[aten::copy_], key=[CPU]

Compared to the dispatcher operations for "plain" pytorch eager mode for lerp (no extra dispatcher calls, there is a direct CPU implementation for the entire lerp.Scalar op):

 [call] op=[aten::lerp.Scalar], key=[AutogradCPU]
  [redispatch] op=[aten::lerp.Scalar], key=[CPU]

If you want to limit the number of dispatcher hops for lerp for XLA, maybe you can avoid calling into more aten ops in the XLA lowering? Although that doesn't explain why XLA is any slower with functionalization

@bdhirsh
Copy link
Collaborator

bdhirsh commented Jan 18, 2024

Hmm. The profile trace in the original post https://dev-discuss.pytorch.org/t/decomposition-slows-down-the-lazy-tensor-tracing/1788 seems to indicate that the lerp decomposition here is being run in XLA's hot path? Although I don't see that decomposition of ATen ops in the "dispatch trace" linked above (since it includes ops like aten.abs and aten.where).

@wonjoolee95
Copy link
Collaborator Author

Thanks for the pointers, @bdhirsh! The profile trace in the original's post is different from the code that was used to capture the dispatch trace -- the code used to capture the dispatch trace is a simplified version of just calling a single aten::lerp, whereas the profile trace in the original post was part of the resnet unit test. One question is -- when you mention "being run in XLA's hot path", does this differ from the original decomposition that's shown in the dispatch trace?

@wonjoolee95
Copy link
Collaborator Author

Looking back at PyTorch/XLA's lowering of lerp -- https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/ops.cpp#L768, the issue may be due to how this lowering is implemented in IR level. We probably want a proper lowering to avoid these unnecessary decompositions. Working on a fix now.

@wonjoolee95
Copy link
Collaborator Author

Synced with @anw90 offline, we saw that the aten::lerp lowering itself was not enough to entirely reduce the tracing time slowdown. Looking into our op lowerings, I saw that many of our existing lowerings do IR-level lowerings. I've opened #6589 to track this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants