Skip to content

Commit

Permalink
Revert workaround for T5 that sets number of workers to 0 & sync_batc…
Browse files Browse the repository at this point in the history
…h_comm=False (NVIDIA#5420) (NVIDIA#5433)

* Revert workers workaround

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix in config

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
3 people authored and andrusenkoau committed Jan 5, 2023
1 parent df43dc5 commit 9f1eb26
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def training_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
)
Expand All @@ -349,6 +350,7 @@ def training_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
)
Expand Down Expand Up @@ -657,6 +659,7 @@ def validation_step_logits(self, batch, batch_idx):
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
else:
Expand All @@ -668,6 +671,7 @@ def validation_step_logits(self, batch, batch_idx):
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)

Expand Down Expand Up @@ -700,6 +704,7 @@ def validation_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)
else:
Expand All @@ -711,6 +716,7 @@ def validation_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
decoder_sequence_length=decoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
)

Expand Down Expand Up @@ -955,7 +961,7 @@ def setup_validation_data(self, cfg):
if hasattr(self, '_validation_ds'):
consumed_samples = 0
self._validation_dl = self.build_pretraining_data_loader(
self._validation_ds, consumed_samples, num_workers=0
self._validation_ds, consumed_samples, num_workers=self._cfg.data.num_workers
)

def setup_test_data(self, cfg):
Expand Down Expand Up @@ -1046,6 +1052,7 @@ def dummy():
tensor_shape=tensor_shape,
decoder_sequence_length=encoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
)
else:
output_tensor = forward_backward_no_pipelining(
Expand All @@ -1056,6 +1063,7 @@ def dummy():
tensor_shape=tensor_shape,
decoder_sequence_length=encoder_seq_length,
dtype=self.autocast_dtype,
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
)

if output_tensor:
Expand Down

0 comments on commit 9f1eb26

Please sign in to comment.