Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama_405b_tp8 OOM w/ Long Input Prompt #19832

Open
stbaione opened this issue Jan 28, 2025 · 2 comments
Open

Llama_405b_tp8 OOM w/ Long Input Prompt #19832

stbaione opened this issue Jan 28, 2025 · 2 comments
Labels
bug 🐞 Something isn't working

Comments

@stbaione
Copy link

stbaione commented Jan 28, 2025

What happened?

I'm seeing an OOM error with Llama 405b tp8 when input prompt is sufficiently long, in this case an input of 11,520 tokens. This is a prompt from the dataset used for Llama mlperf.

I am able to reproduce the issue on MI300X-3, with iree-run-module.

A trace of the command can be found here.

Started iree-run-module invocation with GPU RAM % at 0% for all 8 devices.

Loading the model itself brings RAM usage to 51%:

Image

Upon invocation, it shoots up to 99% for all devices:

Image

NOTE: Due to #19833, you will need to compile with a version prior to 4b0ca34, however bug is still relevant to most recent runtime.

Steps to reproduce your issue

  1. Download the 405b MLIR
  2. Download the inputs with the following script
#!/bin/bash

mkdir inputs_long
cd inputs_long

wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/tokens.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/seq_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/seq_block_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_0.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_1.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_2.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_3.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_4.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_5.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_6.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_405b_tp8/inputs/prefill/long/cache_state_shard_7.npy
  1. Compile
iree-compile 405b_instruct_fp16.mlir -o llama.vmfb --iree-hal-target-device=hip[0]  --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7]  --iree-hip-target=gfx942  --iree-dispatch-creation-enable-aggressive-fusion=true      --iree-global-opt-propagate-transposes=true   --iree-opt-aggressively-propagate-transposes=true      --iree-opt-data-tiling=false   --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'      --iree-hal-indirect-command-buffers=true   --iree-stream-resource-memory-model=discrete  --iree-hal-memoization=true   --iree-opt-strip-assertions
  1. Invoke iree-run-module
ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 iree-run-module --hip_use_streams=true --module=llama.vmfb --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank0.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank1.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank2.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank3.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank4.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank5.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank6.irpa --parameters=model=/data/llama3.1/weights/405b/fp16/tp8/llama3.1_405b_fp16_tp8_parameters.rank7.irpa --device=hip://0 --device=hip://1 --device=hip://2 --device=hip://3 --device=hip://4 --device=hip://5 --device=hip://6 --device=hip://7 --function=prefill_bs4 --input=@inputs_long/tokens.npy --input=@inputs_long/seq_ids.npy --input=@inputs_long/seq_block_ids.npy --input=@inputs_long/cache_state_shard_0.npy --input=@inputs_long/cache_state_shard_1.npy --input=@inputs_long/cache_state_shard_2.npy --input=@inputs_long/cache_state_shard_3.npy --input=@inputs_long/cache_state_shard_4.npy --input=@inputs_long/cache_state_shard_5.npy --input=@inputs_long/cache_state_shard_6.npy --input=@inputs_long/cache_state_shard_7.npy
  1. You should see an error similar to below with export AMD_LOG_LEVEL=1
:1:hip_memory.cpp           :329 : 608260162041d us:  Allocation failed : Device memory : required :9075425280 | free :1430257664 | total :206141652992
:1:memory.cpp               :358 : 608260162067d us:  Video memory allocation failed!
:1:memory.cpp               :318 : 608260162157d us:  Can't allocate memory size - 0x1CF00000 bytes!
:1:rocdevice.cpp            :2443: 608260162166d us:  failed to create a svm hidden buffer!
:1:rocdevice.cpp            :2388: 608260162155d us:  Fail allocation local memory
:1:rocdevice.cpp            :2107: 608260162187d us:  Failed creating memory
:1:memory.cpp               :358 : 608260162199d us:  Video memory allocation failed!
:1:memory.cpp               :318 : 608260162210d us:  Can't allocate memory size - 0x1CF00000 bytes!
:1:rocdevice.cpp            :2388: 608260162217d us:  Fail allocation local memory
:1:rocdevice.cpp            :2107: 608260162232d us:  Failed creating memory
:1:memory.cpp               :358 : 608260162242d us:  Video memory allocation failed!
:1:memory.cpp               :318 : 608260162254d us:  Can't allocate memory size - 0x1CF00000 bytes!
:1:rocdevice.cpp            :2443: 608260162263d us:  failed to create a svm hidden buffer!
:1:memory.cpp               :1534: 608260162274d us:  Unable to allocate aligned memory
:1:hip_memory.cpp           :329 : 608260162290d us:  Allocation failed : Device memory : required :9075425280 | free :1430257664 | total :206141652992

What component(s) does this issue relate to?

Runtime

Version information

22b34b5

Additional context

NOTE: Due to #19833, you will need to compile with a version prior to 4b0ca34, however bug is still relevant to most recent runtime.

@stbaione stbaione added the bug 🐞 Something isn't working label Jan 28, 2025
@stbaione
Copy link
Author

stbaione commented Feb 5, 2025

Latest status points towards all-reduce as being the culprit. We're not deallocating the tensors from all-reduce, which is causing chunks of 8GB to be taken up and not freed for each invocation of all-reduce.

Detailed discussion can be found here

@drprajap
Copy link

drprajap commented Feb 7, 2025

There are 2 issues with our sharded implementation

  1. inefficient all_reduce implementation which results in broadcast to all devices causing lots of data movement due to copy/transfer operations, this needs to be fixed at torch level, Rob is looking into it.
  2. Bug in allocation/deallocation path where we don't have functionality in place to track non-local (cross-device) tensor's lifetime, current implementation inserts dealloca for local tensors around single execution region only, so non-local tensors are left deallocated, which bubbles up the memory usage and eventually it results in OOM. I am working on analysis pass that can perform whole program level analysis with reference counting and insert dealloca in proper places.

Both the issues can be tackled independently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants