diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index d1ecd9288878..ab554297159b 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -6,6 +6,7 @@ import sys import torch from collections import OrderedDict +from deepspeed.utils import z3_leaf_module from deepspeed.runtime.utils import see_memory_usage from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import _init_external_params @@ -383,9 +384,10 @@ def _register_hooks_recursively(self, module, count=[0]): #print(f"{module.__class__} : {module.id}") - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) + if not z3_leaf_module(module): + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) @instrument_w_nvtx def _pre_forward_module_hook(module, *args): diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 9bcf5a91bc95..299138e84712 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -9,6 +9,7 @@ from typing import Deque, Set from deepspeed import comm as dist +from deepspeed.utils import z3_leaf_module from deepspeed.utils.logging import logger from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import * @@ -188,7 +189,7 @@ def record_parameters(self, sub_module: Module) -> None: raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") step_id = self.__step_id_module_fetched_for[sub_module.id].popleft() - for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id): self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id)) def construct_parameter_trace_from_module_trace(self): @@ -261,14 +262,14 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: """ if logger.isEnabledFor(logging.DEBUG): debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} " + str({ "avail": f"{self.__n_available_params:.1e}", "queue_sz": f"{len(self.__param_queue or [])}", "inflight": [p.ds_id for p in self.__inflight_param_registry], })) - params_to_fetch = frozenset(iter_params(current_submodule)) + params_to_fetch = frozenset(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))) fetch_numel = sum( [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE]) if fetch_numel > 0: @@ -390,8 +391,8 @@ def release_sub_module(self, submodule: Module, backward: bool) -> None: """release the parameters of a sub module, assuming they meet conditions to be released.""" params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( - p.ds_id for p in iter_params(submodule))) - for param in iter_params(submodule): + p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) + for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: self.__release_param(param, backward) @@ -473,7 +474,9 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set if not self.is_complete_trace(): raise RuntimeError("expected trace to be complete") - params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) if not p.ds_persist) + params_to_release = set( + p.ds_id for p in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)) + if not p.ds_persist) # Problem: When prefetcher scans the param trace, it skips AVAILABLE params. # This creates issues if those params are released before the skipped uses: @@ -482,7 +485,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set # diverges from the trace. # Solution: Don't release params whose reuse was skipped by prefetch. This is # possible because we detect such skips during prefetch and mark those params. - for param in iter_params(submodule_to_release): + for param in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)): if self.__most_recent_step_id_param_fetched_for[param] > step_id: params_to_release.discard(param.ds_id) @@ -493,7 +496,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set for module in self.__submodule_order[step_id:]: if params_traversed >= self.__max_reuse_dist_in_numel: break - for param in iter_params(module): + for param in iter_params(module, recurse=z3_leaf_module(submodule_to_release)): params_to_release.discard(param.ds_id) params_traversed += param.ds_numel diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 6237d7239682..b6c371cd5a1f 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -16,6 +16,7 @@ from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state +from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, z3_leaf_module from .mixed_precision_linkage import link_hp_params from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/z3_leaf_module.py b/deepspeed/utils/z3_leaf_module.py new file mode 100644 index 000000000000..57521843a2ea --- /dev/null +++ b/deepspeed/utils/z3_leaf_module.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import List, Type + + +def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None: + assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \ + f'leaf_module_classes must be a list of types, got {leaf_module_classes}' + + def _set_z3_leaf_flag(model: torch.nn.Module): + if model.__class__ in leaf_module_classes: + model._z3_leaf = flag + + model.apply(_set_z3_leaf_flag) + + +def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None: + """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. + This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module. + Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + """ + _do_set_z3_leaf_modules(model, leaf_module_classes, True) + + +def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None: + """Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules. + """ + _do_set_z3_leaf_modules(model, leaf_module_classes, False) + + +def z3_leaf_module(model: torch.nn.Module) -> bool: + """Returns whether a module in `model` has been flagged as a 'leaf' module. + See `set_z3_leaf_modules` for more details. + Args: + model (torch.nn.Module): The model to which the leaf module flag will be applied. + """ + return hasattr(model, '_z3_leaf') and model._z3_leaf diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py new file mode 100644 index 000000000000..8f10fe36c4d0 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed.comm as dist +import torch + +from unit.common import DistributedTest +from unit.simple_model import random_dataloader + +import deepspeed +from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module + + +class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, hidden_dim, bias=False), + torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) + self.act = torch.nn.ReLU() + self.cel = torch.nn.CrossEntropyLoss() + self.counter = 0 + + def forward(self, x, y): + # This fails without setting this module as a leaf module. + # See the comment in `set_z3_leaf_modules()`. + x = self.linears[self.counter % len(self.linears)](x) + x = self.act(x) + loss = self.cel(x, y) + self.counter += 1 + return x, loss + + +def run_model(model, config_dict, hidden_dim, dtype): + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + for batch in data_loader: + loss = model(batch[0], batch[1]) + loss = loss[1] + model.backward(loss) + model.step() + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +class TestSetZ3LeafModule(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + def test_set_z3_leaf_modules(self): + hidden_dim = 128 + + # `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_param_persistence_threshold": 0, + "stage3_max_reuse_distance": 0, + } + } + + model = MyModel(hidden_dim) + + assert not z3_leaf_module(model) + set_z3_leaf_modules(model, [MyModel]) + assert z3_leaf_module(model) + + run_model(model, config_dict, hidden_dim, torch.float16)