Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch XLA .data assignment fails when the new tensor is a different shape #3502

Closed
ronghanghu opened this issue Apr 14, 2022 · 6 comments
Closed

Comments

@ronghanghu
Copy link
Collaborator

ronghanghu commented Apr 14, 2022

🐛 Bug

In the latest nightly 20220413 PyTorch XLA build, the shape assignment example in #3392 (comment) is broken again. This is now breaking XLA FSDP implementation in (#3431).

To Reproduce

  1. Allocate a v3-8 TPU VM with tpu-vm-pt-1.10 runtime and install the nightly 20220413 environment
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl
  1. Run the example below
import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

x3 = torch.zeros(10, device=device)
y3 = x3.view(-1)
# This should NOT update y3 because `x3.data` is not in-place modified
x3.data = y3[:5] + 1

print(f"y3: {y3}")

which gives

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 655, in __format__
    return object.__format__(self, format_spec)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 338, in __repr__
    return torch._tensor_str._str(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 439, in _str
    return _str_intern(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 325, in _str_intern
    self = self.to('cpu')
RuntimeError: INVALID_ARGUMENT: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]
         [[XRTExecute_G12]]
  (1) INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at tpu_execute_op.cc:266 : INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

Expected behavior

On the previous nightly 20220408 build

sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl

this example was working well and prints

y3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='xla:1')

as expected (consistent with normal PyTorch behavior on CPU/GPU)

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
  • torch_xla version: nightly 20220413

Additional context

It would be great to create a test case for the example above to guard against future issues.

cc: @JackCaoG

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 15, 2022

Update: I found that this error is due to the nightly 20220413 version of libtpu. It's independent of torch, torch_xla or torchvision.

If we only use 20220413 version of torch, torchvision and torch_xla but keep 20220408 version of libtpu, then everything is still good.

# torch, torchvision and torch_xla 20220413
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220413-cp38-cp38-linux_x86_64.whl

# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl

gives

y3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='xla:1')

in the example above.

@ronghanghu
Copy link
Collaborator Author

Update: closing this issue went away with the nightly 20220415 version of libtpu.

The example above works now with

# torch, torchvision, torch_xla and libtpu 20220415
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220415-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220415-py3-none-any.whl

It seems that the error was only due to a problematic 20220413 version of libtpu and was mitigated in newer libtpu versions.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 16, 2022

Update: there seem to be other problems with 20220415 version of libtpu that break things (could be related to the TensorFlow version problem). I'll keep this issue open and try to update it later. For now, 20220408 version of libtpu works well.

@JackCaoG
Copy link
Collaborator

@ronghanghu Let us know if our side tf update fixed the issue, otherwise I can take another pass of this bug.

@ronghanghu
Copy link
Collaborator Author

Thanks! I'll checkout the nightly build 20220422 tomorrow and get back to you

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 30, 2022

Update: I tried out the latest torch_xla and torch nightly 20220430 wheels that contain #3535, #3541, and tested it against libtpu dev20220413. Now it works well and this error goes away. So it seems that previously the issue was because of the outdated TensorFlow version.

Thanks for your help on this!

Also, note that the latest public libtpu nightly wheel (as of 04/30/2022 morning PDT) is only dev20220420. This newer libtpu gives a warning message Missing key: 'ALT' in 'tpu-env' (yaml) instance metadata but is also working well except for this warning message.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants