Skip to content

Commit

Permalink
Fix the hanging problem of init and other problems
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong committed Aug 18, 2024
1 parent e46aa71 commit da50f42
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion megatron/megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def generator_wrapper(group_type, **kwargs):
_POSITION_EMBEDDING_GROUP = group
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks

_LAST_RANK_WHEN_USING_PIPELINE = generator_wrapper('pp')[-1][-1]
_LAST_RANK_WHEN_USING_PIPELINE = decoder_rank_generator.get_ranks('pp')[-1][-1]

# Build the tensor + data parallel groups.
global _TENSOR_AND_DATA_PARALLEL_GROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def get_te_version():
def get_te_version_str():
if hasattr(te, '__version__'):
return str(te.__version__)
return str(te.__version__).split('+')[0]
else:
return version("transformer-engine")

Expand Down Expand Up @@ -895,7 +895,7 @@ def te_checkpoint(
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
if _te_version > packaging.version.Version("1.8.0"):
if _te_version > packaging.version.Version("1.9.0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
Expand Down
5 changes: 3 additions & 2 deletions megatron/megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
'rank': args.rank,
'timeout': timedelta(minutes=args.distributed_timeout_minutes),
}
if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
init_process_group_kwargs['device_id'] = device_id
# TODO: @aoyulong the init_process_group will be hanging if the device_id is set
# if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
# init_process_group_kwargs['device_id'] = device_id

torch.distributed.init_process_group(**init_process_group_kwargs)

Expand Down

0 comments on commit da50f42

Please sign in to comment.