Skip to content

Commit

Permalink
Adding init_model_parallel to FabricMegatronStrategy (#10733)
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Romeijn <mromeijn@nvidia.com>
  • Loading branch information
marcromeyn authored and monica-sekoyan committed Oct 11, 2024
1 parent 142dad2 commit 0903a0b
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions nemo/lightning/fabric/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
no_ddp_communication_hook: bool = True,
output_data_idx: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
init_model_parallel: bool = True,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -97,6 +98,7 @@ def __init__(
self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
self.sequence_parallel = sequence_parallel
self.pipeline_dtype = pipeline_dtype
self._init_model_parallel = init_model_parallel

self.no_ddp_communication_hook = no_ddp_communication_hook
self.megatron_callbacks = CallbackConnector()
Expand Down Expand Up @@ -180,6 +182,9 @@ def setup_module(self, module: Module) -> MegatronParallel:
convert_module_fn=convert_module_fn,
)

if self._init_model_parallel:
megatron_parallel.init_model_parallel()

if not self.ddp_config:
from megatron.core import mpu

Expand Down

0 comments on commit 0903a0b

Please sign in to comment.