Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to avoid compilation in a section of code? #7622

Open
Jiayi-Pan opened this issue Jul 3, 2024 · 9 comments
Open

How to avoid compilation in a section of code? #7622

Jiayi-Pan opened this issue Jul 3, 2024 · 9 comments

Comments

@Jiayi-Pan
Copy link

Jiayi-Pan commented Jul 3, 2024

❓ Questions and Help

We are using Pytorch XLA w/ TPU to train a multi-modal language models.

We can make most of the code, such as image encoding and the forward pass in the LLM backbone, in a static shape, which XLA handles well. However, making the part that fuses image and text embeddings into the input embedding static is extremely challenging.

Currently, we use mark_step to isolate that section from the rest of the code, allowing it to recompile each time. Although this part is very computationally light, the recompilation is extremely slow and often consumes the majority of training time.

We find documentation on this issue very hard to find, and we are exploring better solutions, such as running that part on the CPU, in eager mode, or not saving that part of the graph to avoid OOM errors during long training runs. We wonder if you have any suggestions/pointers on how to workaround this inefficiency?

Following is a pesudo code to illustrate our problem

for ... # loading data
  # these tensors are with static shape, xla works great on them
  image_embeddings = image_encoder(raw_image_tensor)
  text_embeddings = get_text_embedding(text_token_idxs)
  
  xm.mark_step()
  # this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
  input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
  xm.mark_step()
  
  # these tensors are with static shape, xla works great on them
  output_logits = llm(input_embeddings)
  # loss compute / backward / optimizer step omited
@Jiayi-Pan Jiayi-Pan changed the title How to avoid recompilation in a region of code? How to avoid recompilation in a section of code? Jul 3, 2024
@Jiayi-Pan Jiayi-Pan changed the title How to avoid recompilation in a section of code? How to avoid compilation in a section of code? Jul 3, 2024
@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 3, 2024

Great question. I have a couple questions and a couple suggestions

Question

  1. seems like even through fuse_embedding is dynamic, the shape of input_embeddings is static? This would explain why llm hlo is static
  2. How dynamic is the fuse_embedding? For example are there a total 100 different shape combinations possible, or there can be literally thousands of different shape combinations possible.

Suggestion

  1. Have you used persistent caching? If not please take a look at https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. If there is a relatively smaller dynamism in your code enabling the persistent caching would fix the issue(you can compile and remember all possible combinations).
  2. Maybe try eager mode. This is an experimental feature so you will need nightly. Take a look at https://github.com/pytorch/xla/blob/master/examples/eager/train_decoder_only_eager.py#L10. You can enable the eager mode in the dynamic region and disable it right after. Or you can do similar to https://github.com/pytorch/xla/blob/master/examples/eager/train_decoder_only_eager_with_compile.py which will turn on eager by default and manully pick the region to compile. Eager + compile is the UX I want to make default in next year so appreciate if you have any feedback.

For nightly you can try use

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly+20240701-cp310-cp310-linux_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

since last nightl's nightly seems to be broken.

Eager mode pretty much just compile op by op. It will compile each op once for each input shape, the overall compile time is usually lower. Let me know how above 2 suggestions work for you.

@Jiayi-Pan
Copy link
Author

Jiayi-Pan commented Jul 3, 2024

Thank you for the instructions!
Re Q1: that's correct! We deliberately pad both raw_image_tensor and input_embeddings to make the shape static. Only fuse_embedding is recompiled while llm and image_encoder, where most of the compute happens are static.
Re Q2: unfortunately it's very dynamic, it should be at least on the OOM of thousands

The eager mode looks very promising, however, I'm unable to install the nightly

tpu-vm:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 3, 2024

hmm that's weird, can you access https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl through? If I click on this link it just download the whl file for me. Also is your python version 3.10?

@Jiayi-Pan
Copy link
Author

I can access it, and it's 3.10. But the issue is still there

jiayipan@t1v-n-f6802337-w-0:~$ python --version
Python 3.10.12
jiayipan@t1v-n-f6802337-w-0:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly+20240701
             ^
jiayipan@t1v-n-f6802337-w-0:~$ wget https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
--2024-07-03 17:29:57--  https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.12.207, 173.194.217.207, 74.125.26.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.12.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 83362771 (80M) [application/octet-stream]
Saving to: ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’

torch_xla-nightly+2 100%[===================>]  79.50M  88.9MB/s    in 0.9s

2024-07-03 17:29:58 (88.9 MB/s) - ‘torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl’ saved [83362771/83362771]

jiayipan@t1v-n-f6802337-w-0:~$ pip install
.bash_history
.bash_logout
.bashrc
.cache/
.config/
.local/
.profile
.ssh/
.viminfo
buckets/
prismatic-video-lms/
torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
torch_xla-nightly-cp310-cp310-linux_x86_64.whl
jiayipan@t1v-n-f6802337-w-0:~$ pip install torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly+20240701': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly+20240701
             ^
jiayipan@t1v-n-f6802337-w-0:~$

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 3, 2024

hmm I can't repo this, which is a bit wierd. Maybe manually renamed the whl? something like

mv torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl torch_xla-nightly-cp310-cp310-linux_x86_64.whl

@Jiayi-Pan
Copy link
Author

Some updates.

On reproducing the installation issue
It turns out that the installation error only happens after

python3 -m pip install --upgrade pip 

Given a clean tpu-v3 vm w/ ubuntu-22.04, you should be able to reproduce the error by

python3 -m pip install --upgrade pip 
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly+20240701-cp310-cp310-linux_x86_64.whl

@Jiayi-Pan
Copy link
Author

On Eager Mode
I tried eager mode! The code structure is basically as shown here.

for ... # loading data
  # these tensors are with static shape, xla works great on them
  image_embeddings = image_encoder(raw_image_tensor)
  text_embeddings = get_text_embedding(text_token_idxs)
  
  xm.mark_step()
  # this part is very light in compute, but dynamic. We currently just recompile this graph every single time :(
  torch_xla.experimental.eager_mode(True)
  input_embeddings = fuse_embedding(raw_image_tensor, text_token_idxs, sequence_info_dict)
  torch_xla.experimental.eager_mode(False)
  xm.mark_step()
  
  # these tensors are with static shape, xla works great on them
  output_logits = llm(input_embeddings)
  # loss compute / backward / optimizer step omited

Unfortunately, the code hangs and never reaches output_logits = llm(input_embeddings). (It still works fine on nightly when I disable eager mode).
Do you have any suggestions on debugging? There are a few mark_steps around/within fuse_embedding, not sure if they cause any trouble

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jul 8, 2024

can you run with PT_XLA_DEBUG_LEVEL=1? This will print an message for every compilation if you are using nightly. I am wondering if it is just keep recompiling or eager compilation(compile for each op) is too slow.

@fellhorn
Copy link

Thank you for the instructions! Re Q1: that's correct! We deliberately pad both raw_image_tensor and input_embeddings to make the shape static. Only fuse_embedding is recompiled while llm and image_encoder, where most of the compute happens are static. Re Q2: unfortunately it's very dynamic, it should be at least on the OOM of thousands

The eager mode looks very promising, however, I'm unable to install the nightly

tpu-vm:~$ pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
Defaulting to user installation because normal site-packages is not writeable
ERROR: Invalid requirement: 'torch-xla==nightly': Expected end or semicolon (after name and no valid version specifier)
    torch-xla==nightly

As I was facing the same issue with uv, I created a separate issue for the broken nightly filenames:
#7697

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants