Skip to content

Commit

Permalink
Disable sync_batch_comm in validation_step for GPT (#5397)
Browse files Browse the repository at this point in the history
* disable sync_batch_comm in validation_step

Signed-off-by: ericharper <complex451@gmail.com>

* Read sync_batch_comm from config or default to False

Signed-off-by: Markel Sanz Ausin <markelsanz14@gmail.com>

* Update megatron_gpt_config to default sync_batch_comm to False to avoid CUDA error

Signed-off-by: Markel Sanz Ausin <markelsanz14@gmail.com>

* Empty

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Comment out test

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: Markel Sanz Ausin <markelsanz14@gmail.com>
Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Markel Sanz Ausin <markelsanz14@gmail.com>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
  • Loading branch information
5 people authored Nov 16, 2022
1 parent b211849 commit 01cd8b6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ model:
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
sync_batch_comm: True # Enable stream synchronization after each p2p communication between pipeline stages
sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down Expand Up @@ -196,4 +196,4 @@ model:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5
min_lr: 2e-5
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def training_step(self, batch, batch_idx):
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
sequence_parallel_enabled=self.cfg.get('sequence_parallel', False),
sync_batch_comm=self.cfg.get('sync_batch_comm', True),
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
num_micro_batches_with_partial_activation_checkpoints=self.cfg.get(
'num_micro_batches_with_partial_activation_checkpoints', None
),
Expand Down Expand Up @@ -541,7 +541,7 @@ def validation_step(self, batch, batch_idx):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
sequence_parallel_enabled=self.cfg.get('sequence_parallel', False),
sync_batch_comm=self.cfg.get('sync_batch_comm', True),
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
)

# only the last stage of the pipeline returns losses
Expand Down Expand Up @@ -792,7 +792,8 @@ def setup(self, stage=None):
else:
self.model.sync_initial_word_embeddings()

self.setup_transformer_engine_tp_groups()
if self.cfg.get('transformer_engine', False):
self.setup_transformer_engine_tp_groups()

def setup_training_data(self, cfg):
if hasattr(self, '_train_ds'):
Expand Down Expand Up @@ -841,7 +842,8 @@ def dummy():
self.trainer.strategy.launcher.launch(dummy, trainer=self.trainer)
self.trainer.strategy.setup_environment()

self.setup_transformer_engine_tp_groups()
if self.cfg.get('transformer_engine', False):
self.setup_transformer_engine_tp_groups()

# set the default sampling params if it is None.
# default do greedy sampling
Expand Down

0 comments on commit 01cd8b6

Please sign in to comment.