You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The main (a441a3f) branch of the AQLM repository does not support flash attention 2. The error occurs because QuantizedWeight does not have a weight attribute (closed issue #31). For example, the error occurs in implementations:
In issue #31 a solution was proposed in the pre_quantization_dtype branch. The last commit ebe8ece tries to solve the Flash Attention 2 problem for the mistral type model, but it doesn't help. I tried doing the same thing in the main (a441a3f) branch for the llama type model, ie.
if model.config.model_type == "llama":
layer.self_attn.config._pre_quantization_dtype = torch.float32
and I get error: RuntimeError: FlashAttention only support fp16 and bf16 data type. See the details of the error at the very bottom.
You can quickly reproduce the experiment by executing bash flash_attn2_problem.sh, where
flash_attn2_problem.sh
git clone https://github.com/Vahe1994/AQLM.git
SHA1_HEAD=$(git --git-dir=AQLM/.git rev-parse HEAD)echo"SHA1_HEAD ${SHA1_HEAD}"if [[ $SHA1_HEAD!="a441a3f0ece4cbaa2a91a3421c95a8b7432e4d99" ]]
thenecho"The main branch has been modified. Issue #128 may not be up to date."exit 1
fi
sed -n '1,296p' AQLM/main.py > tmp.txt
echo' if model.config.model_type == "llama":'>> tmp.txt
echo' layer.self_attn.config._pre_quantization_dtype = torch.float32'>> tmp.txt
sed -n '297,918p' AQLM/main.py >> tmp.txt
mv tmp.txt AQLM/main.py
pip install -r AQLM/requirements.txt
pip install flash-attn --no-build-isolation
MODEL_NAME=meta-llama/Llama-2-7b-hf
DATASET_NAME=pajama
python AQLM/main.py $MODEL_NAME$DATASET_NAME \
--nsamples 1 \
--model_seqlen 16 \
--nbits_per_codebook 4 \
--init_max_iter 1 \
--max_epochs 1 \
--steps_per_epoch 1 \
--finetune_max_epochs 1 \
--attn_implementation "flash_attention_2"
The full output of the error that occurs:
PREPARING TO FINETUNE
LlamaDecoderLayer(
(self_attn): LlamaFlashAttention2(
(q_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
)
(k_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
)
(v_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
)
(o_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=4096, bits_per_parameter=0.5040283203125)
)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=11008, self.in_features=4096, bits_per_parameter=0.5039516715116279)
)
(up_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=11008, self.in_features=4096, bits_per_parameter=0.5039516715116279)
)
(down_proj): QuantizedLinear(
(quantized_weight): QuantizedWeight(self.out_features=4096, self.in_features=11008, bits_per_parameter=0.5014989098837209)
)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
Fine-tuning 51584 parameters
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float32.
Traceback (most recent call last):
File "/root/flash_attn2_problem/AQLM/main.py", line 920, in <module>
main()
File "/root/flash_attn2_problem/AQLM/main.py", line 894, in main
quantize_model(model, args)
File "/root/flash_attn2_problem/AQLM/main.py", line 59, in quantize_model
results = quantize_aq(model, train_data, val_data, args)
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/flash_attn2_problem/AQLM/main.py", line 300, in quantize_aq
layer = finetune_groupwise(
File "/root/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/flash_attn2_problem/AQLM/src/finetune.py", line 143, in finetune_groupwise
loss = _compute_mse_on_batch(layer, train_batch_iterators[0], **kwargs)
File "/root/flash_attn2_problem/AQLM/src/finetune.py", line 268, in _compute_mse_on_batch
outs_prediction, *_unused = layer(inps_batch, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 741, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 490, in forward
attn_output = self._flash_attention_forward(
File "/root/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 555, in _flash_attention_forward
attn_output = flash_attn_func(
File "/root/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func
return FlashAttnFunc.apply(
File "/root/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/root/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 546, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/root/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
The text was updated successfully, but these errors were encountered:
Hi, I found the reason why FlashAttention2 (--attn_implementation “flash_attention_2”) is not working.
On line 296 of the main.py file, you cast the block to the float32 data type and on line 257 of the src/finetune.py file, you cast the input data for the block to the float32 data type. But FlashAttention2 only works with the bfloat16 and float16 data type. Therefore, if you use FlashAttention2, you need to cast the block to the bfloat16 (float16) data type and cast the input data for the block to the bfloat16 (float16) data type as well. This will be enough, it will start working. You should keep in mind that casting to the float16 data type may break the training because of nan. Therefore it is better to use bfloat16.
Hello!
The
main
(a441a3f
) branch of the AQLM repository does not supportflash attention 2
. The error occurs because QuantizedWeight does not have a weight attribute (closed issue #31). For example, the error occurs in implementations:In issue #31 a solution was proposed in the pre_quantization_dtype branch. The last commit ebe8ece tries to solve the Flash Attention 2 problem for the
mistral
type model, but it doesn't help. I tried doing the same thing in themain
(a441a3f
) branch for thellama
type model, ie.and I get error:
RuntimeError: FlashAttention only support fp16 and bf16 data type
. See the details of the error at the very bottom.You can quickly reproduce the experiment by executing
bash flash_attn2_problem.sh
, whereflash_attn2_problem.sh
The full output of the error that occurs:
The text was updated successfully, but these errors were encountered: