-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core ATen Opset] Lower aten_pixel_shuffle #5886
Comments
As @zpcore mentioned in an offline sync, the behavior of this op seems a bit odd: For aten op pixel_shuffle. When dispatch in XLA, the op shows different behavior when calling it through C++ ( xla/test/cpp/test_aten_xla_tensor_4.cpp Line 1224 in 657b692
Below is the sample code I use for testing pixel_shuffle in python: import torch The output will always be tensor with all value 0, which is different from run it through C++. |
We got the following HLO:
We got the following HLO:
As we can see, CPP and Python generated different HLO graphs. The op |
This is something to do with
@bdhirsh to see why Update 03/19: I forgot the mention that for some reason the input tensor is dropped as is shown in |
any updates on this issue? |
No, let me follow up with @bdhirsh . |
cc @jiawenliu64 as this op appears to fail tests when Functionalization flag is enabled |
Hey! The API run at xla/torch_xla/csrc/aten_xla_type.cpp Lines 3639 to 3645 in 5a113af
pixel_shuffle is registered there.
My guess for the python vs C++ difference is that there's a C++ decomposition registered in core here (that secretly runs some view ops like And when you use python, In theory... both of those decomps should be correct. Does one of the IR's look wrong? Alternatively: you can definitely remove the |
Thanks for the response, @bdhirsh! So the original reason why we wanted to lower this was that it was marked as a "core" op in pytorch native_functions.yaml and we wanted torch_xla to support lowerings for all core ops. According to my initial understanding, I thought a "core" op would not be decomposed into any further ops, is this the case? But according to https://github.com/pytorch/pytorch/blob/58047205ed098c04ec045e66fc39dcc70b60600b/torch/_refs/nn/functional/__init__.py#L1169, it appears to have some decompositions. |
@wonjoolee95 the way I would categorize the decomps we have is that: (1) there are a ton of decomps, for many (most?) ops in ATen. A compiler backend to torch.compile can specify which of those decomps they do/don't want to run, and lower the remaining primitive ops that show up in the graph directly. For example, inductor has its own set of decomps that it uses (basically core aten + a few other decomps: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/decomposition.py#L75) (2) there is a canonical "core ATen opset" that backends can choose to target, and you can get out a graph of core ATen IR by specifying that you only want to run core ATen decomps. So if you e.g. use torch.export and run with core ATen decomps set, you'll get a graph of core ATen IR. But the eager-mode XLA integration doesn't necessarily run the same set of decomps as core ATen (although you can change which ops you choose to decompose vs lower directly) |
Thanks for the explanation, that makes a lot of sense. @zpcore, I'll remove this issue from the scope of "core aten opset". In |
Hi @bdhirsh , @wonjoolee95 , thanks for the following up. I checked the decomposition trace, it turns out that if we move the tensor to the XLA device, it will use a different decompose. Below is the example code.
It will enter _meta_registrations.py If we remove Either way, I didn't see the python registration been called. |
In order for PyTorch/XLA to support the PyTorch core ATen opset, it requires lowering each core ATen op in PyTorch/XLA. This issue is used to track the PyTorch/XLA lowering for aten_pixel_shuffle.
Here are some general guidelines to lowering this op:
@unittest.skip
or@unittest.expectFailure
and run the unit test at test_core_aten_ops.py. Eg:pytest test/test_core_aten_ops.py -k test_aten_pixel_shuffle_0
For any questions, feel free to leave a comment in this PR.
The text was updated successfully, but these errors were encountered: