Skip to content

Commit

Permalink
Revert "Update mcore parallelism initialization in nemo2 (#10643)"
Browse files Browse the repository at this point in the history
This reverts commit 85e14ca.
  • Loading branch information
yaoyu-33 authored Oct 30, 2024
1 parent d441dca commit 53948d5
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 282 deletions.
116 changes: 17 additions & 99 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def initialize_model_parallel_for_nemo(
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
context_parallel_size=1,
encoder_tensor_model_parallel_size=0,
encoder_pipeline_model_parallel_size=0,
micro_batch_size=None,
global_batch_size=None,
rampup_batch_size=None,
Expand All @@ -122,8 +120,6 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.context_parallel_size = context_parallel_size
app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
Expand All @@ -143,8 +139,6 @@ def initialize_model_parallel_for_nemo(
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
context_parallel_size_=context_parallel_size,
expert_model_parallel_size_=expert_model_parallel_size,
encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
)

Expand All @@ -155,14 +149,12 @@ def initialize_model_parallel_for_nemo(
set_expert_model_parallel_world_size(app_state.expert_model_parallel_size)
set_expert_model_parallel_rank(app_state.expert_model_parallel_rank)

set_pipeline_model_parallel_world_size(
app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size
)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)
set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank)
if HAVE_INTERLEAVED:
set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size)
set_virtual_pipeline_model_parallel_rank(app_state.virtual_pipeline_model_parallel_rank)
set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)

tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
if seed is not None:
Expand Down Expand Up @@ -255,8 +247,6 @@ def fake_initialize_model_parallel(
virtual_pipeline_model_parallel_size_=None,
expert_model_parallel_size_=1,
context_parallel_size_=1,
encoder_tensor_model_parallel_size_=0,
encoder_pipeline_model_parallel_size_=0,
use_tp_pp_dp_mapping=False,
):
"""
Expand Down Expand Up @@ -293,109 +283,37 @@ def fake_initialize_model_parallel(
model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
context_parallel_size = min(context_parallel_size_, world_size)

if encoder_pipeline_model_parallel_size_ is None:
encoder_pipeline_model_parallel_size = 0
else:
encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_

if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0:
encoder_tensor_model_parallel_size = tensor_model_parallel_size
else:
encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_

if encoder_tensor_model_parallel_size > 0:
assert encoder_pipeline_model_parallel_size > 0
assert (
encoder_tensor_model_parallel_size <= tensor_model_parallel_size
), "We do not support encoders with more TP than the decoder."

encoder_model_size = (
encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size
)
decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
total_model_size = encoder_model_size + decoder_model_size

assert world_size % total_model_size == 0, (
f'world_size: {world_size} must be divisible by total world_size: '
f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} '
f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} '
f'* (decoder_)context_parallel_size {context_parallel_size} + '
f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} '
f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} '
f'* context_parallel_size {context_parallel_size}'
assert (
world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) == 0
), f'world_size: {world_size} must be divisible by tensor_model_parallel_size: {tensor_model_parallel_size} times pipeline_model_parallel_size {pipeline_model_parallel_size} times context_parallel_size {context_parallel_size}'
data_parallel_size = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
)
data_parallel_size = world_size // total_model_size

encoder_world_size = encoder_model_size * data_parallel_size
decoder_world_size = decoder_model_size * data_parallel_size
assert encoder_world_size + decoder_world_size == world_size
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size

virtual_pipeline_model_parallel_rank = None
if virtual_pipeline_model_parallel_size_ is not None:
virtual_pipeline_model_parallel_rank = 0

if encoder_world_size > 0:
encoder_rank_generator = RankGenerator(
tp=encoder_tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=encoder_pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=0,
)
else:
encoder_rank_generator = None

decoder_rank_generator = RankGenerator(
rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=expert_model_parallel_size_,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order='tp-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
rank_offset=encoder_world_size,
)

def generator_wrapper(group_type, **kwargs):
from itertools import cycle

"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
return
e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
if group_type == 'pp':
# Map 1 encoder tp rank to several decoder tp ranks, because
# these won't be the same size.
for x, y in zip(cycle(e_ranks), d_ranks):
yield x + y
elif group_type == 'tp-pp':
# For this group, we can just return the concatenated
# groups together, because their sizes are the same.
assert len(e_ranks) == len(d_ranks)
for x, y in zip(e_ranks, d_ranks):
yield x + y
else:
for x in e_ranks:
yield x
for x in d_ranks:
yield x

# Build the data-parallel groups.
all_data_parallel_group_ranks_with_cp = []
for ranks in generator_wrapper('dp'):
for ranks in rank_generator.get_ranks('dp'):
if rank in ranks:
data_parallel_group = list(ranks)
logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')

for ranks_with_cp in generator_wrapper('dp-cp'):
for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
all_data_parallel_group_ranks_with_cp.append(ranks_with_cp)
if rank in ranks_with_cp:
data_parallel_group_with_cp = ranks_with_cp
Expand All @@ -411,7 +329,7 @@ def generator_wrapper(group_type, **kwargs):

# Build the context-parallel groups.
all_context_parallel_group_ranks = []
for ranks in generator_wrapper('cp'):
for ranks in rank_generator.get_ranks('cp'):
all_context_parallel_group_ranks.append(ranks)
if rank in ranks:
context_parallel_group = ranks
Expand All @@ -423,7 +341,7 @@ def generator_wrapper(group_type, **kwargs):

# Build the model-parallel groups.
all_model_parallel_group_ranks = []
for ranks in generator_wrapper('tp-pp'):
for ranks in rank_generator.get_ranks('tp-pp'):
all_model_parallel_group_ranks.append(ranks)
if rank in ranks:
logging.info(f'Rank {rank} has model parallel group: {list(ranks)}')
Expand All @@ -432,7 +350,7 @@ def generator_wrapper(group_type, **kwargs):
# Build the tensor model-parallel groups.
all_tensor_model_parallel_group_ranks = []
tensor_model_parallel_group = None
for ranks in generator_wrapper('tp'):
for ranks in rank_generator.get_ranks('tp'):
all_tensor_model_parallel_group_ranks.append(ranks)
if rank in ranks:
tensor_model_parallel_group = ranks
Expand All @@ -446,7 +364,7 @@ def generator_wrapper(group_type, **kwargs):
# EP rank
expert_model_parallel_rank = 0
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
for ranks in generator_wrapper('ep', independent_ep=True):
for ranks in rank_generator.get_ranks('ep', independent_ep=True):
if rank in ranks:
expert_model_parallel_rank = list(ranks).index(rank)

Expand All @@ -457,7 +375,7 @@ def generator_wrapper(group_type, **kwargs):
pipeline_model_parallel_group = None
embedding_group = None
embedding_rank = None
for ranks in generator_wrapper('pp'):
for ranks in rank_generator.get_ranks('pp'):
all_pipeline_model_parallel_group_ranks.append(ranks)
if rank in ranks:
pipeline_model_parallel_group = ranks
Expand Down
4 changes: 0 additions & 4 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def init_parallel_ranks(
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
context_parallel_size=parallel_config.context_parallel_size,
encoder_tensor_model_parallel_size=getattr(parallel_config, "encoder_tensor_model_parallel_size", 0),
encoder_pipeline_model_parallel_size=getattr(parallel_config, "encoder_pipeline_model_parallel_size", 0),
seed=seed,
pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
use_fp8=fp8,
Expand Down Expand Up @@ -115,8 +113,6 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
encoder_pipeline_model_parallel_size=app_state.encoder_pipeline_model_parallel_size,
encoder_tensor_model_parallel_size=app_state.encoder_tensor_model_parallel_size,
context_parallel_size=app_state.context_parallel_size,
expert_model_parallel_size=app_state.expert_model_parallel_size,
)
Expand Down
Loading

0 comments on commit 53948d5

Please sign in to comment.