Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Refine CPU fallback for TD+XLA #5000

Merged
merged 26 commits into from
May 30, 2023
Merged

[Dynamo] Refine CPU fallback for TD+XLA #5000

merged 26 commits into from
May 30, 2023

Conversation

wonjoolee95
Copy link
Collaborator

@wonjoolee95 wonjoolee95 commented May 11, 2023

Picks up #4935


Supports unsupported ops to fallback in PyTorch/XLA + dynamo by utilizing CapacityBasedPartitioner.

@wonjoolee95
Copy link
Collaborator Author

Okay, so simply removing the fallback assertions (on master branches) do not cause failures but may produce wrong results. Take a look at this example:

@dynamo.optimize("torchxla_trace_once")
def fn_fallback_unsupported(t):
  # xla currently does not lower aten::median
  return unsupported(t)

def unsupported(t):
  ret = torch.mul(t, 2) # torch.mul is supported by XLA
  final_ret = torch.median(ret) # torch.median is not supported by XLA
  return final_ret

def print_metrics():
  print('CompileTime:', met.metric_data('CompileTime')[0])
  print('ExecuteTime:', met.metric_data('ExecuteTime')[0])
  print('CounterNames:', met.counter_names())

device = xm.xla_device()

# initial trace
a = torch.tensor([1, 2, 3, 4, 5])
a_cpu = unsupported(a)
print('a_cpu:', a_cpu)
a_xla = a.to(device=device)
a_xla_ret = fn_fallback_unsupported(a_xla)
print(a_xla_ret)
print_metrics()

met.clear_counters()
print('ClearedCounterNames:', met.counter_names())
print('-----')

# second time
a_2 = torch.tensor([1, 2, 3, 4, 5])
a_cpu_2 = unsupported(a_2)
print('a_cpu_2:', a_cpu_2)
a_xla_2 = a_2.to(device=device)
a_xla_ret_2 = fn_fallback_unsupported(a_xla_2)
print(a_xla_ret_2)
print_metrics()

met.clear_counters()
print('ClearedCounterNames:', met.counter_names())
print('-----')

# third time
a_3 = torch.tensor([2, 3, 4, 5, 6])
a_cpu_3 = unsupported(a_3)
print('a_cpu_3:', a_cpu_3)
a_xla_3 = a_3.to(device=device)
a_xla_ret_3 = fn_fallback_unsupported(a_xla_3)
print(a_xla_ret_3)
print_metrics()

On master branch, this will produce wrong result on the third run:

a_cpu: tensor(6)
tensor(6, device='xla:1')
CompileTime: 4
ExecuteTime: 4
CounterNames: ['CreateXlaTensor', 'DestroyLtcTensor', 'DestroyXlaTensor', 'DeviceDataCacheMiss', 'UncachedCompile', 'xla::_copy_from', 'xla::_propagate_xla_data', 'xla::_to_copy', 'xla::_to_cpu', 'xla::copy', 'xla::empty_symint', 'xla::mul', 'CreateCompileHandles', 'CreateDataHandles', 'DestroyDataHandles', 'ReleaseDataHandles', 'XRTAllocateFromTensor_Empty', 'aten::median']
ClearedCounterNames: []
-----
a_cpu_2: tensor(6)
tensor(6, device='xla:1')
CompileTime: 4
ExecuteTime: 5
CounterNames: ['CreateXlaTensor', 'xla::_copy_from', 'xla::_to_copy', 'xla::empty_symint', 'CreateDataHandles']
ClearedCounterNames: []
-----
a_cpu_3: tensor(8)
tensor(6, device='xla:1')
CompileTime: 4
ExecuteTime: 6
CounterNames: ['CreateXlaTensor', 'xla::_copy_from', 'xla::_to_copy', 'xla::empty_symint', 'CreateDataHandles']

Note the result tensor(6, device='xla:1') on the third run. The correct result should be tensor(8). The root cause seems to be the graph being wrong hashed.

Now with PR, the result is correct:

a_cpu: tensor(6)
tensor(6, device='xla:1')
CompileTime: 2
ExecuteTime: 4
CounterNames: ['CachedCompile', 'CreateXlaTensor', 'DestroyLtcTensor', 'DestroyXlaTensor', 'xla::_copy_from', 'xla::_to_copy', 'xla::_to_cpu', 'xla::empty_symint', 'xla::mul', 'CreateDataHandles', 'DestroyDataHandles', 'ReleaseDataHandles', 'XrtCompile_Empty', 'XrtExecuteChained_Empty', 'XrtExecute_Empty', 'XrtMemoryInfo_Empty', 'XrtRead_Empty', 'XrtReleaseAllocationHandle_Empty', 'XrtReleaseCompileHandle_Empty', 'XrtSessionCount', 'XrtSubTuple_Empty', 'aten::median']
ClearedCounterNames: []
-----
a_cpu_2: tensor(6)
tensor(6, device='xla:1')
CompileTime: 2
ExecuteTime: 5
CounterNames: ['CachedCompile', 'CreateXlaTensor', 'DestroyLtcTensor', 'DestroyXlaTensor', 'xla::_copy_from', 'xla::_to_copy', 'xla::_to_cpu', 'xla::empty_symint', 'xla::mul', 'CreateDataHandles', 'ReleaseDataHandles', 'aten::median']
ClearedCounterNames: []
-----
a_cpu_3: tensor(8)
tensor(8, device='xla:1')
CompileTime: 2
ExecuteTime: 6
CounterNames: ['CachedCompile', 'CreateXlaTensor', 'DestroyLtcTensor', 'DestroyXlaTensor', 'DeviceDataCacheMiss', 'xla::_copy_from', 'xla::_to_copy', 'xla::_to_cpu', 'xla::empty_symint', 'xla::mul', 'CreateDataHandles', 'DestroyDataHandles', 'ReleaseDataHandles', 'aten::median']

As for next steps, let me review and update the metrics in the failing unit tests. As per last discussion, the CompileTime metric should be increased as we compile once more on the initial trace to fetch all the unsupported ops.

cc @seanlatias, let me know if this makes sense.

@seanlatias
Copy link
Collaborator

@wonjoolee95 Thanks Wonjoo. This is an intereting finding. It also explains why my previous run on accuracy is correct and it's because I didn't try different inputs. Please go ahead an add those metrics for testing. I have some new local changes that I'd like to push to further optimize the process. I'll do that once you finish your editting.

@wonjoolee95
Copy link
Collaborator Author

@seanlatias, just updated the metrics for the failing unit tests and added some comments. Please take a look to see if they make sense.

I also just realized that the newly added test DynamoInPlaceTest.test_inplace_update_correctness is failing for a real reason with this PR. I'll look into this. Meanwhile, feel free to push your changes.

@wonjoolee95
Copy link
Collaborator Author

Ah okay, the problem seems to be that for in-place tests, the execution when we fetch the fallback ops actually update the tensors.

@wonjoolee95
Copy link
Collaborator Author

The latest commit should fix the DynamoInPlaceTest.test_inplace_update_correctness test. The fix is a bit ugly because it just duplicates the code in the extract_internal function, but I'll let this be for now (left a TODO with my name to make it cleaner).

@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented May 17, 2023

The CPU CI is green, but the GPU CI fails for some reason due to precision.

@wonjoolee95 wonjoolee95 marked this pull request as ready for review May 17, 2023 17:26
@wonjoolee95 wonjoolee95 self-assigned this May 17, 2023
@seanlatias
Copy link
Collaborator

@wonjoolee95 I'll push my fix today. Facing some issues setting up the environments with the new code. Will let you know once I'm done.

@wonjoolee95
Copy link
Collaborator Author

@wonjoolee95 I'll push my fix today. Facing some issues setting up the environments with the new code. Will let you know once I'm done.

Sounds good, thanks @seanlatias. Just curious, what is your fix about? I wanted to update some fallback unit tests too, so just want to make sure our changes don't conflict.

@seanlatias
Copy link
Collaborator

My fix will be about adding the metric checks in the unit tests. Also, in the dynamo bridge, we should also check call_method. Previously we only check call_module and call_function.

@seanlatias
Copy link
Collaborator

BTW, it seems I don't have the access to push to the branch. Do I miss something?

@wonjoolee95
Copy link
Collaborator Author

BTW, it seems I don't have the access to push to the branch. Do I miss something?

Just sent a collaborator invite to your account. You should be able to push directly to this branch/PR after accepting the invitation.

@wonjoolee95
Copy link
Collaborator Author

For the slowdown I have some solutions in mind to solve them. We can create a separate PR for that. One is that if the users are certain the whole model is supported, they can turn off the CPU fallback check. Second is to create a cache to avoid checking repeated FX ops.

Great points. We're mostly okay with this amount of slowdown for now, we can move the per improvements to future PRs as you said.

@wonjoolee95
Copy link
Collaborator Author

This is the error I see when trying to import torch_xla.

Traceback (most recent call last):
  File "pytorch/xla/test/dynamo/test_fallback.py", line 2, in <module>
    import torch_xla
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch_xla-2.1.0-py3.8-linux-x86_64.egg/torch_xla/__init__.py", line 134, in <module>
    import _XLAC
ImportError: /opt/conda/envs/py38/lib/python3.8/site-packages/torch_xla-2.1.0-py3.8-linux-x86_64.egg/_XLAC.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZNK5torch8autograd4Node4nameEv

This undefined symbol usually happens when there is a mismatch between the pytorch and pytorch/xla versions. Can you verify that your local pytorch/xla is also up to date?

self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count)
# One graph for fetching the fallback ops.
# Another graph for the resnet18 inference.
self.assertEqual(met.metric_data('CompileTime')[0], 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, haven't read through the code, but we don't need to compile the HLO in order to determine if there is a fallback

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted this back with the latest commit that adds ClearPendingIrs.

# Another graph for the resnet18 inference.
self.assertEqual(met.metric_data('CompileTime')[0], 2)
# Again, +1 offset in ExecuteTime for fetching the fallback ops.
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count + 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, we should not introduce additional exeuction for non-fallback graphs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted this back with the latest commit that adds ClearPendingIrs.

xla_mat2 = mat2.to(xm.xla_device())

cpu_res = fn_fallback(M, mat1, mat2)
xla_res = dynamo_fn(xla_M, xla_mat1, xla_mat2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check counters for CompileTime and ExecuteTimer here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling this with the comment below, working on adding metric checks for all our fallback unit tests.

cpu_res = fn_fallback(M, mat1, mat2, 0.5)
xla_res = dynamo_fn(M, mat1, mat2, 0.5)

self.assertTrue(torch.allclose(cpu_res, xla_res.cpu()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, counter check for compilation, execution. I think we also want to check that no aten:: counter is accumulated. Let me keep reading and figure how is the fallback op being executed.

Copy link
Collaborator

@JackCaoG JackCaoG May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading through the pr, it is still unclear to me how fallback op is being handled. Will be execute on CPU on the pytorch end? or it will be execute as lazy and go through our fallback op handling logic? I hope it is the former.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I think currently it is the latter. We do not move the tensors back-and-forth between cpu and xla device in our fallback logic. So we still get aten counters when executing the partitioned graph. Or do you think we should move the tensors explicitly instead of letting the lazy execution do it?

cc @wonjoolee95

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To verify, we do see the xla::_to_cpu counter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that also explains why the CompileTime metric increases by one when seeing a fallback op.

Copy link
Collaborator

@seanlatias seanlatias May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify:

  • CompileTime increases by one for unsupported op when testing if an op goes through CPU fallback or not
  • ExecuteTime increases
    • Whenever an unsupported op is executed through lazy CPU fallback (e.g., in FallBackNodeCollector and InputCollector)
    • Whenever a compiled subgraph is executed

But with these, the final metric numbers still do not match. I'm still looking into the problem.

Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should see the aten:: metrics when we execute the unsupported ops.

@seanlatias, which final metric numbers do not match, the metric numbers in the fallback unit tests?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the test_operator_fallback, I see CompileTime to be 2, which makes sense. However, with ExecuteTime, it is 5. I still couldn't explain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was playing a bit with the unit test case, and making it into a simpler test case. For what it's worth, with a simpler unit test like below:

  def test_operator_fallback(self):

    def fn_fallback(t):
      # As of 05/18/2023, torch.median is not lowered by PyTorch/XLA
      return torch.median(t)

    torch._dynamo.reset()
    met.clear_counters()
    device = xm.xla_device()

    dynamo_fn = torch.compile(fn_fallback, backend="torchxla_trace_once")
    t = torch.randn(5)
    t_xla = t.to(device)

    cpu_res = fn_fallback(t)
    xla_res = dynamo_fn(t_xla)

    print('CompileTime:', met.metric_data('CompileTime')[0])
    print('ExecuteTime:', met.metric_data('ExecuteTime')[0])

    self.assertTrue(torch.allclose(cpu_res, xla_res.cpu()))

I was able to see:

CompileTime: 2
ExecuteTime: 2

Let me look into the existing test_operator_fallback with the cummin op.

for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args):
if isinstance(xla_arg, torch.Tensor):
xla_arg.copy_(cloned_xla_arg)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you call ClearPendingIr here you will avoid the unncessary CompileTime and ExecuteTime.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code to call ClearPendingIr here. And also moved the xm.mark_step to beginning of this function instead. Now, the metrics in the unit tests are left unchanged as expected.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, I think if we can call ClearPendingIr in right place, it should not regress the test_dynamo.py

@wonjoolee95
Copy link
Collaborator Author

@JackCaoG, could you take a look at this again? Addressed the comments from last review to add ClearPendingIrs to fix the test_dynamo.py regressions and added asserts/metric checks to the DynamoCpuFallbackTest unit tests. These DynamoCpuFallbackTest tests should be enough to cover correctness of the fallback mechanism, although I still do not completely understand the reasoning behind the excessively increased Execute counters. I'll look into them, but meanwhile, this PR should be review-able.

Comment on lines +142 to +143
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm why is it 3 here instead of 4? Wouldn't dynamo_fn fallback and execute 2 execution?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, I think I understand why t_xla * 3 would trigger a separate Compilation. t_xla * 3 is actually a pending execution and we will call mark_step to materialize the input. If that's the case I don't understand why ExecuteTime is 3 then, it should be 5 I guess?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here's how the IR dump looks like when I run DynamoCpuFallbackTest.test_operator_fallback locally: https://gist.github.com/wonjoolee95/1426859ef0a9203dca71ad455e4badc8. This is also what I'm trying to figure out 😢

Another odd thing is that when I run the fallback tests individuals, as such: python test/dynamo/test_dynamo.py DynamoCpuFallbackTest.test_operator_fallback and python test/dynamo/test_dynamo.py DynamoCpuFallbackTest.test_fallback_multiple_submodules. However, when I run them in a single run by running python test/dynamo/test_dynamo.py, they fail due to metric assertions (same as the failure in the CI). This makes me think there are possibly some pending IRs somewhere, but I tried to manually invoke torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) at the end of each unit test but still seeing the error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wonjoolee95 I think we need to reset the metrics for each test. Similar to here:

met.clear_all()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... but I also met similar problem when adding clear_all(). The behavior of a single test and a set of tests are different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still a bit confuse about

on the 2nd tracing, we can see that both the CompileTime and ExecuteTime remain the same as the 1st tracing because the graph with the fallback op is already captured

ExecuteTime should increase when a dynamo execution happened. It should increase both during the first dynamo run and subsequent execution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Let's call met.clear_all() instead of met.clear_counters, clear counter won't reset metrics.

I would expect that every call to dynamo_fn will trigger at least 2 execution, since there is a fallback op in the middle and XLA needs to execute the graph before and after the fallback op.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also trying completely understand why these metrics are having such numbers, but here is what I understand -- please let me know if there is anything that doesn't sound right.

ExecuteTime should increase when a dynamo execution happened. It should increase both during the first dynamo run and subsequent execution.

ExecuteTime increases only when there is an XLA execution. However, in this example, there is only one single aten::median op that is executed by CPU, so the ExecuteTime doesn't increase on the 2nd tracing.

I would expect that every call to dynamo_fn will trigger at least 2 execution, since there is a fallback op in the middle and XLA needs to execute the graph before and after the fallback op.

On the 2nd and 3rd tracing, the torchxla.py already has a compiled_graph from the 1st tracing that looks like:

class GraphModule(torch.nn.Module):
    def forward(self, L_t_ : torch.Tensor):
        l_t_ = L_t_
        
        # File: wonjoo_2.py:28, code: return torch.median(t)
        median = torch.median(l_t_);  l_t_ = None
        return (median,)

Now, since the only op in this is median that is executed by CPU, executing this graph with compiled_graph(*args) does not increase ExecuteTime. In the 3rd tracing with t_xla * 3, the only graph that XLA needs to execute before the aten::median op is the x_xla * 3, hence increasing ExecuteTime only by 1.

This is what I think is happening.. let me know if that makes sense, @JackCaoG . And @seanlatias, also let me know if this aligns with your understanding, just to make sure we're on the same page.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly agree with @wonjoolee95. Following is my explanation for each trace.

  • The first trace includes two parts: checking CPU fallback OP for compilation and run the fallback OP with the compiled graph. Both parts involve compiling and executing torch.median(). Thus, the initial CompileTime and ExecuteTime are 2.
  • The second trace does not introduce any changes: same input & same module. Thus, nothing needs to be compiled and executed again. This is handled by the cached computation in XLA.
  • The third trace does not change the input module. Thus, the module does not need to be recompiled. However, it changes the input for the fallback OP. Thus, the fallback OP needs to be recompiled and re-exectued (remember that we said the fallback OP goes through the lazy tensor execution instead of being executed by PyTorch directly). That's why we see an increase in both CompileTime and ExecuteTime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, thanks.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, has one question regarding the test

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, has one question regarding the test

@wonjoolee95
Copy link
Collaborator Author

@JackCaoG, updated the comments above, should be ready for one more review.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @wonjoolee95 and @seanlatias . This is great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants