diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 6b63efbb23f7..13e2842eb750 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -44,7 +44,7 @@ def __init__(self, timers=None, grad_acc_dtype=None, graph_harvesting=False, - immediate_grad_update=False, + immediate_grad_update=True, has_moe_layers=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) @@ -313,7 +313,7 @@ def step(self, closure=None): self.clear_hp_grads() - def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): + def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): """Perform a backward pass and copy the low-precision gradients to the high-precision copy. @@ -323,7 +323,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg The low-precision grads are deallocated during this procedure. """ self.clear_lp_grads() - loss.backward(**bwd_kwargs) + loss.backward(retain_graph=retain_graph, **bwd_kwargs) if update_hp_grads: self.update_hp_grads(clear_lp_grads=clear_lp_grads) @@ -425,9 +425,6 @@ def update_lp_params(self): fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bf16_partitions[partition_id].data.copy_(fp32_partition.data) - # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) - # if i == 0: - # print_rank_0(f'{fp32_partition[:10]=}', force=True) all_gather_dp_groups(groups_flat=self.bf16_groups_flat, partitioned_param_groups=self.bf16_partitioned_groups, @@ -442,10 +439,12 @@ def clear_hp_grads(self): for i, group in enumerate(self.fp32_groups_gradients): self.fp32_groups_has_gradients[i] = [False] * len(group) - def clear_lp_grads(self): + def clear_lp_grads(self, set_to_none=False): # using zero_() fixed memory address for graph replay - set_to_none = False if self.graph_harvesting else True + if self.graph_harvesting: + assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None" + zero_grads_list = [] for group in self.bf16_groups: for param in group: @@ -458,6 +457,10 @@ def clear_lp_grads(self): if not set_to_none and len(zero_grads_list) > 0: torch._foreach_zero_(zero_grads_list) + def zero_grad(self, set_to_none=True): + self.clear_lp_grads(set_to_none) + self.clear_hp_grads() + def state_dict(self): state_dict = {} state_dict[CLIP_GRAD] = self.clip_grad diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 55cfa8f59c91..2878bb11f872 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -128,7 +128,7 @@ # BFLOAT16 optimizer immediate gradient update BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update" -BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False +BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True ######################################### # FP16 support diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 986b68dc1bb1..9c500999a0ee 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -272,6 +272,9 @@ def __init__(self, # Configure distributed model self._configure_distributed_model(model) + self.module_forward_pre_hook = self._create_module_forward_pre_hook() + self.module_forward_post_hook = self._create_module_forward_post_hook() + # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict self.param_names = {param: name for name, param in model.named_parameters()} @@ -1870,7 +1873,6 @@ def deepspeed_io(self, GLOBAL_RANK: self.global_rank, DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS] } - return DeepSpeedDataLoader(dataset=dataset, batch_size=batch_size, pin_memory=pin_memory, @@ -1917,17 +1919,22 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): return scaled_loss - @instrument_w_nvtx - def forward(self, *inputs, **kwargs): - r"""Execute forward propagation - Arguments: - *inputs: Variable length input list - **kwargs: variable length keyword arguments - """ + def _create_module_forward_pre_hook(self): - if self.autotuning_profile_model_info(): - ma = get_ma_status() - else: + def _module_forward_pre_hook(module, inputs): + self._forward_prologue(inputs) + + return self.module.register_forward_pre_hook(_module_forward_pre_hook) + + def _create_module_forward_post_hook(self): + + def _module_forward_post_hook(module, input, output): + self._forward_epilogue() + + return self.module.register_forward_hook(_module_forward_post_hook) + + def _forward_prologue(self, inputs, kwargs=None): + if not self.autotuning_profile_model_info(): see_memory_usage("Engine before forward", force=self.memory_breakdown()) flops_profiler_active = (self.flops_profiler_enabled() @@ -1950,37 +1957,37 @@ def forward(self, *inputs, **kwargs): if flops_profiler_active: self.flops_profiler.start_profile(ignore_list=None) - if self.module.training: - if self.progressive_layer_drop: - kwargs.update(self.progressive_layer_drop.get_state()) + if kwargs is not None: + if self.module.training: + if self.progressive_layer_drop: + kwargs.update(self.progressive_layer_drop.get_state()) - if self.__class__.__name__ != "PipelineEngine": - # TODO: The above if condition is a HACK since for PipelineEngine - # it's difficult to inject argument in forward pass. - if self.module.training and self.curriculum_enabled_legacy(): - self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) - if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": - kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) + if self.__class__.__name__ != "PipelineEngine": + # TODO: The above if condition is a HACK since for PipelineEngine + # it's difficult to inject argument in forward pass. + if self.module.training and self.curriculum_enabled_legacy(): + self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) + if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": + kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) if self.module.training and self.random_ltd_enabled(): self.random_ltd_scheduler.update_seq(self.global_steps) + if self.training_dataloader is None: + self.tput_timer.start() + + self._start_timers(self.engine_timers.forward_timers) + if self.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. for module in self.module.modules(): module._parameters._in_forward = True - self._start_timers(self.engine_timers.forward_timers) - - if self.training_dataloader is None: - self.tput_timer.start() - if self.fp16_auto_cast(): inputs = self._cast_inputs_half(inputs) - loss = self.module(*inputs, **kwargs) - + def _forward_epilogue(self): if self.zero_optimization_partition_weights(): # Disable automated discovery of external parameters for module in self.module.modules(): @@ -1988,16 +1995,33 @@ def forward(self, *inputs, **kwargs): self._stop_timers(self.engine_timers.forward_timers) + flops_profiler_active = (self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) + if flops_profiler_active: self.flops_profiler.stop_profile() + if not self.autotuning_profile_model_info(): + see_memory_usage("Engine after forward", force=self.memory_breakdown()) + + @instrument_w_nvtx + def forward(self, *inputs, **kwargs): + r"""Execute forward propagation + Arguments: + *inputs: Variable length input list + **kwargs: variable length keyword arguments + """ + if self.autotuning_profile_model_info(): + ma = get_ma_status() + + loss = self.module(*inputs, **kwargs) + if self.autotuning_profile_model_info(): activation_mem = get_ma_status() - ma self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path()) exit() - else: - see_memory_usage("Engine after forward", force=self.memory_breakdown()) + return loss def _cast_inputs_half(self, inputs): @@ -2056,43 +2080,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): grads = None self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) - @contextmanager - def no_sync(self): - r""" - Context manager to disable gradient reduction during backward pass. - This context manager has the following effects on other DeepSpeed features. - 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. - 2. It is illegal to call engine.step() within the context manager. - 3. Tracking of gradient accumulation steps is disabled. - """ - assert not self.zero_optimization_partition_gradients(), \ - f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" - - assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported" - - self.inside_no_sync_ctxt = True - try: - yield - finally: - self.inside_no_sync_ctxt = False - - @instrument_w_nvtx - def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True): - r"""Execute backward pass on the loss - Arguments: - loss: Torch tensor on which to execute backward propagation - retain_graph: bool, default: false - forward on user defined choice of retain_graph - """ - + def _backward_prologue(self, loss, scale_wrt_gas=True): see_memory_usage("Engine before backward", force=self.memory_breakdown()) - if self.scale_wrt_gas is not None: scale_wrt_gas = self.scale_wrt_gas - do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt - # scale loss w.r.t. gradient accumulation if reduction is not disabled + do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas: loss = self._scale_loss_by_gas(loss.float()) @@ -2109,13 +2103,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T )] self.monitor.write_events(self.summary_events) - self._start_timers(self.engine_timers.backward_timers) + return loss - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use backward" + def _backward_epilogue(self): + self._start_timers(self.engine_timers.backward_reduce_timers) + if self.enable_backward_allreduce and not self.inside_no_sync_ctxt: + # Traditional code path that allreduces the module parameter grads + self.allreduce_gradients() + self._stop_timers(self.engine_timers.backward_reduce_timers) + see_memory_usage("Engine after backward", force=self.memory_breakdown()) + def _do_optimizer_backward(self, loss, retain_graph): self._start_timers(self.engine_timers.backward_inner_timers) - if self.zero_optimization(): self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() self.optimizer.backward(loss, retain_graph=retain_graph) @@ -2131,30 +2130,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T else: self.optimizer.backward(loss, retain_graph=retain_graph) elif self.bfloat16_enabled(): - self.optimizer.backward(loss) + self.optimizer.backward(loss, retain_graph=retain_graph) else: if self.eigenvalue_enabled(): loss.backward(create_graph=True, retain_graph=True) else: loss.backward(retain_graph=retain_graph) - self._stop_timers(self.engine_timers.backward_inner_timers) - self._start_timers(self.engine_timers.backward_reduce_timers) - - if do_gradient_reduction: - # Traditional code path that allreduces the module parameter grads - self.allreduce_gradients() + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient reduction during backward pass. + This context manager has the following effects on other DeepSpeed features. + 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. + 2. It is illegal to call engine.step() within the context manager. + 3. Tracking of gradient accumulation steps is disabled. + """ + assert not self.zero_optimization_partition_gradients(), \ + f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" - self._stop_timers(self.engine_timers.backward_reduce_timers) + assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported" - self._stop_timers(self.engine_timers.backward_timers) + self.inside_no_sync_ctxt = True + try: + yield + finally: + self.inside_no_sync_ctxt = False - if release_loss: - # loss.data = None - pass + @instrument_w_nvtx + def backward(self, loss, retain_graph=False, scale_wrt_gas=True): + r"""Execute backward pass on the loss + Arguments: + loss: Torch tensor on which to execute backward propagation + retain_graph: bool, default: false + forward on user defined choice of retain_graph + """ + assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ + "must provide optimizer during init in order to use backward" - see_memory_usage("Engine after backward", force=self.memory_breakdown()) + self._start_timers(self.engine_timers.backward_timers) + loss = self._backward_prologue(loss, scale_wrt_gas) + self._do_optimizer_backward(loss, retain_graph) + self._backward_epilogue() + self._stop_timers(self.engine_timers.backward_timers) return loss diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 2bece09bffc4..109bf9a56f48 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -522,11 +522,16 @@ def __init__(self, # resets the data structure value for the next backward propagation self.reset_partition_gradient_structures() - # creates backward hooks for gradient partitioning + # creates backward hooks for the following special handling of gradients + # 1. upcasting for fp32 gradient accumulation + # 2. gradient partitioning + # 3. overlapping backward and reduction self._grad_acc_hooks = [] - if self.partition_gradients or self.overlap_comm: - self.create_reduce_and_remove_grad_hooks() + if self.partition_gradients or self.overlap_comm or self.use_grad_accum_attribute: + self.create_gradient_handling_hooks() + + self.ready_for_gradients = False self.custom_loss_scaler = False self.external_loss_scale = None @@ -678,6 +683,7 @@ def _release_ipg_buffers(self): self.ipg_buffer = None self.grads_in_partition = None self.grads_in_partition_offset = 0 + self.ready_for_gradients = False def initialize_optimizer_states(self): @@ -874,16 +880,18 @@ def increment_value(dictionary, key): def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() + def _fill_param_grad_accum_attribute(self, param): + if param.grad is not None: + if param.grad_accum is None: + param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) + else: + param.grad_accum.add_(param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) + param.grad = None + def fill_grad_accum_attribute(self): for group in self.bit16_groups: for param in group: - if param.grad is not None: - if param.grad_accum is None: - param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) - else: - param.grad_accum.add_( - param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) - param.grad = None + self._fill_param_grad_accum_attribute(param) def get_gradient_for_reduction(self, param): if self.use_grad_accum_attribute: @@ -901,7 +909,7 @@ def clear_grad_attribute(self, param): else: param.grad = None - def create_reduce_and_remove_grad_hooks(self): + def create_gradient_handling_hooks(self): self.grad_accs = [] for i, param_group in enumerate(self.bit16_groups): for param in param_group: @@ -911,10 +919,10 @@ def wrapper(param, i): param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) + def grad_handling_hook(*notneeded): + self.process_gradients(param, i) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) + self._grad_acc_hooks.append(grad_acc.register_hook(grad_handling_hook)) self.grad_accs.append(grad_acc) wrapper(param, i) @@ -1294,7 +1302,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, - # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # their backward hooks in self.create_gradient_handling_hooks() will not run, # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] @@ -1421,6 +1429,13 @@ def reduce_ipg_grads(self): self.elements_in_ipg_bucket = 0 ##################################################################### + def process_gradients(self, param, i): + self.backward_prologue() + if self.use_grad_accum_attribute: + self._fill_param_grad_accum_attribute(param) + if self.partition_gradients or self.overlap_comm: + self.reduce_ready_partitions_and_remove_grads(param, i) + def reduce_ready_partitions_and_remove_grads(self, param, i): if self.partition_gradients or self.is_gradient_accumulation_boundary: self.reduce_independent_p_g_buckets_and_remove_grads(param, i) @@ -1949,9 +1964,7 @@ def update_lp_params(self): zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bit16_partitions[partition_id].data.copy_(fp32_partition.data) - # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) - # if i == 0: - # print_rank_0(f'{fp32_partition[:10]=}', force=True) + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group, @@ -2035,6 +2048,30 @@ def _has_inf_or_nan(x, j=None): inf_or_nan = nan.logical_or(inf) return inf_or_nan.float().max() + def backward_prologue(self): + if not self.ready_for_gradients: + self.micro_step_id += 1 + if self.contiguous_gradients and self.ipg_buffer is None: + self.ipg_buffer = [] + buf_0 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_0) + + # Use double buffers to avoid data access conflict when overlap_comm is enabled. + if self.overlap_comm: + buf_1 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_1) + self.ipg_index = 0 + self.ready_for_gradients = True + + def backward_epilogue(self): + # Only for Stage 1, Mode 2 + if self.use_grad_accum_attribute: + self.fill_grad_accum_attribute() + def backward(self, loss, retain_graph=False): """ :attr:`backward` performs the following steps: @@ -2043,32 +2080,13 @@ def backward(self, loss, retain_graph=False): 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ - self.micro_step_id += 1 - - if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_0) - - # Use double buffers to avoid data access conflict when overlap_comm is enabled. - if self.overlap_comm: - buf_1 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 - + self.backward_prologue() if self.custom_loss_scaler: scaled_loss = self.external_loss_scale * loss - scaled_loss.backward() + scaled_loss.backward(retain_graph=retain_graph) else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - # Only for Stage 1, Mode 2 - if self.use_grad_accum_attribute: - self.fill_grad_accum_attribute() + self.backward_epilogue() def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients)