Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training multiple models #7018

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
19 changes: 11 additions & 8 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False,
immediate_grad_update=True,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
Expand Down Expand Up @@ -313,7 +313,7 @@ def step(self, closure=None):

self.clear_hp_grads()

def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
"""Perform a backward pass and copy the low-precision gradients to the
high-precision copy.

Expand All @@ -323,7 +323,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg
The low-precision grads are deallocated during this procedure.
"""
self.clear_lp_grads()
loss.backward(**bwd_kwargs)
loss.backward(retain_graph=retain_graph, **bwd_kwargs)

if update_hp_grads:
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
Expand Down Expand Up @@ -425,9 +425,6 @@ def update_lp_params(self):
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
# if i == 0:
# print_rank_0(f'{fp32_partition[:10]=}', force=True)

all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
partitioned_param_groups=self.bf16_partitioned_groups,
Expand All @@ -442,10 +439,12 @@ def clear_hp_grads(self):
for i, group in enumerate(self.fp32_groups_gradients):
self.fp32_groups_has_gradients[i] = [False] * len(group)

def clear_lp_grads(self):
def clear_lp_grads(self, set_to_none=False):

# using zero_() fixed memory address for graph replay
set_to_none = False if self.graph_harvesting else True
if self.graph_harvesting:
assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"

zero_grads_list = []
for group in self.bf16_groups:
for param in group:
Expand All @@ -458,6 +457,10 @@ def clear_lp_grads(self):
if not set_to_none and len(zero_grads_list) > 0:
torch._foreach_zero_(zero_grads_list)

def zero_grad(self, set_to_none=True):
self.clear_lp_grads(set_to_none)
self.clear_hp_grads()

def state_dict(self):
state_dict = {}
state_dict[CLIP_GRAD] = self.clip_grad
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@

# BFLOAT16 optimizer immediate gradient update
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True

#########################################
# FP16 support
Expand Down
186 changes: 107 additions & 79 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def __init__(self,
# Configure distributed model
self._configure_distributed_model(model)

self.module_forward_pre_hook = self._create_module_forward_pre_hook()
self.module_forward_post_hook = self._create_module_forward_post_hook()
self.module_backward_pre_hook = self._create_module_backward_pre_hook()

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

Expand Down Expand Up @@ -1870,7 +1874,6 @@ def deepspeed_io(self,
GLOBAL_RANK: self.global_rank,
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
}

return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
Expand Down Expand Up @@ -1917,17 +1920,30 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):

return scaled_loss

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
def _create_module_backward_pre_hook(self):

if self.autotuning_profile_model_info():
ma = get_ma_status()
else:
def _module_backward_hook(module, grad_output):
if hasattr(self.optimizer, 'backward_prologue'):
self.optimizer.backward_prologue()

return self.module.register_full_backward_pre_hook(_module_backward_hook)

def _create_module_forward_pre_hook(self):

def _module_forward_pre_hook(module, inputs):
self._forward_prologue(inputs)

return self.module.register_forward_pre_hook(_module_forward_pre_hook)

def _create_module_forward_post_hook(self):

def _module_forward_post_hook(module, input, output):
self._forward_epilogue()

return self.module.register_forward_hook(_module_forward_post_hook)

def _forward_prologue(self, inputs, kwargs=None):
if not self.autotuning_profile_model_info():
see_memory_usage("Engine before forward", force=self.memory_breakdown())

flops_profiler_active = (self.flops_profiler_enabled()
Expand All @@ -1950,54 +1966,71 @@ def forward(self, *inputs, **kwargs):
if flops_profiler_active:
self.flops_profiler.start_profile(ignore_list=None)

if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())
if kwargs is not None:
if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})

if self.module.training and self.random_ltd_enabled():
self.random_ltd_scheduler.update_seq(self.global_steps)

if self.training_dataloader is None:
self.tput_timer.start()

self._start_timers(self.engine_timers.forward_timers)

if self.zero_optimization_partition_weights():
# Enable automated discovery of external parameters by indicating that
# we are in a forward pass.
for module in self.module.modules():
module._parameters._in_forward = True

