diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index b6d70dfb649e..72c14555a8ad 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -337,6 +337,7 @@ def training_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, ) @@ -349,6 +350,7 @@ def training_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, custom_sync_context_handler=custom_sync_context_handler, ) @@ -657,6 +659,7 @@ def validation_step_logits(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) else: @@ -668,6 +671,7 @@ def validation_step_logits(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) @@ -700,6 +704,7 @@ def validation_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) else: @@ -711,6 +716,7 @@ def validation_step(self, batch, batch_idx): tensor_shape=tensor_shape, decoder_sequence_length=decoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) @@ -951,7 +957,7 @@ def setup_validation_data(self, cfg): if hasattr(self, '_validation_ds'): consumed_samples = 0 self._validation_dl = self.build_pretraining_data_loader( - self._validation_ds, consumed_samples, num_workers=0 + self._validation_ds, consumed_samples, num_workers=self._cfg.data.num_workers ) def setup_test_data(self, cfg): @@ -1042,6 +1048,7 @@ def dummy(): tensor_shape=tensor_shape, decoder_sequence_length=encoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), ) else: output_tensor = forward_backward_no_pipelining( @@ -1052,6 +1059,7 @@ def dummy(): tensor_shape=tensor_shape, decoder_sequence_length=encoder_seq_length, dtype=self.autocast_dtype, + sync_batch_comm=self.cfg.get('sync_batch_comm', False), ) if output_tensor: