From 9f1eb26d70a786b99848e42cab89cd814b544c93 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:59:37 -0800 Subject: [PATCH] Revert workaround for T5 that sets number of workers to 0 & sync_batch_comm=False (#5420) (#5433) * Revert workers workaround Signed-off-by: MaximumEntropy * Fix in config Signed-off-by: MaximumEntropy * Fix Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Co-authored-by: Oleksii Kuchaiev Signed-off-by: MaximumEntropy Co-authored-by: Sandeep Subramanian Co-authored-by: Oleksii Kuchaiev Signed-off-by: andrusenkoau --- .../megatron_lm_encoder_decoder_model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index f3722753e571..89c54bff4fc4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -337,6 +337,7 @@ def training_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, ) @@ -349,6 +350,7 @@ def training_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, ) @@ -657,6 +659,7 @@ def validation_step_logits(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) else: @@ -668,6 +671,7 @@ def validation_step_logits(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) @@ -700,6 +704,7 @@ def validation_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) else: @@ -711,6 +716,7 @@ def validation_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) @@ -955,7 +961,7 @@ def setup_validation_data(self, cfg): if hasattr(self, '_validation_ds'): consumed_samples = 0 self._validation_dl = self.build_pretraining_data_loader( - self._validation_ds, consumed_samples, num_workers=0 + self._validation_ds, consumed_samples, num_workers=self._cfg.data.num_workers ) def setup_test_data(self, cfg): @@ -1046,6 +1052,7 @@ def dummy(): tensor_shape=tensor_shape, decoder_sequence_length=encoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), ) else: output_tensor = forward_backward_no_pipelining( @@ -1056,6 +1063,7 @@ def dummy(): tensor_shape=tensor_shape, decoder_sequence_length=encoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), ) if output_tensor: