diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1b13787007e9..6ec8b88c07ee 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3602,7 +3602,12 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - loss *= self.args.gradient_accumulation_steps + if num_items_in_batch is not None: + if self.compute_loss_func or self.model_accepts_loss_kwargs: + loss *= self.args.gradient_accumulation_steps + # Average tokens across devices is orthogonal to gradient accumulation + if self.args.average_tokens_across_devices: + loss *= self.args.world_size self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps @@ -3617,6 +3622,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N labels = inputs.pop("labels") else: labels = None + if self.args.average_tokens_across_devices and num_items_in_batch is not None: + num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device) + num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu()) if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 485610dd9baa..54398a6bf9f9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1530,6 +1530,15 @@ class TrainingArguments: }, ) + average_tokens_across_devices: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to " + "synchronize num_tokens_in_batch for precise loss calculation. Reference: " + "https://github.com/huggingface/transformers/issues/34242" + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1763,6 +1772,19 @@ def __post_init__(self): if self.framework == "pt" and is_torch_available(): self.device + # Disable average tokens when using single device + if self.average_tokens_across_devices: + try: + if self.world_size == 1: + logger.warning( + "average_tokens_across_devices is set to True but it is invalid when world size is" + "1. Turn it to False automatically." + ) + self.average_tokens_across_devices = False + except ImportError as e: + logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.") + self.average_tokens_across_devices = False + if self.torchdynamo is not None: warnings.warn( "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"