Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi groups for broadcast of sharding stage 2 #46894

Merged
merged 7 commits into from
Oct 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def _set_reduce_overlap(self, reduce_overlap):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap

def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
def _set_broadcast_overlap(self,
broadcast_overlap,
layers=None,
num_groups=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
Expand All @@ -205,6 +208,27 @@ def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params

if num_groups is None or num_groups > len(self._broadcast_order_params):
warnings.warn(
"The num_groups for broadcast is larger than the number of params to be broadcast. "
"It will set to default value: 1 (use the default sharding group)."
)
num_groups = 1

assert isinstance(
num_groups,
int) and num_groups > 0, "num_groups should be a positive integer"

self._number_of_broadcast_groups = num_groups
self._broadcast_groups = [
None for _ in range(self._number_of_broadcast_groups)
]
self._broadcast_groups[0] = self._group

ranks = self._group.ranks
for i in range(1, self._number_of_broadcast_groups):
self._broadcast_groups[i] = new_group(ranks)

def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
Expand Down Expand Up @@ -487,14 +511,17 @@ def __impl__(x, y):
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
group_idx = 0

param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
task = broadcast(
tensor=x,
src=self._group.ranks[self._param2rank[x.name]],
group=self._group,
sync_op=False)
group = self._broadcast_groups[group_idx]
group_idx = (group_idx + 1) % self._number_of_broadcast_groups
task = broadcast(tensor=x,
src=group.ranks[self._param2rank[x.name]],
group=group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task

Expand Down