diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 498da7e42fb6..80c78ccd6a6d 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -186,6 +186,14 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): if self._config.pipeline['activation_checkpoint_interval'] > 0: self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval'] + # set use_reentrant default to True. + if self._config.pipeline.get('use_reentrant') is None: + self._config.pipeline['use_reentrant'] = True + if self._config.pipeline['use_reentrant'] is False: + # set activation_checkpoint_func to non_reentrant_checkpoint func. + self.module.activation_checkpoint_func = ds_checkpointing.non_reentrant_checkpoint + if self.grid.get_global_rank() == 0: + logger.info(f'CONFIG: activation_checkpoint_func=non_reentrant_checkpoint') self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline @@ -636,10 +644,7 @@ def _exec_forward_pass(self, buffer_id): inputs = inputs[0] if len(inputs) == 1 else inputs self.pipe_buffers['inputs'][buffer_id] = inputs - # Zero out the gradients each time we use the tensor because only the data in - # tensor changes across batches - self._zero_grads(inputs) - + # inputs has no gradient because it is from a cloned tensor outputs = super().forward(inputs) # Reset activation checkpointing buffers. @@ -777,7 +782,9 @@ def _exec_load_micro_batch(self, buffer_id): loaded = None if torch.is_tensor(batch[0]): loaded = batch[0].clone().to(self.device).detach() - loaded.requires_grad = loaded.is_floating_point() + if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[ + 'use_reentrant']: + loaded.requires_grad = loaded.is_floating_point() else: assert isinstance(batch[0], (tuple, list)) # Assume list or tuple @@ -785,7 +792,9 @@ def _exec_load_micro_batch(self, buffer_id): for x in batch[0]: assert torch.is_tensor(x) mine = x.clone().detach().to(self.device) - mine.requires_grad = mine.is_floating_point() + if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[ + 'use_reentrant']: + mine.requires_grad = mine.is_floating_point() loaded.append(mine) loaded = tuple(loaded) @@ -1159,15 +1168,6 @@ def _exec_optimizer_step(self, lr_kwargs=None): STEP_GLOBAL_TIMER, ]) - def _zero_grads(self, inputs): - if isinstance(inputs, torch.Tensor): - if inputs.grad is not None: - inputs.grad.data.zero_() - else: - for t in inputs: - if t.grad is not None: - t.grad.data.zero_() - def _allocate_zeros(self, shape, **kwargs): """ Allocate a tensor of zeros on the engine's device. diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 43d1713de8f4..bd6b5fe4bf91 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -204,7 +204,9 @@ def __init__(self, self._synchronize_tied_weights() self.activation_checkpoint_interval = activation_checkpoint_interval + self.activation_checkpoint_func = activation_checkpoint_func + # if configuration use_reentrant = False, self.activation_checkpoint_func will be set to ``checkpointing.non_reentrant_checkpoint`` def _build(self): specs = self._layer_specs @@ -618,13 +620,14 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True): self._synchronize_tied_weights() def _is_checkpointable(self, funcs): - # This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations. - # Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things. - # I presume it's related to the discrete inputs that cannot require_grad? Need to revisit. - if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): - return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) + + if self.activation_checkpoint_func is not checkpointing.non_reentrant_checkpoint: + # This hook excludes the embedding layer + # because only non_reentrant_checkpoint can accept inputs with requires_grad=False + # otherwise, the backward of the embedding layer won't receive gradients. + if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) if self.checkpointable_layers is not None: return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) - params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) diff --git a/tests/unit/runtime/pipe/test_pipe.py b/tests/unit/runtime/pipe/test_pipe.py index dae791c8f860..88e26290b650 100644 --- a/tests/unit/runtime/pipe/test_pipe.py +++ b/tests/unit/runtime/pipe/test_pipe.py @@ -16,6 +16,31 @@ PipeTopo = PipeDataParallelTopology +config_dict = { + "train_batch_size": 4, + "grandient_accumulation_steps": 1, + "steps_per_print": 20, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7 + } + }, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": False + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } +} + def rel_diff(A, B): return abs(A - B) / abs(A) @@ -38,34 +63,8 @@ def rel_diff(A, B): class TestPipeCifar10(DistributedTest): world_size = 4 - def test(self, topo_config): + def test_pipe_base(self, topo_config): skip_on_arch(min_arch=7) - - config_dict = { - "train_batch_size": 4, - "grandient_accumulation_steps": 1, - "steps_per_print": 20, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [0.9, 0.999], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "zero_optimization": { - "stage": 0 - }, - "fp16": { - "enabled": False - }, - "pipeline": { - "seed_layers": True, - "activation_checkpoint_interval": 1 - } - } - topo = PipeTopo(**topo_config) steps = 100 # must be >=100 @@ -103,3 +102,56 @@ def test(self, topo_config): test = test_losses[-lastX:] test_avg = sum(test) / len(test) assert rel_diff(base_avg, test_avg) < 0.05 # Originally 0.03, but seeing instability with AMD results + + # def _check_model_params_equal(self, model1, model2): + # for p1, p2 in zip(model1.parameters(), model2.parameters()): + # if p1.data.ne(p2.data).sum() > 0: + # assert False, f"model params not equal" + + def test_pipe_use_reentrant(self, topo_config): + skip_on_arch(min_arch=7) + + topo = PipeTopo(**topo_config) + steps = 100 # must be >=100 + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + # Train with not set use_reentrant, default: True + base_net = copy.deepcopy(init_net) + base_model = PipelineModule(layers=base_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) + base_losses = train_cifar(base_model, config=config_dict, num_steps=steps, fp16=config_dict['fp16']['enabled']) + + # Train with set use_reentrant=False, this will use ``non_reentrant_checkpoint`` + test_config_dict = copy.deepcopy(config_dict) + test_config_dict['pipeline']['use_reentrant'] = False + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss()) + test_losses = train_cifar(test_model, + config=test_config_dict, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + abs_diffs = [l0 - l1 for l0, l1 in zip(base_losses, test_losses)] + rel_diffs = [rel_diff(l0, l1) for l0, l1 in zip(base_losses, test_losses)] + if dist.get_rank() == 0: + print(f'abs min={min(abs_diffs)} max={max(abs_diffs)} avg={sum(abs_diffs)/len(abs_diffs)}') + print(f'rel min={min(rel_diffs)} max={max(rel_diffs)} avg={sum(rel_diffs)/len(rel_diffs)}') + print(f'first: base={base_losses[0]} test={test_losses[0]} abs={abs_diffs[0]} rel={rel_diffs[0]}') + + for lastX in [1, 10, 100]: + base_avg = sum(base_losses[-lastX:]) / lastX + test_avg = sum(test_losses[-lastX:]) / lastX + print( + f'last-{lastX}: base={base_avg} test={test_avg} abs={base_avg - test_avg} rel={rel_diff(base_avg, test_avg)}' + ) + lastX = 100 + base = base_losses[-lastX:] + base_avg = sum(base) / len(base) + test = test_losses[-lastX:] + test_avg = sum(test) / len(test) + assert rel_diff(base_avg, test_avg) < 0.05 + + # the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1) + # Check if models have same weights after training + # self._check_model_params_equal(base_model, test_model)