You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
The PR 31629 allowed packing with no cross-contamination and without requiring to deal with attention masks for flash-attention-2.
However, prepare_fa2_from_position_ids function produces an error when training with a batch_size greater than 1.
Below is an end-to-end example to reproduce the error:
import numpy as np
import torch
import lightning
from transformers import MistralForCausalLM, MistralConfig
config = MistralConfig(max_position_embeddings = 1024,
hidden_size = 1024,
intermediate_size = 3584,
num_hidden_layers = 8,
pad_token_id = 0,
bos_token_id = 2,
eos_token_id = 3)
config._attn_implementation = "flash_attention_2"
batch_size = 2
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, max_num_pack_attempts):
self.dataset = []
for _ in range(100_000):
# Generate samples of different lengths
sample_len = np.random.randint(10, config.max_position_embeddings)
tokens = np.random.randint(0, config.vocab_size, size = sample_len)
self.dataset.append(tokens)
self.max_num_pack_attempts = max_num_pack_attempts
def get_single_sample(self):
idx = np.random.randint(0, len(self.dataset))
tokens = self.dataset[idx]
return tokens.tolist()
def generate_pack(self):
input_ids = []
labels = [0] # placeholder for the model to shift right by 1
position_ids = []
num_failed_attempts = 0
while (len(input_ids) < config.max_position_embeddings) and (num_failed_attempts < self.max_num_pack_attempts):
sample = self.get_single_sample()
# If there is empty room
if len(input_ids) + len(sample) + 1 < config.max_position_embeddings:
input_ids += [config.bos_token_id] + sample
labels += sample + [config.eos_token_id]
position_ids += range(len(sample) + 1)
else:
num_failed_attempts += 1
# Pad
input_ids = input_ids + [config.pad_token_id] * (config.max_position_embeddings - len(input_ids))
position_ids = position_ids + [config.pad_token_id] * (config.max_position_embeddings - len(position_ids))
labels = labels + [-100] * (config.max_position_embeddings - len(labels))
return {
'input_ids': torch.tensor(input_ids),
'position_ids': torch.tensor(position_ids),
'labels': torch.tensor(labels, dtype=torch.long)
}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.generate_pack()
class CustomModel(lightning.LightningModule):
def __init__(self, config, learning_rate):
super(CustomModel, self).__init__()
self.model = MistralForCausalLM(config = config)
self.learning_rate = learning_rate
num_params = sum(p.numel() for p in self.model.parameters())
print(f'Number of parameters in Mistral: {num_params:,}')
def forward(self, input_ids, position_ids, labels = None):
return self.model(input_ids = input_ids,
position_ids = position_ids,
labels = labels,
use_cache=False)
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
position_ids = batch['position_ids']
labels = batch['labels']
outputs = self(input_ids, position_ids, labels)
loss = outputs.loss
return loss
def configure_optimizers(self):
opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, betas = (0.9, 0.95), eps = 1e-8, weight_decay = 0.1)
return {"optimizer": opt}
data_loader = torch.utils.data.DataLoader(CustomDataset(10), batch_size=batch_size, num_workers=1)
model = CustomModel(config, 3e-4)
print(model.model.model.layers[0].self_attn) # print the model's self attention layer name to make sure it uses FA2
# TRAINER
trainer = lightning.Trainer(
max_steps = 2_000,
accelerator="gpu",
precision = "bf16-mixed",
limit_train_batches = 1_000
)
trainer.fit(model, train_dataloaders=data_loader)
The error:
Epoch 0: 0%| | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/ubuntu/fa2_from_position_ids_test.py", line 111, in <module>
trainer.fit(model, train_dataloaders=data_loader)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
results = self._run_stage()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
self.fit_loop.run()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
self._optimizer_step(batch_idx, closure)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
call._call_lightning_module_hook(
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 75, in optimizer_step
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 484, in wrapper
out = func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
ret = func(self, *args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/optim/adamw.py", line 204, in step
loss = closure()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
closure_result = closure()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
self._result = self.closure(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
step_output = self._step_fn()
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
return self.lightning_module.training_step(*args, **kwargs)
File "/home/ubuntu/fa2_from_position_ids_test.py", line 91, in training_step
outputs = self(input_ids, position_ids, labels)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/fa2_from_position_ids_test.py", line 81, in forward
return self.model(input_ids = input_ids,
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1033, in forward
outputs = self.model(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 808, in forward
layer_outputs = decoder_layer(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 549, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 382, in forward
attn_output = _flash_attention_forward(
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 272, in _flash_attention_forward
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
File "/home/ubuntu/.local/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 166, in prepare_fa2_from_position_ids
key = key.view(-1, key.size(-2), key.size(-1))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
When batch_size is set to 1, the training takes place without an error.
I conducted tests on 8xH100 and 1xA100-40GB, trying different training strategies, e.g., "ddp", "deepspeed_stage_2" and ended up with the same error.
Expected behavior
The training should be possible without an error for different batch_size values.
The text was updated successfully, but these errors were encountered:
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.44.2Who can help?
@RhuiDih
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The PR 31629 allowed packing with no cross-contamination and without requiring to deal with attention masks for flash-attention-2.
However, prepare_fa2_from_position_ids function produces an error when training with a batch_size greater than 1.
Below is an end-to-end example to reproduce the error:
The error:
When batch_size is set to 1, the training takes place without an error.
I conducted tests on 8xH100 and 1xA100-40GB, trying different training strategies, e.g., "ddp", "deepspeed_stage_2" and ended up with the same error.
Expected behavior
The training should be possible without an error for different batch_size values.
The text was updated successfully, but these errors were encountered: