diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index 04069f4aba16..73ef81fb1ecb 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -74,6 +74,7 @@ def __init__( no_ddp_communication_hook: bool = True, output_data_idx: bool = False, pipeline_dtype: Optional[torch.dtype] = None, + init_model_parallel: bool = True, **kwargs: Any, ) -> None: super().__init__( @@ -97,6 +98,7 @@ def __init__( self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size self.sequence_parallel = sequence_parallel self.pipeline_dtype = pipeline_dtype + self._init_model_parallel = init_model_parallel self.no_ddp_communication_hook = no_ddp_communication_hook self.megatron_callbacks = CallbackConnector() @@ -180,6 +182,9 @@ def setup_module(self, module: Module) -> MegatronParallel: convert_module_fn=convert_module_fn, ) + if self._init_model_parallel: + megatron_parallel.init_model_parallel() + if not self.ddp_config: from megatron.core import mpu