Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLPs with FbgemmFp8Linear on Llama-405b-FP8 do not handle batch sizes >1 correctly #32868

Closed
2 of 4 tasks
vgel opened this issue Aug 18, 2024 · 2 comments · Fixed by #33239
Closed
2 of 4 tasks

MLPs with FbgemmFp8Linear on Llama-405b-FP8 do not handle batch sizes >1 correctly #32868

vgel opened this issue Aug 18, 2024 · 2 comments · Fixed by #33239
Labels

Comments

@vgel
Copy link
Contributor

vgel commented Aug 18, 2024

System Info

transformers 4.44.0
torch 2.4.0+cu121
fbgemm_gpu 0.8.0+cu121

Who can help?

@ArthurZucker (also maybe @SunMarc ?)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

>>> model_name = "meta-llama/Meta-Llama-3.1-405B-FP8"
>>> base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> train_strs = [
...  'The Golden Gate Bridge AI is talking to the user.\n\nUser: Who are you?\n\nAI: That',
...  'An AI is talking to the user.\n\nUser: Who are you?\n\nAI: That',
...  'The Golden Gate Bridge AI is talking to the user.\n\nUser: Who are you?\n\nAI: I',
...  'An AI is talking to the user.\n\nUser: Who are you?\n\nAI: I',
...  'The Golden Gate Bridge AI is talking to the user.\n\nUser: Who are you?\n\nAI: I can'
... ]
>>> encoded_batch = tokenizer(train_strs[:32], padding=True, return_tensors="pt").to(model.device)
>>> out = base_model(**encoded_batch)
...
File python3.10/site-packages/transformers/models/llama/modeling_llama.py:751, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    749 hidden_states = self.post_attention_layernorm(hidden_states)
    750 hidden_states = self.mlp(hidden_states)
--> 751 hidden_states = residual + hidden_states
    753 outputs = (hidden_states,)
    755 if output_attentions:

RuntimeError: The size of tensor a (23) must match the size of tensor b (736) at non-singleton dimension 1
>>> # Note 736 = 32 * 23, the maximum sequence length

After some digging in pdb, I tracked it down to the quantized MLPs:

>>> emd = base_model.model.embed_tokens
>>> emd.shape
torch.Size([32, 23, 16384])
>>> type(base_model.model.model.layers[0].mlp.up_proj)
torch.nn.modules.linear.Linear
>>> base_model.model.model.layers[0].mlp(emd).shape
torch.Size([32, 23, 16384])
>>> type(base_model.model.model.layers[1].mlp.up_proj)
transformers.integrations.fbgemm_fp8.FbgemmFp8Linear
>>> base_model.model.model.layers[1].mlp(emd).shape
torch.Size([736, 16384]) # <-------------------------------- wrong!!

I was able to patch it with this monkeypatch:

>>> class FixedQuantedMLP(torch.nn.Module):
...     def __init__(self, mlp):
...         super().__init__()
...         self.mlp = mlp
... 
...     def forward(self, x):
...         shape = x.shape
...         x = self.mlp(x)
...         return x.reshape(shape) 

>>> def fix_layer_mlp(layer):
...     layer.old_mlp = layer.mlp
...     layer.mlp = FixedQuantedMLP(layer.mlp)

>>> for layer in base_model.model.layers: fix_layer_mlp(layer)

...which made model.generate work as expected.

Expected behavior

The quantized MLP layers should not squish batch size and sequence length together. I suspect these lines are at fault, but I'm not sure:

x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
)

@vgel vgel added the bug label Aug 18, 2024
@SunMarc
Copy link
Member

SunMarc commented Aug 19, 2024

Thanks for the detailed report @vgel ! This is indeed a bug. I forgot that calling view modifies the tensor inplace. Would you like to open a PR to fix this ? As you tested, you just need to reshape the tensor to its original shape just after quantize_fp8_per_row ops.

@vgel
Copy link
Contributor Author

vgel commented Sep 1, 2024

Thanks for the detailed report @vgel ! This is indeed a bug. I forgot that calling view modifies the tensor inplace. Would you like to open a PR to fix this ? As you tested, you just need to reshape the tensor to its original shape just after quantize_fp8_per_row ops.

Sure, just opened a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants