From 6fe217604cbda9456bc5ce9608993622b1eb4b15 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 2 Feb 2024 13:44:08 -0800 Subject: [PATCH] Delay reduce-scatter for ZeRO3 leaf modules (#5008) ZeRO3 sets hooks on parameters to run reduce-scatter. This is often problematic for MoE models. Our data parallel processes may activate different sets of experts, but the hook is not fired unless the expert is activated at a forward pass. The reduce-scatter is called only on some processes in this case. This PR delays reduce-scatter for ZeRO3 leaf modules (Refer to #4966) to address the issue. We no longer set reduce-scatter hooks on parameters of the leaf modules. Instead, we launch reduce-scatter on all parameters belonging to the leaf module when exiting the module during the backward pass. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/parameter_offload.py | 82 ++++--------------- deepspeed/runtime/zero/stage3.py | 52 +++++++++++- deepspeed/runtime/zero/utils.py | 62 ++++++++++++++ deepspeed/utils/__init__.py | 2 +- deepspeed/utils/z3_leaf_module.py | 66 +++++++++++---- .../runtime/zero/test_zero_leaf_module.py | 69 ++++++++++++++-- 6 files changed, 240 insertions(+), 93 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 88dc41867d1f..56cc4af19840 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -8,88 +8,31 @@ from collections import OrderedDict from deepspeed.utils import z3_leaf_module from deepspeed.runtime.utils import see_memory_usage +from deepspeed.runtime.zero.utils import apply_to_tensors_only from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params -from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator FWD_MODULE_STACK = list() - -def is_builtin_type(obj): - # https://stackoverflow.com/a/17795199 - return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" - - -def isinstance_namedtuple(obj: object) -> bool: - """ - Is this an instance of namedtuple/NamedTuple? - From: https://stackoverflow.com/a/62692640 - - Args: - obj (object): An object. - - Returns: - bool: True if namedtuple/NamedTuple else False. - """ - return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') - - # ensure we only warn once, otherwise every iteration will trigger a warning warned = False -def _apply_to_tensors_only(module, functional, backward_function, outputs): - """ - Apply a torch.autograd.Function that calls a `backward_function` to every Tensor in `outputs`. - - Args: - module (torch.nn.Module): A torch module - functional (Type[torch.autograd.Function]): The function class to apply. - backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to - `functional.apply`. - outputs (Any): The output of `module`. +def _apply_backward_to_tensors_only(module, functional, backward_function, outputs): - Returns: - Any: The output of `module`. - """ - if isinstance(outputs, (tuple, list)): - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, functional, backward_function, output) - touched_outputs.append(touched_output) - - if isinstance_namedtuple(outputs): - # namedtuples require a slightly different syntax. - return outputs.__class__(*touched_outputs) - - return outputs.__class__(touched_outputs) - elif isinstance(outputs, dict): - # apply inplace to avoid recreating dict inherited objects - for key in outputs.keys(): - outputs[key] = _apply_to_tensors_only(module, functional, backward_function, outputs[key]) - return outputs - - elif isinstance(outputs, torch.Tensor): + def apply_to_tensor_fn(tensor): # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter - touched_outputs = functional.apply(module, backward_function, outputs) + touched_outputs = functional.apply(module, backward_function, tensor) # restore zero param attributes if those get stripped by `backward_function` - if not is_zero_param(touched_outputs) and is_zero_param(outputs): - touched_outputs.ds_param_alias = outputs + if not is_zero_param(touched_outputs) and is_zero_param(tensor): + touched_outputs.ds_param_alias = tensor return touched_outputs - else: - if not is_builtin_type(outputs): - global warned - if not warned and dist.get_rank() == 0: - logger.warning( - f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " - "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " - "output tensors and therefore may not get triggered properly.") - warned = True - return outputs + + return apply_to_tensors_only(apply_to_tensor_fn, outputs) #for each tensor in outputs run the forward_function and register backward_function as hook @@ -384,7 +327,10 @@ def _register_hooks_recursively(self, module, count=[0]): #print(f"{module.__class__} : {module.id}") - if not z3_leaf_module(module): + if z3_leaf_module(module): + for param in module.parameters(): + param.ds_z3_leaf_module = module + else: for child in module.children(): count[0] = count[0] + 1 self._register_hooks_recursively(child, count=count) @@ -448,7 +394,7 @@ def _run_before_backward_function(sub_module): sub_module.applied_pre_backward_ref_cnt -= 1 #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output) + return _apply_backward_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output) #This is an alternate to doing _post_backward_module_hook #it uses tensor.register_hook instead of using torch.autograd.Function @@ -478,7 +424,7 @@ def _run_after_backward_function(sub_module): if sub_module.ds_grads_remaining == 0: self.post_sub_module_backward_function(sub_module) - return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs) + return _apply_backward_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs) # Pre forward hook self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 5a5b03bc64eb..1e128ef527af 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -20,6 +20,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +from deepspeed.runtime.zero.utils import apply_to_tensors_only from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper @@ -27,7 +28,8 @@ from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER from deepspeed.accelerator import get_accelerator -import time +from deepspeed.utils import z3_leaf_parameter + # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False @@ -377,6 +379,7 @@ def __init__( #creates backward hooks for gradient partitioning ###Calls all gather param self._grad_acc_hooks = [] + self._leaf_module_hooks = [] self.create_reduce_and_remove_grad_hooks() #exit(0) @@ -399,6 +402,8 @@ def destroy(self): self.parameter_offload.destroy() for hook in self._grad_acc_hooks: hook.remove() + for hook in self._leaf_module_hooks: + hook.remove() print_rank_0("Removed grad acc hooks", force=False) del self.__ipg_bucket_flat_buffer @@ -1112,6 +1117,7 @@ def overlapping_partition_gradients_reduce_epilogue(self): def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') self.grad_accs = [] + self.leaf_parameters = defaultdict(list) for i, param_group in enumerate(self.fp16_groups): for param in param_group: if param.requires_grad: @@ -1134,10 +1140,52 @@ def reduce_partition_and_remove_grads(*notneeded): self.grad_accs.append(grad_acc) #print(f"param grad fn {param.expand_as(param).grad_fn}") - wrapper(param) + if z3_leaf_parameter(param): + self.leaf_parameters[param.ds_z3_leaf_module].append(param) + else: + wrapper(param) # Partition the parameter after creating the hook param.partition() + + # We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done + for leaf_module, leaf_parameters in self.leaf_parameters.items(): + + def wrapper(params): + + def forward_pre_hook(module, input): + """Pre-forward hook to set backward hook on input tensors to the leaf module""" + module._leaf_module_inputs_remaining = 0 + + @instrument_w_nvtx + def reduce_leaf_module_grads(grad): + module._leaf_module_inputs_remaining -= 1 + # Make sure everything is done in the leaf module + if module._leaf_module_inputs_remaining == 0: + for param in params: + if param.grad is None: + param.grad = torch.zeros_like(param) + self.reduce_ready_partitions_and_remove_grads(param) + + def set_module_bwd_hook(tensor): + if tensor.requires_grad: + module._leaf_module_inputs_remaining += 1 + tensor.register_hook(reduce_leaf_module_grads) + return tensor + + output = apply_to_tensors_only(set_module_bwd_hook, input) + + if module._leaf_module_inputs_remaining == 0: + raise RuntimeError( + "A module cannot be set as a leaf module when it does not have any input tensors that require gradients" + ) + + return output + + return forward_pre_hook + + self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper(leaf_parameters))) + print_rank_0(f'[End] Create gradient reduction hooks') def get_param_id(self, param): diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 0bf1ca4a894d..78eaaba59ebb 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -87,3 +87,65 @@ def assert_ints_same_as_other_ranks(ints: List[int]) -> None: if ints != rank0_ints: raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: " f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}") + + +def is_builtin_type(obj): + # https://stackoverflow.com/a/17795199 + return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" + + +def isinstance_namedtuple(obj: object) -> bool: + """ + Is this an instance of namedtuple/NamedTuple? + From: https://stackoverflow.com/a/62692640 + + Args: + obj (object): An object. + + Returns: + bool: True if namedtuple/NamedTuple else False. + """ + return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') + + +def apply_to_tensors_only(function, value, warning_msg_fn=None): + """ + Apply `function` to every Tensor in `value`. + + Args: + module (torch.nn.Module): A torch module + functional (Type[torch.autograd.Function]): The function class to apply. + backward_function (Callable[[torch.nn.Module], None]): A backward_function to pass to + `functional.apply`. + outputs (Any): The output of `module`. + + Returns: + Any: The output of `module`. + """ + if isinstance(value, (tuple, list)): + touched_outputs = [] + for elem in value: + touched_output = apply_to_tensors_only(function, elem) + touched_outputs.append(touched_output) + + if isinstance_namedtuple(value): + # namedtuples require a slightly different syntax. + return value.__class__(*touched_outputs) + + return value.__class__(touched_outputs) + elif isinstance(value, dict): + # apply inplace to avoid recreating dict inherited objects + for key in value.keys(): + value[key] = apply_to_tensors_only(function, value[key]) + return value + + elif isinstance(value, torch.Tensor): + # this also applies to torch.Tensor's subclasses like torch.nn.parameter.Parameter + return function(value) + else: + if not is_builtin_type(value): + global warned + if warning_msg_fn and not warned and dist.get_rank() == 0: + logger.warning(warning_msg_fn(value)) + warned = True + return value diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 294aee53bc63..155503fa44a0 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -16,7 +16,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state -from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, z3_leaf_module +from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter from .mixed_precision_linkage import link_hp_params from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/z3_leaf_module.py b/deepspeed/utils/z3_leaf_module.py index 57521843a2ea..47d9ff698f1f 100644 --- a/deepspeed/utils/z3_leaf_module.py +++ b/deepspeed/utils/z3_leaf_module.py @@ -7,42 +7,80 @@ from typing import List, Type -def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None: +def z3_leaf_module(model: torch.nn.Module) -> bool: + """Returns whether a module in `model` has been flagged as a 'leaf' module. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + Returns: + bool: Whether the module has been flagged as a 'leaf' module. + """ + return hasattr(model, '_z3_leaf') and model._z3_leaf + + +def z3_leaf_parameter(model: torch.nn.Parameter) -> bool: + """Returns whether a parameter belongs to a leaf module. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Parameter): The parameter to which the leaf module flag will be applied. + Returns: + bool: Whether the parameter belongs to a leaf module. + """ + return hasattr(model, 'ds_z3_leaf_module') + + +def get_z3_leaf_modules(model: torch.nn.Module) -> List[torch.nn.Module]: + """Returns a list of modules in `model` that have been flagged as 'leaf' modules. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + Returns: + List[torch.nn.Module]: A list of modules that have been flagged as 'leaf' modules. + """ + return [module for module in model.modules() if z3_leaf_module(module)] + + +def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], + flag: bool) -> List[torch.nn.Module]: assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \ f'leaf_module_classes must be a list of types, got {leaf_module_classes}' + leaf_modules = [] + def _set_z3_leaf_flag(model: torch.nn.Module): + nonlocal leaf_modules if model.__class__ in leaf_module_classes: model._z3_leaf = flag + leaf_modules.append(model) model.apply(_set_z3_leaf_flag) + if len(leaf_modules) == 0: + raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}') + + return leaf_modules -def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None: + +def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]: """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. Args: model (torch.nn.Module): The model to which the leaf module flag will be applied. leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + Returns: + List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. """ - _do_set_z3_leaf_modules(model, leaf_module_classes, True) + return _do_set_z3_leaf_modules(model, leaf_module_classes, True) -def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None: +def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]: """Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. See `set_z3_leaf_modules` for more details. Args: model (torch.nn.Module): The model to which the leaf module flag will be applied. leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + Returns: + List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`. """ - _do_set_z3_leaf_modules(model, leaf_module_classes, False) - - -def z3_leaf_module(model: torch.nn.Module) -> bool: - """Returns whether a module in `model` has been flagged as a 'leaf' module. - See `set_z3_leaf_modules` for more details. - Args: - model (torch.nn.Module): The model to which the leaf module flag will be applied. - """ - return hasattr(model, '_z3_leaf') and model._z3_leaf + return _do_set_z3_leaf_modules(model, leaf_module_classes, False) diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index 8f10fe36c4d0..0855acec57e3 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -10,13 +10,13 @@ from unit.simple_model import random_dataloader import deepspeed -from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module +from deepspeed.utils import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module -class MyModel(torch.nn.Module): +class ChooseModuleByCounter(torch.nn.Module): def __init__(self, hidden_dim): - super(MyModel, self).__init__() + super(ChooseModuleByCounter, self).__init__() self.linears = torch.nn.ModuleList( [torch.nn.Linear(hidden_dim, hidden_dim, bias=False), torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) @@ -34,7 +34,25 @@ def forward(self, x, y): return x, loss -def run_model(model, config_dict, hidden_dim, dtype): +class ChooseModuleByRankModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(ChooseModuleByRankModel, self).__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, hidden_dim, bias=False), + torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) + self.act = torch.nn.ReLU() + self.cel = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + # Each rank runs only one of the linear layers + x = self.linears[dist.get_rank() % len(self.linears)](x) + x = self.act(x) + loss = self.cel(x, y) + return x, loss + + +def run_model(model, config_dict, hidden_dim, dtype, requires_grad): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -43,6 +61,7 @@ def run_model(model, config_dict, hidden_dim, dtype): dtype=dtype) dist.barrier() for batch in data_loader: + batch[0].requires_grad = requires_grad loss = model(batch[0], batch[1]) loss = loss[1] model.backward(loss) @@ -57,7 +76,7 @@ class TestSetZ3LeafModule(DistributedTest): world_size = 2 reuse_dist_env = True - def test_set_z3_leaf_modules(self): + def _test_set_z3_leaf_modules(self, cls, requires_grad): hidden_dim = 128 # `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module @@ -81,10 +100,44 @@ def test_set_z3_leaf_modules(self): } } - model = MyModel(hidden_dim) + model = cls(hidden_dim) assert not z3_leaf_module(model) - set_z3_leaf_modules(model, [MyModel]) + set_z3_leaf_modules(model, [cls]) assert z3_leaf_module(model) - run_model(model, config_dict, hidden_dim, torch.float16) + run_model(model, config_dict, hidden_dim, torch.float16, requires_grad) + + def test_choose_module_by_counter(self): + self._test_set_z3_leaf_modules(ChooseModuleByCounter, True) + + def test_choose_module_by_rank(self): + self._test_set_z3_leaf_modules(ChooseModuleByRankModel, True) + + def test_no_grad_input_error(self): + try: + self._test_set_z3_leaf_modules(ChooseModuleByCounter, False) + raise AssertionError( + "Expected RuntimeError: inputs with requires_grad=False is not supported for a leaf module") + except RuntimeError as e: + pass + + def test_set_unset_leaf_modules(self): + hidden_dim = 128 + model = ChooseModuleByCounter(hidden_dim) + assert len(set_z3_leaf_modules(model, [torch.nn.ModuleList])) == 1, \ + "Expected only one module to be set as a leaf module" + assert len(get_z3_leaf_modules(model)) == 1, "Expected there is only one leaf module" + + assert len(unset_z3_leaf_modules(model, [torch.nn.ModuleList])) == 1, \ + "Expected only one module to be unset as a leaf module" + assert len(get_z3_leaf_modules(model)) == 0, "Expected there is no leaf module" + + def test_set_no_match_class(self): + hidden_dim = 128 + model = ChooseModuleByCounter(hidden_dim) + try: + set_z3_leaf_modules(model, [torch.nn.Conv2d]) + raise AssertionError("Expected error that no module is set as a leaf module") + except ValueError as e: + pass