You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm seeing an OOM error with Llama 405b tp8 when input prompt is sufficiently long, in this case an input of 11,520 tokens. This is a prompt from the dataset used for Llama mlperf.
I am able to reproduce the issue on MI300X-3, with iree-run-module.
Latest status points towards all-reduce as being the culprit. We're not deallocating the tensors from all-reduce, which is causing chunks of 8GB to be taken up and not freed for each invocation of all-reduce.
There are 2 issues with our sharded implementation
inefficient all_reduce implementation which results in broadcast to all devices causing lots of data movement due to copy/transfer operations, this needs to be fixed at torch level, Rob is looking into it.
Bug in allocation/deallocation path where we don't have functionality in place to track non-local (cross-device) tensor's lifetime, current implementation inserts dealloca for local tensors around single execution region only, so non-local tensors are left deallocated, which bubbles up the memory usage and eventually it results in OOM. I am working on analysis pass that can perform whole program level analysis with reference counting and insert dealloca in proper places.
What happened?
I'm seeing an OOM error with Llama 405b tp8 when input prompt is sufficiently long, in this case an input of
11,520
tokens. This is a prompt from the dataset used for Llama mlperf.I am able to reproduce the issue on
MI300X-3
, withiree-run-module
.A trace of the command can be found here.
Started
iree-run-module
invocation with GPU RAM % at 0% for all 8 devices.Loading the model itself brings RAM usage to 51%:
Upon invocation, it shoots up to 99% for all devices:
NOTE: Due to #19833, you will need to compile with a version prior to 4b0ca34, however bug is still relevant to most recent runtime.
Steps to reproduce your issue
iree-compile 405b_instruct_fp16.mlir -o llama.vmfb --iree-hal-target-device=hip[0] --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7] --iree-hip-target=gfx942 --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete --iree-hal-memoization=true --iree-opt-strip-assertions
iree-run-module
export AMD_LOG_LEVEL=1
What component(s) does this issue relate to?
Runtime
Version information
22b34b5
Additional context
NOTE: Due to #19833, you will need to compile with a version prior to 4b0ca34, however bug is still relevant to most recent runtime.
The text was updated successfully, but these errors were encountered: