@@ -40,7 +40,13 @@ def __init__(
40
40
max_scale : float = 2 ** 32 ,
41
41
) -> None :
42
42
super ().__init__ (
43
- initial_scale , min_scale , growth_factor , backoff_factor , growth_interval , hysteresis , max_scale
43
+ initial_scale ,
44
+ min_scale ,
45
+ growth_factor ,
46
+ backoff_factor ,
47
+ growth_interval ,
48
+ hysteresis ,
49
+ max_scale ,
44
50
)
45
51
self .num_working_param_groups = num_working_param_groups
46
52
self .grad_store = grad_store
@@ -273,11 +279,10 @@ def _create_master_param_current_rank(self, param_list):
273
279
# Backward Reduction Hook #
274
280
###########################
275
281
276
- def _grad_handler (self , param , group_id , grad ):
282
+ def _grad_handler (self , group_id , param ):
277
283
# if run with no_sync context, would not sync grad when backward
278
284
if self .require_grad_sync :
279
285
self ._add_to_bucket (param , group_id )
280
- return grad
281
286
282
287
def _attach_reduction_hook (self ):
283
288
# we iterate over the working params
@@ -286,7 +291,7 @@ def _attach_reduction_hook(self):
286
291
param_group = self ._working_param_groups [group_id ]
287
292
for param in param_group :
288
293
if param .requires_grad :
289
- param .register_hook (partial (self ._grad_handler , param , group_id ))
294
+ param .register_post_accumulate_grad_hook (partial (self ._grad_handler , group_id ))
290
295
291
296
#######################
292
297
# Reduction Functions #
@@ -415,15 +420,22 @@ def _run_reduction(self):
415
420
recieved_grad = torch .zeros_like (flat_grads_list [0 ])
416
421
dist .reduce_scatter (recieved_grad , flat_grads_list , group = self .dp_pg )
417
422
self ._update_partitoned_grad (
418
- non_moe_grad_in_bucket_current_rank , recieved_grad , group_id , 1
423
+ non_moe_grad_in_bucket_current_rank ,
424
+ recieved_grad ,
425
+ group_id ,
426
+ 1 ,
419
427
)
420
428
421
429
if len (moe_grad_list ) > 0 :
422
430
flat_grads_list = list (
423
431
moe_flat_grads .split (len (moe_flat_grads ) // self .moe_extra_dp_pg_size )
424
432
)
425
433
recieved_grad = torch .zeros_like (flat_grads_list [0 ])
426
- dist .reduce_scatter (recieved_grad , flat_grads_list , group = self .moe_extra_dp_pg )
434
+ dist .reduce_scatter (
435
+ recieved_grad ,
436
+ flat_grads_list ,
437
+ group = self .moe_extra_dp_pg ,
438
+ )
427
439
param_slice = self ._world_size // self .moe_extra_dp_pg_size
428
440
recieved_grad = list (recieved_grad .split (len (recieved_grad ) // param_slice ))
429
441
for split_recieved_grad in recieved_grad :
@@ -444,14 +456,25 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List
444
456
self ._add_grad (grad , self ._world_size , group_id , param_id , rank )
445
457
446
458
def _update_partitoned_grad (
447
- self , origin_grad_list : List , flat_grad : torch .Tensor , group_id : int , partition_num : int
459
+ self ,
460
+ origin_grad_list : List ,
461
+ flat_grad : torch .Tensor ,
462
+ group_id : int ,
463
+ partition_num : int ,
448
464
) -> None :
449
465
sync_tensor (flat_grad , origin_grad_list )
450
466
for grad in origin_grad_list :
451
467
param_id = self ._bucket_store .get_param_id_of_grad (grad )
452
468
self ._add_grad (grad , partition_num , group_id , param_id )
453
469
454
- def _add_grad (self , grad : torch .Tensor , partition_num : int , group_id : int , param_id : int , rank : int = 0 ) -> None :
470
+ def _add_grad (
471
+ self ,
472
+ grad : torch .Tensor ,
473
+ partition_num : int ,
474
+ group_id : int ,
475
+ param_id : int ,
476
+ rank : int = 0 ,
477
+ ) -> None :
455
478
if len (self ._grad_store .get_partitioned_gradients_by_param_id (group_id , param_id )) < partition_num :
456
479
self ._grad_store .append_gradients_by_param_id (grad , group_id , param_id )
457
480
else :
@@ -534,6 +557,7 @@ def zero_grad(self, set_to_none=True):
534
557
if param .grad is not None :
535
558
param .grad .detach ()
536
559
param .grad .zero_ ()
560
+ self ._bucket_store .reset_all ()
537
561
538
562
####################
539
563
# Update Parameter #
@@ -655,14 +679,20 @@ def step(self, closure=None):
655
679
for _ in range (self .moe_extra_dp_pg_size )
656
680
]
657
681
dist .all_gather (
658
- all_splited_param , splited_param .to (device ).to (self ._dtype ), group = self .moe_extra_dp_pg
682
+ all_splited_param ,
683
+ splited_param .to (device ).to (self ._dtype ),
684
+ group = self .moe_extra_dp_pg ,
659
685
)
660
686
else :
661
687
all_splited_param = [
662
688
torch .zeros (splited_param .shape , device = device , dtype = self ._dtype )
663
689
for _ in range (self ._world_size )
664
690
]
665
- dist .all_gather (all_splited_param , splited_param .to (device ).to (self ._dtype ), group = self .dp_pg )
691
+ dist .all_gather (
692
+ all_splited_param ,
693
+ splited_param .to (device ).to (self ._dtype ),
694
+ group = self .dp_pg ,
695
+ )
666
696
working_param .data .copy_ (flatten (all_splited_param )[: working_param .numel ()].reshape_as (working_param ))
667
697
self .optim .param_groups [group_id ]["params" ] = self ._master_param_groups_of_current_rank [group_id ]
668
698
@@ -685,7 +715,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
685
715
if norm_type == inf :
686
716
total_norm = max (grad .data .abs ().max () for grad in gradients )
687
717
total_norm_cuda = torch .tensor (
688
- [float (total_norm )], device = get_accelerator ().get_current_device (), dtype = torch .float
718
+ [float (total_norm )],
719
+ device = get_accelerator ().get_current_device (),
720
+ dtype = torch .float ,
689
721
)
690
722
dist .all_reduce (total_norm_cuda , op = torch .distributed .ReduceOp .MAX , group = self .dp_pg )
691
723
total_norm = total_norm_cuda .item ()
@@ -698,10 +730,14 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
698
730
699
731
# Sum across all model parallel GPUs.
700
732
total_norm_exponentiated_cuda = torch .tensor (
701
- [float (total_norm_exponentiated )], device = get_accelerator ().get_current_device (), dtype = torch .float
733
+ [float (total_norm_exponentiated )],
734
+ device = get_accelerator ().get_current_device (),
735
+ dtype = torch .float ,
702
736
)
703
737
torch .distributed .all_reduce (
704
- total_norm_exponentiated_cuda , op = torch .distributed .ReduceOp .SUM , group = self .dp_pg
738
+ total_norm_exponentiated_cuda ,
739
+ op = torch .distributed .ReduceOp .SUM ,
740
+ group = self .dp_pg ,
705
741
)
706
742
total_norm = total_norm_exponentiated_cuda .item () ** (1.0 / norm_type )
707
743
@@ -920,5 +956,8 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
920
956
921
957
def get_master_to_working_map (self ) -> Dict [int , torch .Tensor ]:
922
958
if hasattr (self , "moe_master_to_working_map" ):
923
- return {** self ._param_store .master_to_working_param , ** self .moe_master_to_working_map }
959
+ return {
960
+ ** self ._param_store .master_to_working_param ,
961
+ ** self .moe_master_to_working_map ,
962
+ }
924
963
return self ._param_store .master_to_working_param
0 commit comments