Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare_fa2_from_position_ids error in training with batch_size > 1 #33268

Open
2 of 4 tasks
meliksahturker opened this issue Sep 2, 2024 · 5 comments
Open
2 of 4 tasks
Labels

Comments

@meliksahturker
Copy link

meliksahturker commented Sep 2, 2024

System Info

  • transformers version: 4.44.2
  • Platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): 2.13.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: "ddp", "deepspeed_stage_2"
  • Using GPU in script?: tested on 8xH100 and 1xA100-40GB
  • GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@RhuiDih
@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The PR 31629 allowed packing with no cross-contamination and without requiring to deal with attention masks for flash-attention-2.
However, prepare_fa2_from_position_ids function produces an error when training with a batch_size greater than 1.

Below is an end-to-end example to reproduce the error:

import numpy as np
import torch
import lightning

from transformers import MistralForCausalLM, MistralConfig


config = MistralConfig(max_position_embeddings = 1024,
                       hidden_size = 1024,
                       intermediate_size = 3584,
                       num_hidden_layers = 8,
                       pad_token_id = 0,
                       bos_token_id = 2,
                       eos_token_id = 3)
config._attn_implementation = "flash_attention_2"
batch_size = 2

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, max_num_pack_attempts):
        self.dataset = []
        for _ in range(100_000):
            # Generate samples of different lengths
            sample_len = np.random.randint(10, config.max_position_embeddings)
            tokens = np.random.randint(0, config.vocab_size, size = sample_len)
            self.dataset.append(tokens)
        self.max_num_pack_attempts = max_num_pack_attempts


    def get_single_sample(self):
        idx = np.random.randint(0, len(self.dataset))
        tokens = self.dataset[idx]
        return tokens.tolist()

    def generate_pack(self):
        input_ids = []
        labels = [0] # placeholder for the model to shift right by 1
        position_ids = []
        num_failed_attempts = 0
        
        while (len(input_ids) < config.max_position_embeddings) and (num_failed_attempts < self.max_num_pack_attempts):
            sample = self.get_single_sample()
        
            # If there is empty room
            if len(input_ids) + len(sample) + 1 < config.max_position_embeddings:
                input_ids += [config.bos_token_id] + sample
                labels += sample + [config.eos_token_id]
                position_ids += range(len(sample) + 1)
            else:
                num_failed_attempts += 1
        
        # Pad
        input_ids = input_ids + [config.pad_token_id] * (config.max_position_embeddings - len(input_ids))
        position_ids = position_ids + [config.pad_token_id] * (config.max_position_embeddings - len(position_ids))
        labels = labels + [-100] * (config.max_position_embeddings - len(labels))
    
        return {
                'input_ids': torch.tensor(input_ids),
                'position_ids': torch.tensor(position_ids),
                'labels': torch.tensor(labels, dtype=torch.long)
               }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.generate_pack()


class CustomModel(lightning.LightningModule):
    def __init__(self, config, learning_rate):
        super(CustomModel, self).__init__()
        self.model = MistralForCausalLM(config = config)
        self.learning_rate = learning_rate

        num_params = sum(p.numel() for p in self.model.parameters())
        print(f'Number of parameters in Mistral: {num_params:,}')

    def forward(self, input_ids, position_ids, labels = None):
        return self.model(input_ids = input_ids,
                          position_ids = position_ids,
                          labels = labels,
                          use_cache=False)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        position_ids = batch['position_ids']
        labels = batch['labels']
        
        outputs = self(input_ids, position_ids, labels)
        loss = outputs.loss
        return loss

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, betas = (0.9, 0.95), eps = 1e-8, weight_decay = 0.1)
        return {"optimizer": opt}

data_loader = torch.utils.data.DataLoader(CustomDataset(10), batch_size=batch_size, num_workers=1)
model = CustomModel(config, 3e-4)
print(model.model.model.layers[0].self_attn) # print the model's self attention layer name to make sure it uses FA2

# TRAINER
trainer = lightning.Trainer(
    max_steps = 2_000,
    accelerator="gpu",
    precision = "bf16-mixed",
    limit_train_batches = 1_000
)

trainer.fit(model, train_dataloaders=data_loader)

The error:

Epoch 0:   0%|                                                                                                                                             | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/ubuntu/fa2_from_position_ids_test.py", line 111, in <module>
    trainer.fit(model, train_dataloaders=data_loader)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 75, in optimizer_step
    return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/adamw.py", line 204, in step
    loss = closure()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
    step_output = self._step_fn()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/home/ubuntu/fa2_from_position_ids_test.py", line 91, in training_step
    outputs = self(input_ids, position_ids, labels)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/fa2_from_position_ids_test.py", line 81, in forward
    return self.model(input_ids = input_ids,
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1033, in forward
    outputs = self.model(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 808, in forward
    layer_outputs = decoder_layer(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 549, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 382, in forward
    attn_output = _flash_attention_forward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 272, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 166, in prepare_fa2_from_position_ids
    key = key.view(-1, key.size(-2), key.size(-1))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

When batch_size is set to 1, the training takes place without an error.
I conducted tests on 8xH100 and 1xA100-40GB, trying different training strategies, e.g., "ddp", "deepspeed_stage_2" and ended up with the same error.

Expected behavior

The training should be possible without an error for different batch_size values.

@ArthurZucker
Copy link
Collaborator

Thanks for opening a PR as well, will have a look !

Copy link

github-actions bot commented Oct 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@meliksahturker
Copy link
Author

This issue is not stale and the related PR still awaits merging.

@ArthurZucker
Copy link
Collaborator

Yep, waiting for a test to be added! 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants