Skip to content

Commit

Permalink
Delay reduce-scatter for ZeRO3 leaf modules (deepspeedai#5008)
Browse files Browse the repository at this point in the history
ZeRO3 sets hooks on parameters to run reduce-scatter. This is often
problematic for MoE models. Our data parallel processes may activate
different sets of experts, but the hook is not fired unless the expert
is activated at a forward pass. The reduce-scatter is called only on
some processes in this case.

This PR delays reduce-scatter for ZeRO3 leaf modules (Refer to deepspeedai#4966) to
address the issue.
We no longer set reduce-scatter hooks on parameters of the leaf modules.
Instead, we launch reduce-scatter on all parameters belonging to the
leaf module when exiting the module during the backward pass.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
2 people authored and amaurya committed Feb 17, 2024
1 parent 698a961 commit 6fe2176
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 93 deletions.
82 changes: 14 additions & 68 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,88 +8,31 @@
from collections import OrderedDict
from deepspeed.utils import z3_leaf_module
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.utils import apply_to_tensors_only
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import _init_external_params
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator

FWD_MODULE_STACK = list()


def is_builtin_type(obj):
# https://stackoverflow.com/a/17795199
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"


def isinstance_namedtuple(obj: object) -> bool:
"""
Is this an instance of namedtuple/NamedTuple?
From: https://stackoverflow.com/a/62692640
Args:
obj (object): An object.
Returns:
bool: True if namedtuple/NamedTuple else False.
"""
return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields')


# ensure we only warn once, otherwise every iteration will trigger a warning
warned = False


def _apply_to_tensors_only(module, functional, backward_function, outputs):
"""
Apply a torch.autograd.Function that calls a `backward_function` to every Tensor in `outputs`.
Args:
module (torch.nn.Module): A torch module
functional (Type[torch.autograd.Function]): The function class to apply.
backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to
`functional.apply`.
outputs (Any): The output of `module`.
def _apply_backward_to_tensors_only(module, functional, backward_function, outputs):

Returns:
Any: The output of `module`.
"""
if isinstance(outputs, (tuple, list)):
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)

if isinstance_namedtuple(outputs):
# namedtuples require a slightly different syntax.
return outputs.__class__(*touched_outputs)

return outputs.__class__(touched_outputs)
elif isinstance(outputs, dict):
# apply inplace to avoid recreating dict inherited objects
for key in outputs.keys():
outputs[key] = _apply_to_tensors_only(module, functional, backward_function, outputs[key])
return outputs

elif isinstance(outputs, torch.Tensor):
def apply_to_tensor_fn(tensor):
# this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter
touched_outputs = functional.apply(module, backward_function, outputs)
touched_outputs = functional.apply(module, backward_function, tensor)

# restore zero param attributes if those get stripped by `backward_function`
if not is_zero_param(touched_outputs) and is_zero_param(outputs):
touched_outputs.ds_param_alias = outputs
if not is_zero_param(touched_outputs) and is_zero_param(tensor):
touched_outputs.ds_param_alias = tensor
return touched_outputs
else:
if not is_builtin_type(outputs):
global warned
if not warned and dist.get_rank() == 0:
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
warned = True
return outputs

return apply_to_tensors_only(apply_to_tensor_fn, outputs)


#for each tensor in outputs run the forward_function and register backward_function as hook
Expand Down Expand Up @@ -384,7 +327,10 @@ def _register_hooks_recursively(self, module, count=[0]):

#print(f"{module.__class__} : {module.id}")

if not z3_leaf_module(module):
if z3_leaf_module(module):
for param in module.parameters():
param.ds_z3_leaf_module = module
else:
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
Expand Down Expand Up @@ -448,7 +394,7 @@ def _run_before_backward_function(sub_module):
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output)
return _apply_backward_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output)

#This is an alternate to doing _post_backward_module_hook
#it uses tensor.register_hook instead of using torch.autograd.Function
Expand Down Expand Up @@ -478,7 +424,7 @@ def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs)
return _apply_backward_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs)

# Pre forward hook
self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
Expand Down
52 changes: 50 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER
from deepspeed.accelerator import get_accelerator
import time
from deepspeed.utils import z3_leaf_parameter

# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
Expand Down Expand Up @@ -377,6 +379,7 @@ def __init__(
#creates backward hooks for gradient partitioning
###Calls all gather param
self._grad_acc_hooks = []
self._leaf_module_hooks = []
self.create_reduce_and_remove_grad_hooks()

#exit(0)
Expand All @@ -399,6 +402,8 @@ def destroy(self):
self.parameter_offload.destroy()
for hook in self._grad_acc_hooks:
hook.remove()
for hook in self._leaf_module_hooks:
hook.remove()
print_rank_0("Removed grad acc hooks", force=False)
del self.__ipg_bucket_flat_buffer

Expand Down Expand Up @@ -1112,6 +1117,7 @@ def overlapping_partition_gradients_reduce_epilogue(self):
def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
self.leaf_parameters = defaultdict(list)
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
if param.requires_grad:
Expand All @@ -1134,10 +1140,52 @@ def reduce_partition_and_remove_grads(*notneeded):
self.grad_accs.append(grad_acc)

#print(f"param grad fn {param.expand_as(param).grad_fn}")
wrapper(param)
if z3_leaf_parameter(param):
self.leaf_parameters[param.ds_z3_leaf_module].append(param)
else:
wrapper(param)

# Partition the parameter after creating the hook
param.partition()

# We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done
for leaf_module, leaf_parameters in self.leaf_parameters.items():

def wrapper(params):

def forward_pre_hook(module, input):
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
module._leaf_module_inputs_remaining = 0

@instrument_w_nvtx
def reduce_leaf_module_grads(grad):
module._leaf_module_inputs_remaining -= 1
# Make sure everything is done in the leaf module
if module._leaf_module_inputs_remaining == 0:
for param in params:
if param.grad is None:
param.grad = torch.zeros_like(param)
self.reduce_ready_partitions_and_remove_grads(param)

def set_module_bwd_hook(tensor):
if tensor.requires_grad:
module._leaf_module_inputs_remaining += 1
tensor.register_hook(reduce_leaf_module_grads)
return tensor

output = apply_to_tensors_only(set_module_bwd_hook, input)

if module._leaf_module_inputs_remaining == 0:
raise RuntimeError(
"A module cannot be set as a leaf module when it does not have any input tensors that require gradients"
)

return output

return forward_pre_hook

self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper(leaf_parameters)))

print_rank_0(f'[End] Create gradient reduction hooks')

def get_param_id(self, param):
Expand Down
62 changes: 62 additions & 0 deletions deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,65 @@ def assert_ints_same_as_other_ranks(ints: List[int]) -> None:
if ints != rank0_ints:
raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: "
f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}")


def is_builtin_type(obj):
# https://stackoverflow.com/a/17795199
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"


def isinstance_namedtuple(obj: object) -> bool:
"""
Is this an instance of namedtuple/NamedTuple?
From: https://stackoverflow.com/a/62692640
Args:
obj (object): An object.
Returns:
bool: True if namedtuple/NamedTuple else False.
"""
return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields')


def apply_to_tensors_only(function, value, warning_msg_fn=None):
"""
Apply `function` to every Tensor in `value`.
Args:
module (torch.nn.Module): A torch module
functional (Type[torch.autograd.Function]): The function class to apply.
backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to
`functional.apply`.
outputs (Any): The output of `module`.
Returns:
Any: The output of `module`.
"""
if isinstance(value, (tuple, list)):
touched_outputs = []
for elem in value:
touched_output = apply_to_tensors_only(function, elem)
touched_outputs.append(touched_output)

if isinstance_namedtuple(value):
# namedtuples require a slightly different syntax.
return value.__class__(*touched_outputs)

return value.__class__(touched_outputs)
elif isinstance(value, dict):
# apply inplace to avoid recreating dict inherited objects
for key in value.keys():
value[key] = apply_to_tensors_only(function, value[key])
return value

elif isinstance(value, torch.Tensor):
# this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter
return function(value)
else:
if not is_builtin_type(value):
global warned
if warning_msg_fn and not warned and dist.get_rank() == 0:
logger.warning(warning_msg_fn(value))
warned = True
return value
2 changes: 1 addition & 1 deletion deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, z3_leaf_module
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
66 changes: 52 additions & 14 deletions deepspeed/utils/z3_leaf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,80 @@
from typing import List, Type


def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None:
def z3_leaf_module(model: torch.nn.Module) -> bool:
"""Returns whether a module in `model` has been flagged as a 'leaf' module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
Returns:
bool: Whether the module has been flagged as a 'leaf' module.
"""
return hasattr(model, '_z3_leaf') and model._z3_leaf


def z3_leaf_parameter(model: torch.nn.Parameter) -> bool:
"""Returns whether a parameter belongs to a leaf module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Parameter): The parameter to which the leaf module flag will be applied.
Returns:
bool: Whether the parameter belongs to a leaf module.
"""
return hasattr(model, 'ds_z3_leaf_module')


def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]:
"""Returns a list of modules in `model` that have been flagged as 'leaf' modules.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
Returns:
List[torch.nn.Module]: A list of modules that have been flagged as 'leaf' modules.
"""
return [module for module in model.modules() if z3_leaf_module(module)]


def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type],
flag: bool) -> List[torch.nn.Module]:
assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types, got {leaf_module_classes}'

leaf_modules = []

def _set_z3_leaf_flag(model: torch.nn.Module):
nonlocal leaf_modules
if model.__class__ in leaf_module_classes:
model._z3_leaf = flag
leaf_modules.append(model)

model.apply(_set_z3_leaf_flag)

if len(leaf_modules) == 0:
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')

return leaf_modules

def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:

def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
_do_set_z3_leaf_modules(model, leaf_module_classes, True)
return _do_set_z3_leaf_modules(model, leaf_module_classes, True)


def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
Returns:
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
"""
_do_set_z3_leaf_modules(model, leaf_module_classes, False)


def z3_leaf_module(model: torch.nn.Module) -> bool:
"""Returns whether a module in `model` has been flagged as a 'leaf' module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
"""
return hasattr(model, '_z3_leaf') and model._z3_leaf
return _do_set_z3_leaf_modules(model, leaf_module_classes, False)
Loading

0 comments on commit 6fe2176

Please sign in to comment.