Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix output tensor shape for argmin and argmax where keepdim=True and dim=None #6536

Merged
merged 8 commits into from
Feb 29, 2024
Merged

Conversation

mrnikwaws
Copy link
Contributor

Current failure looks like this:

root@f825750bf417:/ansible/torch_xla/pytorch/xla# export PJRT_DEVICE="CPU"
root@f825750bf417:/ansible/torch_xla/pytorch/xla# python
Python 3.8.18 (default, Feb  1 2024, 06:10:58) 
[GCC 10.2.1 20210110] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla
>>> t = torch.rand((3,4))
>>> t
tensor([[0.9253, 0.5503, 0.4175, 0.7273],
        [0.8306, 0.7907, 0.2054, 0.5639],
        [0.4089, 0.1673, 0.9702, 0.4839]])
>>> t_xla = t.to('xla')
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707952093.212406   88634 cpu_client.cc:404] TfrtCpuClient created.
>>> t_xla
tensor([[0.9253, 0.5503, 0.4175, 0.7273],
        [0.8306, 0.7907, 0.2054, 0.5639],
        [0.4089, 0.1673, 0.9702, 0.4839]], device='xla:0')
>>> torch.argmax(t,dim=None,keepdim=True)
tensor([[10]])
>>> torch.argmax(t_xla,dim=None,keepdim=True)
tensor(10, device='xla:0'

After:

root@f825750bf417:/ansible/torch_xla/pytorch/xla# python
Python 3.8.18 (default, Feb  1 2024, 06:10:58) 
[GCC 10.2.1 20210110] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_xla
>>> t = torch.randn((3,4))
>>> t
tensor([[-1.6839e+00, -5.5569e-01, -1.1452e+00,  4.5730e-01],
        [ 7.5517e-01,  2.3971e+00,  5.8805e-01,  8.4879e-01],
        [-3.3246e-04,  1.4524e-01,  2.0454e-01, -5.7229e-01]])
>>> t_xla = t.to('xla')
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707953861.897586   92883 cpu_client.cc:404] TfrtCpuClient created.
>>> torch.argmax(t,keepdim=True)
tensor([[5]])
>>> torch.argmax(t_xla,keepdim=True)
tensor([[5]], device='xla:0')
>>> torch.argmin(t,keepdim=True)
tensor([[0]])
>>> torch.argmin(t_xla,keepdim=True)
tensor([[0]], device='xla:0')

@JackCaoG JackCaoG requested a review from wonjoolee95 February 15, 2024 19:21
Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Changes LGTM, can you add a corresponding unit test at https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_2.cpp#L2077 and fix the linter issues?

@mrnikwaws
Copy link
Contributor Author

Yep will do

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Changes LGTM, let's wait for the CI to verify the tests.

@wonjoolee95
Copy link
Collaborator

Seems like the build is failing:

test/cpp/test_aten_xla_tensor_2.cpp:2189:8: note: ‘virtual void torch_xla::cpp_test::AtenXlaTensorTest_TestArgMaxDimKeep_Test::TestBody()’ previously defined here
 2189 | TEST_F(AtenXlaTensorTest, TestArgMaxDimKeep) {

My guess is because test with name TestArgMaxDimKeep already exists.

@mrnikwaws
Copy link
Contributor Author

I misread the test code - they need unique names - fixing now

@mrnikwaws
Copy link
Contributor Author

New test failures seem unrelated?

[0 / 1] [Prepa] BazelWorkspaceStatusAction stable-status.txt
[11 / 13] 7 / 8 tests; Compiling test/cpp/test_aten_xla_tensor_2.cpp; 0s local
[11 / 13] 7 / 8 tests; Compiling test/cpp/test_aten_xla_tensor_2.cpp; 11s local
[12 / 13] 7 / 8 tests; checking cached actions
[12 / 13] 7 / 8 tests; Linking test/cpp/test_aten_xla_tensor_2; 4s local
[12 / 13] 7 / 8 tests; Linking test/cpp/test_aten_xla_tensor_2; 10s local
[13 / 14] 7 / 8 tests; [Prepa] Testing //test/cpp:test_aten_xla_tensor_2
[13 / 14] 7 / 8 tests; Testing //test/cpp:test_aten_xla_tensor_2; 11s local
[13 / 14] 8 / 8 tests; Testing //test/cpp:test_aten_xla_tensor_2; 24s local
INFO: Elapsed time: 85.857s, Critical Path: 83.49s
INFO: 7 processes: 4 internal, 3 local.
INFO: Build completed successfully, 7 total actions
//torch_xla/csrc/runtime:cache_test                             (cached) PASSED in 0.7s
//torch_xla/csrc/runtime:env_hash_test                          (cached) PASSED in 0.6s
//torch_xla/csrc/runtime:ifrt_computation_client_test           (cached) PASSED in 1.2s
//torch_xla/csrc/runtime:pjrt_computation_client_test           (cached) PASSED in 0.5s
//torch_xla/csrc/runtime:sys_util_test                          (cached) PASSED in 0.0s
//torch_xla/csrc/runtime:util_test                              (cached) PASSED in 0.1s
//torch_xla/csrc/runtime:xla_util_test                          (cached) PASSED in 0.8s
//test/cpp:test_aten_xla_tensor_2                                        PASSED in 24.9s

Executed 1 out of 8 tests: 8 tests pass.

@wonjoolee95
Copy link
Collaborator

Apologies for the delay but can you rebase with this head? The failing tests have been disabled on head, but I hope to let the CI complete before merging this. Thanks!

@mrnikwaws
Copy link
Contributor Author

Looks like new unrelated failures

@wonjoolee95
Copy link
Collaborator

Yeah, it's been a rough day for the head CI.. 😢

Let's just wait for the other CIs and merge it if they're green, changes LGTM anyways. Thanks!

@wonjoolee95
Copy link
Collaborator

CI seems to be passing other than the same failure on head. Merging this. Thanks!

@wonjoolee95 wonjoolee95 merged commit a1ab7fd into pytorch:master Feb 29, 2024
15 of 17 checks passed
amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants