Skip to content

Commit

Permalink
fix 2D parallel crash caused by all-reduce on 2D world_mesh
Browse files Browse the repository at this point in the history
ghstack-source-id: 1c5bf790d7473f6a24124051fcfa1fd2585a56f9
Pull Request resolved: #105
  • Loading branch information
tianyu-l committed Mar 2, 2024
1 parent ce048cd commit 117acf6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.sp_enabled:
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
sp_degree = job_config.training.sequence_parallelism_degree
sp_degree = job_config.training.sequence_parallel_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def main(job_config: JobConfig):

# build dataloader
# need dp world size and rank
# TODO: dp might not always be 0 so we need to handle that more carefully
dp_degree = world_mesh.size(0)
dp_rank = world_mesh.get_local_rank(0)
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
data_loader = build_dataloader_fn(
tokenizer,
Expand Down Expand Up @@ -253,8 +253,8 @@ def main(job_config: JobConfig):
np.max(losses_since_last_log),
)
global_avg_loss, global_max_loss = (
dist_mean(avg_loss, world_mesh),
dist_max(max_loss, world_mesh),
dist_mean(avg_loss, dp_mesh),
dist_max(max_loss, dp_mesh),
)

time_delta = timer() - time_last_log
Expand Down

0 comments on commit 117acf6

Please sign in to comment.