MLPs with FbgemmFp8Linear
on Llama-405b-FP8 do not handle batch sizes >1 correctly
#32868
Closed
2 of 4 tasks
Labels
System Info
transformers 4.44.0
torch 2.4.0+cu121
fbgemm_gpu 0.8.0+cu121
Who can help?
@ArthurZucker (also maybe @SunMarc ?)
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
After some digging in pdb, I tracked it down to the quantized MLPs:
I was able to patch it with this monkeypatch:
...which made
model.generate
work as expected.Expected behavior
The quantized MLP layers should not squish batch size and sequence length together. I suspect these lines are at fault, but I'm not sure:
transformers/src/transformers/integrations/fbgemm_fp8.py
Lines 50 to 52 in 52cb403
The text was updated successfully, but these errors were encountered: