-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamo] Refine CPU fallback for TD+XLA #5000
Conversation
Okay, so simply removing the fallback assertions (on master branches) do not cause failures but may produce wrong results. Take a look at this example:
On master branch, this will produce wrong result on the third run:
Note the result Now with PR, the result is correct:
As for next steps, let me review and update the metrics in the failing unit tests. As per last discussion, the cc @seanlatias, let me know if this makes sense. |
@wonjoolee95 Thanks Wonjoo. This is an intereting finding. It also explains why my previous run on accuracy is correct and it's because I didn't try different inputs. Please go ahead an add those metrics for testing. I have some new local changes that I'd like to push to further optimize the process. I'll do that once you finish your editting. |
@seanlatias, just updated the metrics for the failing unit tests and added some comments. Please take a look to see if they make sense. I also just realized that the newly added test |
Ah okay, the problem seems to be that for in-place tests, the execution when we fetch the fallback ops actually update the tensors. |
The latest commit should fix the |
The CPU CI is green, but the GPU CI fails for some reason due to precision. |
@wonjoolee95 I'll push my fix today. Facing some issues setting up the environments with the new code. Will let you know once I'm done. |
Sounds good, thanks @seanlatias. Just curious, what is your fix about? I wanted to update some fallback unit tests too, so just want to make sure our changes don't conflict. |
My fix will be about adding the metric checks in the unit tests. Also, in the dynamo bridge, we should also check |
BTW, it seems I don't have the access to push to the branch. Do I miss something? |
Just sent a collaborator invite to your account. You should be able to push directly to this branch/PR after accepting the invitation. |
…have XLAData or IR
Great points. We're mostly okay with this amount of slowdown for now, we can move the per improvements to future PRs as you said. |
This |
test/dynamo/test_dynamo.py
Outdated
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count) | ||
# One graph for fetching the fallback ops. | ||
# Another graph for the resnet18 inference. | ||
self.assertEqual(met.metric_data('CompileTime')[0], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, haven't read through the code, but we don't need to compile the HLO in order to determine if there is a fallback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted this back with the latest commit that adds ClearPendingIrs
.
test/dynamo/test_dynamo.py
Outdated
# Another graph for the resnet18 inference. | ||
self.assertEqual(met.metric_data('CompileTime')[0], 2) | ||
# Again, +1 offset in ExecuteTime for fetching the fallback ops. | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, we should not introduce additional exeuction for non-fallback graphs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted this back with the latest commit that adds ClearPendingIrs
.
test/dynamo/test_fallback.py
Outdated
xla_mat2 = mat2.to(xm.xla_device()) | ||
|
||
cpu_res = fn_fallback(M, mat1, mat2) | ||
xla_res = dynamo_fn(xla_M, xla_mat1, xla_mat2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we check counters for CompileTime
and ExecuteTimer
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handling this with the comment below, working on adding metric checks for all our fallback unit tests.
test/dynamo/test_fallback.py
Outdated
cpu_res = fn_fallback(M, mat1, mat2, 0.5) | ||
xla_res = dynamo_fn(M, mat1, mat2, 0.5) | ||
|
||
self.assertTrue(torch.allclose(cpu_res, xla_res.cpu())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, counter check for compilation, execution. I think we also want to check that no aten::
counter is accumulated. Let me keep reading and figure how is the fallback op being executed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading through the pr, it is still unclear to me how fallback op is being handled. Will be execute on CPU on the pytorch end? or it will be execute as lazy and go through our fallback op handling logic? I hope it is the former.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point. I think currently it is the latter. We do not move the tensors back-and-forth between cpu and xla device in our fallback logic. So we still get aten
counters when executing the partitioned graph. Or do you think we should move the tensors explicitly instead of letting the lazy execution do it?
cc @wonjoolee95
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To verify, we do see the xla::_to_cpu
counter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that also explains why the CompileTime
metric increases by one when seeing a fallback op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify:
CompileTime
increases by one for unsupported op when testing if an op goes through CPU fallback or notExecuteTime
increases- Whenever an unsupported op is executed through lazy CPU fallback (e.g., in
FallBackNodeCollector
andInputCollector
) - Whenever a compiled subgraph is executed
- Whenever an unsupported op is executed through lazy CPU fallback (e.g., in
But with these, the final metric numbers still do not match. I'm still looking into the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we should see the aten::
metrics when we execute the unsupported ops.
@seanlatias, which final metric numbers do not match, the metric numbers in the fallback unit tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the test_operator_fallback
, I see CompileTime
to be 2, which makes sense. However, with ExecuteTime
, it is 5. I still couldn't explain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was playing a bit with the unit test case, and making it into a simpler test case. For what it's worth, with a simpler unit test like below:
def test_operator_fallback(self):
def fn_fallback(t):
# As of 05/18/2023, torch.median is not lowered by PyTorch/XLA
return torch.median(t)
torch._dynamo.reset()
met.clear_counters()
device = xm.xla_device()
dynamo_fn = torch.compile(fn_fallback, backend="torchxla_trace_once")
t = torch.randn(5)
t_xla = t.to(device)
cpu_res = fn_fallback(t)
xla_res = dynamo_fn(t_xla)
print('CompileTime:', met.metric_data('CompileTime')[0])
print('ExecuteTime:', met.metric_data('ExecuteTime')[0])
self.assertTrue(torch.allclose(cpu_res, xla_res.cpu()))
I was able to see:
CompileTime: 2
ExecuteTime: 2
Let me look into the existing test_operator_fallback
with the cummin
op.
for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args): | ||
if isinstance(xla_arg, torch.Tensor): | ||
xla_arg.copy_(cloned_xla_arg) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you call ClearPendingIr
here you will avoid the unncessary CompileTime
and ExecuteTime
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated code to call ClearPendingIr
here. And also moved the xm.mark_step
to beginning of this function instead. Now, the metrics in the unit tests are left unchanged as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm, I think if we can call ClearPendingIr
in right place, it should not regress the test_dynamo.py
@JackCaoG, could you take a look at this again? Addressed the comments from last review to add |
self.assertEqual(met.metric_data('CompileTime')[0], 3) | ||
self.assertEqual(met.metric_data('ExecuteTime')[0], 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm why is it 3 here instead of 4? Wouldn't dynamo_fn
fallback and execute 2 execution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, I think I understand why t_xla * 3
would trigger a separate Compilation. t_xla * 3
is actually a pending execution and we will call mark_step
to materialize the input. If that's the case I don't understand why ExecuteTime
is 3 then, it should be 5 I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So here's how the IR dump looks like when I run DynamoCpuFallbackTest.test_operator_fallback
locally: https://gist.github.com/wonjoolee95/1426859ef0a9203dca71ad455e4badc8. This is also what I'm trying to figure out 😢
Another odd thing is that when I run the fallback tests individuals, as such: python test/dynamo/test_dynamo.py DynamoCpuFallbackTest.test_operator_fallback
and python test/dynamo/test_dynamo.py DynamoCpuFallbackTest.test_fallback_multiple_submodules
. However, when I run them in a single run by running python test/dynamo/test_dynamo.py
, they fail due to metric assertions (same as the failure in the CI). This makes me think there are possibly some pending IRs somewhere, but I tried to manually invoke torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
at the end of each unit test but still seeing the error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wonjoolee95 I think we need to reset the metrics for each test. Similar to here:
xla/test/dynamo/test_dynamo.py
Line 88 in 3150573
met.clear_all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm... but I also met similar problem when adding clear_all()
. The behavior of a single test and a set of tests are different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still a bit confuse about
on the 2nd tracing, we can see that both the CompileTime and ExecuteTime remain the same as the 1st tracing because the graph with the fallback op is already captured
ExecuteTime
should increase when a dynamo execution happened. It should increase both during the first dynamo run and subsequent execution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Let's call
met.clear_all()
instead ofmet.clear_counters
, clear counter won't reset metrics.
I would expect that every call to dynamo_fn
will trigger at least 2 execution, since there is a fallback op in the middle and XLA needs to execute the graph before and after the fallback op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also trying completely understand why these metrics are having such numbers, but here is what I understand -- please let me know if there is anything that doesn't sound right.
ExecuteTime should increase when a dynamo execution happened. It should increase both during the first dynamo run and subsequent execution.
ExecuteTime increases only when there is an XLA execution. However, in this example, there is only one single aten::median
op that is executed by CPU, so the ExecuteTime
doesn't increase on the 2nd tracing.
I would expect that every call to dynamo_fn will trigger at least 2 execution, since there is a fallback op in the middle and XLA needs to execute the graph before and after the fallback op.
On the 2nd and 3rd tracing, the torchxla.py
already has a compiled_graph
from the 1st tracing that looks like:
class GraphModule(torch.nn.Module):
def forward(self, L_t_ : torch.Tensor):
l_t_ = L_t_
# File: wonjoo_2.py:28, code: return torch.median(t)
median = torch.median(l_t_); l_t_ = None
return (median,)
Now, since the only op in this is median
that is executed by CPU, executing this graph with compiled_graph(*args)
does not increase ExecuteTime
. In the 3rd tracing with t_xla * 3
, the only graph that XLA needs to execute before the aten::median
op is the x_xla * 3
, hence increasing ExecuteTime
only by 1.
This is what I think is happening.. let me know if that makes sense, @JackCaoG . And @seanlatias, also let me know if this aligns with your understanding, just to make sure we're on the same page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly agree with @wonjoolee95. Following is my explanation for each trace.
- The first trace includes two parts: checking CPU fallback OP for compilation and run the fallback OP with the compiled graph. Both parts involve compiling and executing
torch.median()
. Thus, the initialCompileTime
andExecuteTime
are 2. - The second trace does not introduce any changes: same input & same module. Thus, nothing needs to be compiled and executed again. This is handled by the cached computation in XLA.
- The third trace does not change the input module. Thus, the module does not need to be recompiled. However, it changes the input for the fallback OP. Thus, the fallback OP needs to be recompiled and re-exectued (remember that we said the fallback OP goes through the lazy tensor execution instead of being executed by PyTorch directly). That's why we see an increase in both
CompileTime
andExecuteTime
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm, has one question regarding the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm, has one question regarding the test
@JackCaoG, updated the comments above, should be ready for one more review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @wonjoolee95 and @seanlatias . This is great work!
Picks up #4935
Supports unsupported ops to fallback in PyTorch/XLA + dynamo by utilizing
CapacityBasedPartitioner
.