Skip to content

Commit

Permalink
use non_reentrant_checkpoint fix requires_grad of input must be…
Browse files Browse the repository at this point in the history
… true for activation checkpoint layer in pipeline train. (#4224)

* feat: add `non_reentrant_checkpoint`

* feat: add missing output postprocess and change the hook to record leaf forward tensor refs

* fix: make the multi_grad_hook registered after graph construction

* fix: backward compatibility for multi_tensor_hook

* fix: nonlocal reference error of deepspeed_saved_tensors

* fix: reduce repeating hook registration

* test: add test for `activation_checkpointing.checkpointing.non_reentrant_checkpoint`

* Pass correct node size for ZeRO++ (#4085)

* Pass correct node size

* formatting

---------

Co-authored-by: Connor Holmes <development@cmikeh2.me>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>

* add deepspeed chat arxiv report (#4110)

* add deepspeed chat arxiv report

* add zeroquant v2 and fp

* add selective enhencement

* add ignore for 'Youn' in spell checker

---------

Co-authored-by: yaozhewei <zheweiy@berkeley.edu>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>

* style: change flake8 detected style missmatch

* test: hack to clone the `test_activation_checkpointing` module for reuse and add regression tests

* doc: explain the introduction of `non_reentrant_checkpoint`

* doc: explain the test of `non_reentrant_checkpoint`

* apply non_reentrant_checkpoint in pipeline parallel training

* ut pass

* fix ci

* reduce check level for ci

---------

Co-authored-by: hughpu <hughpu@hotmail.com>
Co-authored-by: Hugh Pu <31498041+hughpu@users.noreply.github.com>
Co-authored-by: Connor Holmes <connorholmes@microsoft.com>
Co-authored-by: Connor Holmes <development@cmikeh2.me>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>
Co-authored-by: yaozhewei <zheweiy@berkeley.edu>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
  • Loading branch information
10 people authored Sep 6, 2023
1 parent 16d8953 commit 60a3e89
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 48 deletions.
30 changes: 15 additions & 15 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):

if self._config.pipeline['activation_checkpoint_interval'] > 0:
self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
# set use_reentrant default to True.
if self._config.pipeline.get('use_reentrant') is None:
self._config.pipeline['use_reentrant'] = True
if self._config.pipeline['use_reentrant'] is False:
# set activation_checkpoint_func to non_reentrant_checkpoint func.
self.module.activation_checkpoint_func = ds_checkpointing.non_reentrant_checkpoint
if self.grid.get_global_rank() == 0:
logger.info(f'CONFIG: activation_checkpoint_func=non_reentrant_checkpoint')

self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline

Expand Down Expand Up @@ -636,10 +644,7 @@ def _exec_forward_pass(self, buffer_id):
inputs = inputs[0] if len(inputs) == 1 else inputs
self.pipe_buffers['inputs'][buffer_id] = inputs

# Zero out the gradients each time we use the tensor because only the data in
# tensor changes across batches
self._zero_grads(inputs)

# inputs has no gradient because it is from a cloned tensor
outputs = super().forward(inputs)

# Reset activation checkpointing buffers.
Expand Down Expand Up @@ -777,15 +782,19 @@ def _exec_load_micro_batch(self, buffer_id):
loaded = None
if torch.is_tensor(batch[0]):
loaded = batch[0].clone().to(self.device).detach()
loaded.requires_grad = loaded.is_floating_point()
if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[
'use_reentrant']:
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], (tuple, list))
# Assume list or tuple
loaded = []
for x in batch[0]:
assert torch.is_tensor(x)
mine = x.clone().detach().to(self.device)
mine.requires_grad = mine.is_floating_point()
if self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline[
'use_reentrant']:
mine.requires_grad = mine.is_floating_point()
loaded.append(mine)
loaded = tuple(loaded)

Expand Down Expand Up @@ -1159,15 +1168,6 @@ def _exec_optimizer_step(self, lr_kwargs=None):
STEP_GLOBAL_TIMER,
])

def _zero_grads(self, inputs):
if isinstance(inputs, torch.Tensor):
if inputs.grad is not None:
inputs.grad.data.zero_()
else:
for t in inputs:
if t.grad is not None:
t.grad.data.zero_()

def _allocate_zeros(self, shape, **kwargs):
""" Allocate a tensor of zeros on the engine's device.
Expand Down
15 changes: 9 additions & 6 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def __init__(self,
self._synchronize_tied_weights()

self.activation_checkpoint_interval = activation_checkpoint_interval

self.activation_checkpoint_func = activation_checkpoint_func
# if configuration use_reentrant = False, self.activation_checkpoint_func will be set to ``checkpointing.non_reentrant_checkpoint``

def _build(self):
specs = self._layer_specs
Expand Down Expand Up @@ -618,13 +620,14 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
self._synchronize_tied_weights()

def _is_checkpointable(self, funcs):
# This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations.
# Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
# I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)

if self.activation_checkpoint_func is not checkpointing.non_reentrant_checkpoint:
# This hook excludes the embedding layer
# because only non_reentrant_checkpoint can accept inputs with requires_grad=False
# otherwise, the backward of the embedding layer won't receive gradients.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
if self.checkpointable_layers is not None:
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)
106 changes: 79 additions & 27 deletions tests/unit/runtime/pipe/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,31 @@

PipeTopo = PipeDataParallelTopology

config_dict = {
"train_batch_size": 4,
"grandient_accumulation_steps": 1,
"steps_per_print": 20,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"zero_optimization": {
"stage": 0
},
"fp16": {
"enabled": False
},
"pipeline": {
"seed_layers": True,
"activation_checkpoint_interval": 1
}
}


def rel_diff(A, B):
return abs(A - B) / abs(A)
Expand All @@ -38,34 +63,8 @@ def rel_diff(A, B):
class TestPipeCifar10(DistributedTest):
world_size = 4

def test(self, topo_config):
def test_pipe_base(self, topo_config):
skip_on_arch(min_arch=7)

config_dict = {
"train_batch_size": 4,
"grandient_accumulation_steps": 1,
"steps_per_print": 20,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"zero_optimization": {
"stage": 0
},
"fp16": {
"enabled": False
},
"pipeline": {
"seed_layers": True,
"activation_checkpoint_interval": 1
}
}

topo = PipeTopo(**topo_config)
steps = 100 # must be >=100

Expand Down Expand Up @@ -103,3 +102,56 @@ def test(self, topo_config):
test = test_losses[-lastX:]
test_avg = sum(test) / len(test)
assert rel_diff(base_avg, test_avg) < 0.05 # Originally 0.03, but seeing instability with AMD results

# def _check_model_params_equal(self, model1, model2):
# for p1, p2 in zip(model1.parameters(), model2.parameters()):
# if p1.data.ne(p2.data).sum() > 0:
# assert False, f"model params not equal"

def test_pipe_use_reentrant(self, topo_config):
skip_on_arch(min_arch=7)

topo = PipeTopo(**topo_config)
steps = 100 # must be >=100

# Allocate model for consistent initial weights.
init_net = AlexNetPipe()

# Train with not set use_reentrant, default: True
base_net = copy.deepcopy(init_net)
base_model = PipelineModule(layers=base_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss())
base_losses = train_cifar(base_model, config=config_dict, num_steps=steps, fp16=config_dict['fp16']['enabled'])

# Train with set use_reentrant=False, this will use ``non_reentrant_checkpoint``
test_config_dict = copy.deepcopy(config_dict)
test_config_dict['pipeline']['use_reentrant'] = False
test_net = copy.deepcopy(init_net)
test_model = PipelineModule(layers=test_net.to_layers(), topology=topo, loss_fn=nn.CrossEntropyLoss())
test_losses = train_cifar(test_model,
config=test_config_dict,
num_steps=steps,
fp16=config_dict['fp16']['enabled'])

abs_diffs = [l0 - l1 for l0, l1 in zip(base_losses, test_losses)]
rel_diffs = [rel_diff(l0, l1) for l0, l1 in zip(base_losses, test_losses)]
if dist.get_rank() == 0:
print(f'abs min={min(abs_diffs)} max={max(abs_diffs)} avg={sum(abs_diffs)/len(abs_diffs)}')
print(f'rel min={min(rel_diffs)} max={max(rel_diffs)} avg={sum(rel_diffs)/len(rel_diffs)}')
print(f'first: base={base_losses[0]} test={test_losses[0]} abs={abs_diffs[0]} rel={rel_diffs[0]}')

for lastX in [1, 10, 100]:
base_avg = sum(base_losses[-lastX:]) / lastX
test_avg = sum(test_losses[-lastX:]) / lastX
print(
f'last-{lastX}: base={base_avg} test={test_avg} abs={base_avg - test_avg} rel={rel_diff(base_avg, test_avg)}'
)
lastX = 100
base = base_losses[-lastX:]
base_avg = sum(base) / len(base)
test = test_losses[-lastX:]
test_avg = sum(test) / len(test)
assert rel_diff(base_avg, test_avg) < 0.05

# the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1)
# Check if models have same weights after training
# self._check_model_params_equal(base_model, test_model)

0 comments on commit 60a3e89

Please sign in to comment.