Skip to content

Commit

Permalink
fix can't find an ops in fwd
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 5, 2025
1 parent d6bc8da commit 93b2f10
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
8 changes: 7 additions & 1 deletion src/nanotron/parallel/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def clear_all():
def is_async_comm(x):
import re

NON_ASYNC_HANDLE_IDX = ["bwd.layer_mlp_{}_batch_1", "bwd.layer_attn_{}_batch_0"]
NON_ASYNC_HANDLE_IDX = [
# "fwd.layer_attn_{}_batch_0",
# "fwd.layer_mlp_{}_batch_0",
# "fwd.layer_mlp_{}_batch_1",
"bwd.layer_mlp_{}_batch_1",
"bwd.layer_attn_{}_batch_0",
]

patterns = [p.replace("{}", r"\d+") for p in NON_ASYNC_HANDLE_IDX] # Replace {} with regex for numbers
regex = re.compile("^(" + "|".join(patterns) + ")$") # Combine patterns into a single regex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ def forward(

id(tensor)
if async_all_reduce is True:
if isinstance(handle_idx, str):
do_async = is_last_batch_of_attn(handle_idx) is False
else:
do_async = async_all_reduce
# if isinstance(handle_idx, str):
# do_async = is_last_batch_of_attn(handle_idx) is False
# else:
# do_async = async_all_reduce
from nanotron.parallel.comm import is_async_comm

do_async = is_async_comm(handle_idx)

handle = dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=do_async)
if do_async:
Expand Down
6 changes: 5 additions & 1 deletion src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,13 +603,17 @@ def row_linear(
# out, work = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce)
id(out)
# NOTE: why the id(out) doesn't match the id(out) before the all_reduce?
if handle_idx == "fwd.layer_attn_0_batch_0":
assert 1 == 1

out = differentiable_all_reduce_sum(out, group=group, async_all_reduce=async_all_reduce, handle_idx=handle_idx)
if async_all_reduce:
from nanotron.parallel.comm import AsyncCommBucket

# work = AsyncCommBucket.get(orig_out_id)
# work = AsyncCommBucket.pop(orig_out_id)
if handle_idx == "fwd.layer_mlp_1_batch_0":
# if handle_idx == "fwd.layer_mlp_1_batch_0":
if handle_idx == "fwd.layer_attn_0_batch_0":
assert 1 == 1

work = AsyncCommBucket.pop(handle_idx)
Expand Down

0 comments on commit 93b2f10

Please sign in to comment.