Skip to content

Commit

Permalink
merges for container
Browse files Browse the repository at this point in the history
  • Loading branch information
wdykas committed Dec 8, 2023
2 parents 3b5e4f5 + 419ad62 commit ec38d24
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ model:

## Flash Attention
use_flash_attention: False # Use flash attention in self-attention module, this config does nothing when transformer_engine=True

## Network
sharp: False # Enable the use of SHARP for NCCL communications. This is going to be ignored if the network doesn't support SHARP.

data:
# Path to data must be specified by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,16 +639,10 @@ def training_step(self, dataloader_iter, batch_idx):
# it should be casted to other pipeline stages for logging.
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
)
if torch.distributed.get_rank() == get_last_rank():
torch.distributed.send(loss_mean, 0)
elif torch.distributed.get_rank() == 0:
torch.distributed.recv(loss_mean, get_last_rank())
self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)

# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _training_strategy(self) -> NLPDDPStrategy:
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
sharp=cfg.model.get('sharp', False),
)

def _grad_scaler(self) -> GradScaler:
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class NLPDDPStrategy(DDPStrategy):
no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2
with FP32 gradient accumulation.
nccl_communicator_config_path: Path to the yaml file with NCCL communicator options
sharp: Apply SHARP to data-parallel proc groups.
"""

def __init__(
Expand All @@ -91,6 +92,7 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
no_ddp_communication_hook: bool = False,
nccl_communicator_config_path: Optional[str] = None,
sharp: bool = False,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
if not HAVE_APEX:
Expand All @@ -106,6 +108,7 @@ def __init__(

self.no_ddp_communication_hook = no_ddp_communication_hook
self.nccl_communicator_config_path = nccl_communicator_config_path
self.sharp = sharp

def setup(self, trainer: "pl.Trainer") -> None:
"""
Expand Down Expand Up @@ -199,6 +202,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
nccl_communicator_config_path=self.nccl_communicator_config_path,
use_sharp=self.sharp,
)

# assert that fake tp and pp rank match after model parallel init
Expand Down

0 comments on commit ec38d24

Please sign in to comment.