Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] llama 8b fp8 with attention vector distribute fail #19991

Open
AmosLewis opened this issue Feb 14, 2025 · 6 comments · May be fixed by #20014
Open

[Codegen] llama 8b fp8 with attention vector distribute fail #19991

AmosLewis opened this issue Feb 14, 2025 · 6 comments · May be fixed by #20014
Assignees
Labels
bug 🐞 Something isn't working

Comments

@AmosLewis
Copy link
Contributor

What happened?

Follow up of [ROCm][Codegen] llama 8b fp8 with attention segfault #19921

New codegen issue llama_f8_attn_bug_log_0213.txt after I rebase iree to

commit 0ff26a7bef803edf3e22588f3e69a51c9335a79b (HEAD -> main, upstream/main)
Author: Prashant Kumar <pk5561@gmail.com>
Date:   Thu Feb 13 23:26:59 2025 +0530
    [Codegen] Add support to emulate unsupported float type (#19943)
f8_attn_chi_castf32_roctorch.mlir:45778:10: error: 'func.func' op failed to distribute
    %1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed, %collapsed_1, %collapsed_2, %extracted, %arg4 : tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, f32, tensor<?x?xf8E4M3FNUZ>) outs(%cast : tensor<32x?x128xf32>) {
         ^
f8_attn_chi_castf32_roctorch.mlir:2706:12: note: called from
    %914 = util.call @sharktank_masked_flash_attention_1_32_128_128_f8E4M3FNUZ_f32_f32(%909, %910, %911, %913, %912) : (tensor<1x32x?x128xf8E4M3FNUZ>, tensor<1x32x?x128xf8E4M3FNUZ>, tensor<1x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?xf8E4M3FNUZ>) -> tensor<1x32x?x128xf32>
           ^
f8_attn_chi_castf32_roctorch.mlir:45778:10: note: see current operation:

Steps to reproduce your issue

  1. compile iree
cmake -G Ninja -B ../iree-build  -S . \
    -DCMAKE_BUILD_TYPE=Debug \
    -DIREE_ENABLE_ASSERTIONS=ON \
    -DCMAKE_C_COMPILER=clang \
    -DCMAKE_CXX_COMPILER=clang++ \
    -DIREE_ENABLE_RUNTIME_TRACING=ON \
    -DIREE_BUILD_TRACY=OFF \
    -DIREE_ENABLE_LLD=ON \
    -DIREE_BUILD_PYTHON_BINDINGS=ON \
    -DPython3_EXECUTABLE="$(which python3)" \
    -DIREE_TARGET_BACKEND_CUDA=OFF \
    -DIREE_HAL_DRIVER_HIP=ON \
    -DIREE_TARGET_BACKEND_ROCM=ON .
cmake --build ../iree-build
  1. Download input mlir here f8_attn_chi_castf32_roctorch.mlir,

Optional: Export the 8_attn_chi_castf32_roctorch.mlir manually with nod-ai/shark-ai#907

run the following cmd:

 /home/chi/src/iree-build/tools/iree-compile f8_attn_chi_castf32_roctorch.mlir \
  --iree-hip-target=gfx942 \
  -o=f8_attn_chi_castf32_roctorch.vmfb \
  --iree-hal-target-device=hip \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions

What component(s) does this issue relate to?

Compiler

Version information

commit 0ff26a7 (HEAD -> main, upstream/main)
Author: Prashant Kumar pk5561@gmail.com
Date: Thu Feb 13 23:26:59 2025 +0530
[Codegen] Add support to emulate unsupported float type (#19943)

Additional context

No response

@pashu123
Copy link
Contributor

Arising from this dispatch: https://gist.github.com/pashu123/e21bc74fafbc4ce3ae23b0adf3ac75b5

@pashu123
Copy link
Contributor

iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-distribute,canonicalize,cse))" --split-input-file before_vector_distribute.mlir
https://gist.github.com/pashu123/e22dba342a9bf78c0ee5ccb0522d3855

@IanWood1
Copy link
Contributor

IanWood1 commented Feb 14, 2025

I think the problem is the tensor.expand_shape which is getting put into the dispatch because it is the attention mask, this is similar to the problem fixed by #19838

cc @MaheshRavishankar

@dan-garvey
Copy link
Contributor

did this get triaged in the codegen sync? Is someone on it?

@MaheshRavishankar MaheshRavishankar self-assigned this Feb 17, 2025
@MaheshRavishankar
Copy link
Contributor

Ok, I looked at this, and it hit a pretty hairy unimplemented part of propagating reshapes across the mask operation. It is also true that this is adding a unit dimension here. I havent tracked it down to what that unit dimension is. It cant be the batch since the batch is not part of mask AFAIK. So some extra unit dimension being added here that I havent fully tracked down. Ill fix the core issue, but there might be some intermediate WAR that might be faster to land and unblock.

@MaheshRavishankar
Copy link
Contributor

#20014 seems to fix the compilation, but I need to see what impact it will have on test suite, etc. But at least the fix works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants