Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi gpu loss sync condition, add doc and test #35743

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,7 +3684,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
if (
self.args.average_tokens_across_devices
and (self.model_accepts_loss_kwargs or self.compute_loss_func)
and num_items_in_batch is not None
):
loss *= self.accelerator.num_processes

return (loss, outputs) if return_outputs else loss
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,11 @@ class TrainingArguments:
Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.

average_tokens_across_devices (`bool`, *optional*, defaults to `False`):
Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
num_tokens_in_batch for precise loss calculation. Reference:
https://github.com/huggingface/transformers/issues/34242
"""

framework = "pt"
Expand Down
103 changes: 103 additions & 0 deletions tests/trainer/test_trainer_distributed_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json

import datasets
import torch

from tests.trainer.test_trainer import StoreLossCallback
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)


class TestTrainerDistributedLoss(TestCasePlus):
@require_torch_multi_gpu
def test_trainer(self):
device_count = torch.cuda.device_count()
min_bs = 1
output_dir = self.get_auto_remove_tmp_dir()
for gpu_num, enable, bs, name in (
(1, True, min_bs * device_count, "base"),
(device_count, False, min_bs, "broken"),
(device_count, True, min_bs, "fixed"),
):
distributed_args = f"""--nproc_per_node={gpu_num}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed_loss.py
""".split()
args = f"--output_dir {output_dir}/{name} --per_device_train_batch_size {bs} --average_tokens_across_devices {enable}".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
with open(f"{output_dir}/base_losses.json") as f:
base_loss = json.load(f)
with open(f"{output_dir}/broken_losses.json") as f:
broken_loss = json.load(f)
with open(f"{output_dir}/fixed_losses.json") as f:
fixed_loss = json.load(f)

broken_diff = [abs(base_loss[i] - broken_loss[i]) for i in range(len(base_loss))]
fixed_diff = [abs(base_loss[i] - fixed_loss[i]) for i in range(len(base_loss))]
sum_base = sum(base_loss)
sum_broken = sum(broken_diff)
relative_broken = abs(sum_base - sum_broken) / max(sum_base, sum_broken)

self.assertGreater(max(broken_diff), 0.5)
self.assertLess(max(fixed_diff), 0.005)
self.assertLess(relative_broken, 0.1)


def run_distributed_training(training_args):
set_seed(42)
model_name = "nickypro/tinyllama-15M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:17]")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

model = AutoModelForCausalLM.from_pretrained(model_name)

loss_callback = StoreLossCallback()

training_args.logging_steps = 1
training_args.max_steps = 10
training_args.learning_rate = 3e-4
training_args.disable_tqdm = True
training_args.dataloader_drop_last = True
training_args.report_to = []

trainer = Trainer(
model,
training_args,
train_dataset=tokenized_dataset,
callbacks=[loss_callback],
data_collator=data_collator,
)
trainer.train()
with open(training_args.output_dir + "_losses.json", "w") as f:
json.dump(loss_callback.losses, f)


if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
run_distributed_training(training_args)