Skip to content

Commit 87d09ab

Browse files
ver217wangbluo
authored andcommitted
[zero] support multiple (partial) backward passes (hpcaitech#5596)
* [zero] support multiple (partial) backward passes * [misc] update requirements
1 parent 8743f6c commit 87d09ab

File tree

3 files changed

+56
-15
lines changed

3 files changed

+56
-15
lines changed

colossalai/zero/low_level/bookkeeping/bucket_store.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
class BucketStore(BaseStore):
1212
def __init__(self, torch_pg: ProcessGroup):
1313
super().__init__(torch_pg)
14+
self.reset_all()
1415

16+
def reset_all(self) -> None:
1517
# init
1618
self.current_group_id = 0
1719
self._num_elements_in_bucket = 0

colossalai/zero/low_level/low_level_optim.py

+53-14
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def __init__(
4040
max_scale: float = 2**32,
4141
) -> None:
4242
super().__init__(
43-
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
43+
initial_scale,
44+
min_scale,
45+
growth_factor,
46+
backoff_factor,
47+
growth_interval,
48+
hysteresis,
49+
max_scale,
4450
)
4551
self.num_working_param_groups = num_working_param_groups
4652
self.grad_store = grad_store
@@ -273,11 +279,10 @@ def _create_master_param_current_rank(self, param_list):
273279
# Backward Reduction Hook #
274280
###########################
275281

276-
def _grad_handler(self, param, group_id, grad):
282+
def _grad_handler(self, group_id, param):
277283
# if run with no_sync context, would not sync grad when backward
278284
if self.require_grad_sync:
279285
self._add_to_bucket(param, group_id)
280-
return grad
281286

282287
def _attach_reduction_hook(self):
283288
# we iterate over the working params
@@ -286,7 +291,7 @@ def _attach_reduction_hook(self):
286291
param_group = self._working_param_groups[group_id]
287292
for param in param_group:
288293
if param.requires_grad:
289-
param.register_hook(partial(self._grad_handler, param, group_id))
294+
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))
290295

291296
#######################
292297
# Reduction Functions #
@@ -415,15 +420,22 @@ def _run_reduction(self):
415420
recieved_grad = torch.zeros_like(flat_grads_list[0])
416421
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
417422
self._update_partitoned_grad(
418-
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
423+
non_moe_grad_in_bucket_current_rank,
424+
recieved_grad,
425+
group_id,
426+
1,
419427
)
420428

421429
if len(moe_grad_list) > 0:
422430
flat_grads_list = list(
423431
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
424432
)
425433
recieved_grad = torch.zeros_like(flat_grads_list[0])
426-
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
434+
dist.reduce_scatter(
435+
recieved_grad,
436+
flat_grads_list,
437+
group=self.moe_extra_dp_pg,
438+
)
427439
param_slice = self._world_size // self.moe_extra_dp_pg_size
428440
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
429441
for split_recieved_grad in recieved_grad:
@@ -444,14 +456,25 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List
444456
self._add_grad(grad, self._world_size, group_id, param_id, rank)
445457

446458
def _update_partitoned_grad(
447-
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
459+
self,
460+
origin_grad_list: List,
461+
flat_grad: torch.Tensor,
462+
group_id: int,
463+
partition_num: int,
448464
) -> None:
449465
sync_tensor(flat_grad, origin_grad_list)
450466
for grad in origin_grad_list:
451467
param_id = self._bucket_store.get_param_id_of_grad(grad)
452468
self._add_grad(grad, partition_num, group_id, param_id)
453469

454-
def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
470+
def _add_grad(
471+
self,
472+
grad: torch.Tensor,
473+
partition_num: int,
474+
group_id: int,
475+
param_id: int,
476+
rank: int = 0,
477+
) -> None:
455478
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
456479
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
457480
else:
@@ -534,6 +557,7 @@ def zero_grad(self, set_to_none=True):
534557
if param.grad is not None:
535558
param.grad.detach()
536559
param.grad.zero_()
560+
self._bucket_store.reset_all()
537561

538562
####################
539563
# Update Parameter #
@@ -655,14 +679,20 @@ def step(self, closure=None):
655679
for _ in range(self.moe_extra_dp_pg_size)
656680
]
657681
dist.all_gather(
658-
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
682+
all_splited_param,
683+
splited_param.to(device).to(self._dtype),
684+
group=self.moe_extra_dp_pg,
659685
)
660686
else:
661687
all_splited_param = [
662688
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
663689
for _ in range(self._world_size)
664690
]
665-
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
691+
dist.all_gather(
692+
all_splited_param,
693+
splited_param.to(device).to(self._dtype),
694+
group=self.dp_pg,
695+
)
666696
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
667697
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
668698

@@ -685,7 +715,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
685715
if norm_type == inf:
686716
total_norm = max(grad.data.abs().max() for grad in gradients)
687717
total_norm_cuda = torch.tensor(
688-
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
718+
[float(total_norm)],
719+
device=get_accelerator().get_current_device(),
720+
dtype=torch.float,
689721
)
690722
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
691723
total_norm = total_norm_cuda.item()
@@ -698,10 +730,14 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
698730

699731
# Sum across all model parallel GPUs.
700732
total_norm_exponentiated_cuda = torch.tensor(
701-
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
733+
[float(total_norm_exponentiated)],
734+
device=get_accelerator().get_current_device(),
735+
dtype=torch.float,
702736
)
703737
torch.distributed.all_reduce(
704-
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
738+
total_norm_exponentiated_cuda,
739+
op=torch.distributed.ReduceOp.SUM,
740+
group=self.dp_pg,
705741
)
706742
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
707743

@@ -920,5 +956,8 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
920956

921957
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
922958
if hasattr(self, "moe_master_to_working_map"):
923-
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
959+
return {
960+
**self._param_store.master_to_working_param,
961+
**self.moe_master_to_working_map,
962+
}
924963
return self._param_store.master_to_working_param

requirements/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ click
88
fabric
99
contexttimer
1010
ninja
11-
torch>=1.12
11+
torch>=2.1.0
1212
safetensors
1313
einops
1414
pydantic

0 commit comments

Comments
 (0)