self._start_timers(self.engine_timers.forward_timers)

if self.training_dataloader is None:
self.tput_timer.start()

if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

loss = self.module(*inputs, **kwargs)

def _forward_epilogue(self):
if self.zero_optimization_partition_weights():
# Disable automated discovery of external parameters
for module in self.module.modules():
module._parameters._in_forward = False

self._stop_timers(self.engine_timers.forward_timers)

flops_profiler_active = (self.flops_profiler_enabled()
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)

if flops_profiler_active:
self.flops_profiler.stop_profile()

if not self.autotuning_profile_model_info():
see_memory_usage("Engine after forward", force=self.memory_breakdown())

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if self.autotuning_profile_model_info():
ma = get_ma_status()

loss = self.module(*inputs, **kwargs)

if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
exit()
else:
see_memory_usage("Engine after forward", force=self.memory_breakdown())

return loss

def _cast_inputs_half(self, inputs):
Expand Down Expand Up @@ -2056,43 +2089,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
grads = None
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)

@contextmanager
def no_sync(self):
r"""
Context manager to disable gradient reduction during backward pass.
This context manager has the following effects on other DeepSpeed features.
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2. It is illegal to call engine.step() within the context manager.
3. Tracking of gradient accumulation steps is disabled.
"""
assert not self.zero_optimization_partition_gradients(), \
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"

assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"

self.inside_no_sync_ctxt = True
try:
yield
finally:
self.inside_no_sync_ctxt = False

@instrument_w_nvtx
def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""

def _backward_prologue(self, loss):
see_memory_usage("Engine before backward", force=self.memory_breakdown())

if self.scale_wrt_gas is not None:
scale_wrt_gas = self.scale_wrt_gas

do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt

# scale loss w.r.t. gradient accumulation if reduction is not disabled
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())

Expand All @@ -2109,13 +2112,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
)]
self.monitor.write_events(self.summary_events)

self._start_timers(self.engine_timers.backward_timers)
return loss

assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
"must provide optimizer during init in order to use backward"
def _backward_epilogue(self):
self._start_timers(self.engine_timers.backward_reduce_timers)
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()
self._stop_timers(self.engine_timers.backward_reduce_timers)
see_memory_usage("Engine after backward", force=self.memory_breakdown())

def _do_optimizer_backward(self, loss, retain_graph):
self._start_timers(self.engine_timers.backward_inner_timers)

if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
self.optimizer.backward(loss, retain_graph=retain_graph)
Expand All @@ -2131,30 +2139,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
else:
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.bfloat16_enabled():
self.optimizer.backward(loss)
self.optimizer.backward(loss, retain_graph=retain_graph)
else:
if self.eigenvalue_enabled():
loss.backward(create_graph=True, retain_graph=True)
else:
loss.backward(retain_graph=retain_graph)

self._stop_timers(self.engine_timers.backward_inner_timers)

self._start_timers(self.engine_timers.backward_reduce_timers)

if do_gradient_reduction:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()
@contextmanager
def no_sync(self):
r"""
Context manager to disable gradient reduction during backward pass.
This context manager has the following effects on other DeepSpeed features.
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2. It is illegal to call engine.step() within the context manager.
3. Tracking of gradient accumulation steps is disabled.
"""
assert not self.zero_optimization_partition_gradients(), \
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"

self._stop_timers(self.engine_timers.backward_reduce_timers)
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"

self._stop_timers(self.engine_timers.backward_timers)
self.inside_no_sync_ctxt = True
try:
yield
finally:
self.inside_no_sync_ctxt = False

if release_loss:
# loss.data = None
pass
@instrument_w_nvtx
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
"must provide optimizer during init in order to use backward"

see_memory_usage("Engine after backward", force=self.memory_breakdown())
self._start_timers(self.engine_timers.backward_timers)
loss = self._backward_prologue(loss)
self._do_optimizer_backward(loss, retain_graph)
self._backward_epilogue()
self._stop_timers(self.engine_timers.backward_timers)

return loss

Expand Down
Loading
Loading