Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to set a module as a leaf node when recursively setting Z3 hooks #4966

Merged
merged 9 commits into from
Jan 19, 2024
3 changes: 3 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from deepspeed.runtime import lr_schedules
from deepspeed.utils import groups
from deepspeed.utils import logger, log_dist, instrument_w_nvtx
from deepspeed.utils import set_z3_leaf_modules, get_default_z3_leaf_module_classes
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
STEP_MICRO_TIMER, \
Expand Down Expand Up @@ -1533,6 +1534,8 @@ def _configure_zero_optimizer(self, optimizer):

elif zero_stage == ZeroStageEnum.weights:
assert not self.has_moe_layers, "MoE not supported with Stage 3"
set_z3_leaf_modules(self.module, get_default_z3_leaf_module_classes())

if isinstance(optimizer, DummyOptim):
log_dist("Creating ZeRO Offload", ranks=[0])
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
Expand Down
8 changes: 5 additions & 3 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import torch
from collections import OrderedDict
from deepspeed.utils import z3_leaf_module
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import _init_external_params
Expand Down Expand Up @@ -383,9 +384,10 @@ def _register_hooks_recursively(self, module, count=[0]):

#print(f"{module.__class__} : {module.id}")

for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
if not z3_leaf_module(module):
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)

@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
Expand Down
19 changes: 11 additions & 8 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Deque, Set

from deepspeed import comm as dist
from deepspeed.utils import z3_leaf_module
from deepspeed.utils.logging import logger
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import *
Expand Down Expand Up @@ -188,7 +189,7 @@ def record_parameters(self, sub_module: Module) -> None:
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id):
for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id):
self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))

def construct_parameter_trace_from_module_trace(self):
Expand Down Expand Up @@ -261,14 +262,14 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
"""
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
"inflight": [p.ds_id for p in self.__inflight_param_registry],
}))

params_to_fetch = frozenset(iter_params(current_submodule))
params_to_fetch = frozenset(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule)))
fetch_numel = sum(
[p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
if fetch_numel > 0:
Expand Down Expand Up @@ -390,8 +391,8 @@ def release_sub_module(self, submodule: Module, backward: bool) -> None:
"""release the parameters of a sub module, assuming they meet conditions to
be released."""
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule)))
for param in iter_params(submodule):
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param, backward)
Expand Down Expand Up @@ -473,7 +474,9 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set
if not self.is_complete_trace():
raise RuntimeError("expected trace to be complete")

params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) if not p.ds_persist)
params_to_release = set(
p.ds_id for p in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release))
if not p.ds_persist)

# Problem: When prefetcher scans the param trace, it skips AVAILABLE params.
# This creates issues if those params are released before the skipped uses:
Expand All @@ -482,7 +485,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set
# diverges from the trace.
# Solution: Don't release params whose reuse was skipped by prefetch. This is
# possible because we detect such skips during prefetch and mark those params.
for param in iter_params(submodule_to_release):
for param in iter_params(submodule_to_release, recurse=z3_leaf_module(submodule_to_release)):
if self.__most_recent_step_id_param_fetched_for[param] > step_id:
params_to_release.discard(param.ds_id)

Expand All @@ -493,7 +496,7 @@ def __params_to_release(self, submodule_to_release: Module, step_id: int) -> Set
for module in self.__submodule_order[step_id:]:
if params_traversed >= self.__max_reuse_dist_in_numel:
break
for param in iter_params(module):
for param in iter_params(module, recurse=z3_leaf_module(submodule_to_release)):
params_to_release.discard(param.ds_id)
params_traversed += param.ds_numel

Expand Down
1 change: 1 addition & 0 deletions deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, z3_leaf_module, get_default_z3_leaf_module_classes
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
63 changes: 63 additions & 0 deletions deepspeed/utils/z3_leaf_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from typing import List, Type

default_z3_leaf_module_classes = []

try:
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
default_z3_leaf_module_classes.append(MixtralSparseMoeBlock)
except ImportError:
pass
tjruwase marked this conversation as resolved.
Show resolved Hide resolved


def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None:
assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \
f'leaf_module_classes must be a list of types, got {leaf_module_classes}'

def _set_z3_leaf_flag(model: torch.nn.Module):
if model.__class__ in leaf_module_classes:
model._z3_leaf = flag

model.apply(_set_z3_leaf_flag)


def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
"""
_do_set_z3_leaf_modules(model, leaf_module_classes, True)


def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
"""
_do_set_z3_leaf_modules(model, leaf_module_classes, False)


def z3_leaf_module(model: torch.nn.Module) -> bool:
"""Returns whether a module in `model` has been flagged as a 'leaf' module.
See `set_z3_leaf_modules` for more details.
Args:
model (torch.nn.Module): The model to which the leaf module flag will be applied.
"""
return hasattr(model, '_z3_leaf') and model._z3_leaf


def get_default_z3_leaf_module_classes() -> List[Type]:
"""Returns a list of module classes that are flagged as 'leaf' modules by default.
See `set_z3_leaf_modules` for more details.
"""
return default_z3_leaf_module_classes
90 changes: 90 additions & 0 deletions tests/unit/runtime/zero/test_zero_leaf_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed.comm as dist
import torch

from unit.common import DistributedTest
from unit.simple_model import random_dataloader

import deepspeed
from deepspeed.utils import set_z3_leaf_modules, z3_leaf_module


class MyModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.linears = torch.nn.ModuleList(
[torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
torch.nn.Linear(hidden_dim, hidden_dim, bias=False)])
self.act = torch.nn.ReLU()
self.cel = torch.nn.CrossEntropyLoss()
self.counter = 0

def forward(self, x, y):
# This fails without setting this module as a leaf module.
# See the comment in `set_z3_leaf_modules()`.
x = self.linears[self.counter % len(self.linears)](x)
x = self.act(x)
loss = self.cel(x, y)
self.counter += 1
return x, loss


def run_model(model, config_dict, hidden_dim, dtype):
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
dist.barrier()
for batch in data_loader:
loss = model(batch[0], batch[1])
loss = loss[1]
model.backward(loss)
model.step()

# Needed in ZeRO 3. Not doing so can give memory leak
model.destroy()


class TestSetZ3LeafModule(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True

def test_set_z3_leaf_modules(self):
hidden_dim = 128

# `stage3_max_reuse_distance` is set to 0 to cause an error if the module is not set as a leaf module
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 3,
"stage3_prefetch_bucket_size": hidden_dim**2,
"stage3_param_persistence_threshold": 0,
"stage3_max_reuse_distance": 0,
}
}

model = MyModel(hidden_dim)

assert not z3_leaf_module(model)
set_z3_leaf_modules(model, [MyModel])
assert z3_leaf_module(model)

run_model(model, config_dict, hidden_dim, torch.float16)