Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention 2 doesn't work #128

Open
ArtemBiliksin opened this issue Sep 28, 2024 · 1 comment
Open

Flash attention 2 doesn't work #128

ArtemBiliksin opened this issue Sep 28, 2024 · 1 comment

Comments

@ArtemBiliksin
Copy link

ArtemBiliksin commented Sep 28, 2024

Hello!

The main (a441a3f) branch of the AQLM repository does not support flash attention 2. The error occurs because QuantizedWeight does not have a weight attribute (closed issue #31). For example, the error occurs in implementations:

In issue #31 a solution was proposed in the pre_quantization_dtype branch. The last commit ebe8ece tries to solve the Flash Attention 2 problem for the mistral type model, but it doesn't help. I tried doing the same thing in the main (a441a3f) branch for the llama type model, ie.

if model.config.model_type == "llama":
    layer.self_attn.config._pre_quantization_dtype = torch.float32

and I get error: RuntimeError: FlashAttention only support fp16 and bf16 data type. See the details of the error at the very bottom.

You can quickly reproduce the experiment by executing bash flash_attn2_problem.sh, where

flash_attn2_problem.sh

git clone https://github.com/Vahe1994/AQLM.git

SHA1_HEAD=$(git --git-dir=AQLM/.git rev-parse  HEAD)
echo "SHA1_HEAD ${SHA1_HEAD}"
if [[ $SHA1_HEAD != "a441a3f0ece4cbaa2a91a3421c95a8b7432e4d99" ]]
then
  echo "The main branch has been modified. Issue #128 may not be up to date."
  exit 1
fi

sed -n '1,296p' AQLM/main.py > tmp.txt
echo '            if model.config.model_type == "llama":' >> tmp.txt
echo '                layer.self_attn.config._pre_quantization_dtype = torch.float32' >> tmp.txt
sed -n '297,918p' AQLM/main.py >> tmp.txt
mv tmp.txt AQLM/main.py

pip install -r AQLM/requirements.txt
pip install flash-attn --no-build-isolation

MODEL_NAME=meta-llama/Llama-2-7b-hf
DATASET_NAME=pajama

python AQLM/main.py $MODEL_NAME $DATASET_NAME \
  --nsamples 1 \
  --model_seqlen 16 \
  --nbits_per_codebook 4 \
  --init_max_iter 1 \
  --max_epochs 1 \
  --steps_per_epoch 1 \
  --finetune_max_epochs 1 \
  --attn_implementation "flash_attention_2"

The full output of the error that occurs:

PREPARING TO FINETUNE
LlamaDecoderLayer(
  (self_attn): LlamaFlashAttention2(
    (q_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
    )
    (k_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
    )
    (v_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
    )
    (o_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
    )
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (mlp): LlamaMLP(
    (gate_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=11008, self.in_features=4096, bits_per_parameter=0.5039516715116279)
    )
    (up_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=11008, self.in_features=4096, bits_per_parameter=0.5039516715116279)
    )
    (down_proj): QuantizedLinear(
      (quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=11008, bits_per_parameter=0.5014989098837209)
    )
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm()
  (post_attention_layernorm): LlamaRMSNorm()
)
Fine-tuning 51584 parameters
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float32.
Traceback (most recent call last):
  File "/root/flash_attn2_problem/AQLM/main.py", line 920, in <module>
    main()
  File "/root/flash_attn2_problem/AQLM/main.py", line 894, in main
    quantize_model(model, args)
  File "/root/flash_attn2_problem/AQLM/main.py", line 59, in quantize_model
    results = quantize_aq(model, train_data, val_data, args)
  File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/flash_attn2_problem/AQLM/main.py", line 300, in quantize_aq
    layer = finetune_groupwise(
  File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/flash_attn2_problem/AQLM/src/finetune.py", line 143, in finetune_groupwise
    loss = _compute_mse_on_batch(layer, train_batch_iterators[0], **kwargs)
  File "/root/flash_attn2_problem/AQLM/src/finetune.py", line 268, in _compute_mse_on_batch
    outs_prediction, *_unused = layer(inps_batch, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 741, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 490, in forward
    attn_output = self._flash_attention_forward(
  File "/root/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 555, in _flash_attention_forward
    attn_output = flash_attn_func(
  File "/root/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func
    return FlashAttnFunc.apply(
  File "/root/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/root/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
@ArtemBiliksin
Copy link
Author

Hi, I found the reason why FlashAttention2 (--attn_implementation “flash_attention_2”) is not working.

On line 296 of the main.py file, you cast the block to the float32 data type and on line 257 of the src/finetune.py file, you cast the input data for the block to the float32 data type. But FlashAttention2 only works with the bfloat16 and float16 data type. Therefore, if you use FlashAttention2, you need to cast the block to the bfloat16 (float16) data type and cast the input data for the block to the bfloat16 (float16) data type as well. This will be enough, it will start working. You should keep in mind that casting to the float16 data type may break the training because of nan. Therefore it is better to use bfloat16.

BlockDtype
InputBlockDtype

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant