Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix requires_grad of input must be true for activation checkpoint layer in pipeline train. #4128

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,6 @@ def _exec_forward_pass(self, buffer_id):
inputs = inputs[0] if len(inputs) == 1 else inputs
self.pipe_buffers['inputs'][buffer_id] = inputs

# Zero out the gradients each time we use the tensor because only the data in
# tensor changes across batches
self._zero_grads(inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we can delete this? We could also delete the comment before this line.

Copy link
Contributor Author

@inkcherry inkcherry Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
else:
inputs = self.pipe_buffers['inputs'][buffer_id].clone()
here inputs has become a non-leaf tensor by clone() op, and its gradient has not been saved. so no need to call _zero_grad again.

and if if self.is_pipe_partitioned and not self.is_first_stage() is False, which means not create a new leaf tensor named input, access current inputs'(non-leaf) gradient in _zero_grad will trigger a warning


outputs = super().forward(inputs)

# Reset activation checkpointing buffers.
Expand Down Expand Up @@ -777,15 +773,13 @@ def _exec_load_micro_batch(self, buffer_id):
loaded = None
if torch.is_tensor(batch[0]):
loaded = batch[0].clone().to(self.device).detach()
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], (tuple, list))
# Assume list or tuple
loaded = []
for x in batch[0]:
assert torch.is_tensor(x)
mine = x.clone().detach().to(self.device)
mine.requires_grad = mine.is_floating_point()
loaded.append(mine)
loaded = tuple(loaded)

Expand Down Expand Up @@ -1158,15 +1152,6 @@ def _exec_optimizer_step(self, lr_kwargs=None):
STEP_GLOBAL_TIMER,
])

def _zero_grads(self, inputs):
if isinstance(inputs, torch.Tensor):
if inputs.grad is not None:
inputs.grad.data.zero_()
else:
for t in inputs:
if t.grad is not None:
t.grad.data.zero_()

def _allocate_zeros(self, shape, **kwargs):
""" Allocate a tensor of zeros on the engine's device.

Expand Down
71 changes: 57 additions & 14 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
from deepspeed.accelerator import get_accelerator
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
from enum import Enum


class CkptLayer_Enum(Enum):
not_ckpt_layer = 0
normal_ckpt_layer = 1
wrap_ckpt_layer = 2


class PipelineError(Exception):
Expand Down Expand Up @@ -83,6 +90,21 @@ def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_att
self.tied_weight_attr = tied_weight_attr


class ModuleWrapper(nn.Module):
"""Some input with requires_grad=False will prevent the checkpoint activation layer from obtaining grads,
and the requires_grad of the output of checkpoint activation layer will also be False, using this wrapper at first layer could solve this problem
"""

def __init__(self, module):
super().__init__()
self.module = module

def forward(self, x, dummy_arg=None):
assert dummy_arg is not None
x = self.module(x)
return x


class PipelineModule(nn.Module):
"""Modules to be parallelized with pipeline parallelism.

Expand Down Expand Up @@ -191,7 +213,8 @@ def __init__(self,
self.fwd_map = {}
self.tied_modules = nn.ModuleDict()
self.tied_weight_attrs = {}

self.is_wrapped_ckptlayer = False
self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
# Offset the random seed by the stage ID.
#newseed = get_accelerator().initial_seed() + self._grid.get_stage_id()
#ds_utils.set_random_seed(newseed)
Expand Down Expand Up @@ -320,15 +343,22 @@ def forward(self, forward_input):
# will see a different offset.
self.micro_offset += 1

def exec_range_func(start, end):
def exec_range_func(start, end, wrap_layer=False):
''' Helper function to be used with checkpoint()
Adapted from torch.utils.checkpoint:checkpoint_sequential()
'''
local_micro_offset = self.micro_offset + 1

def exec_func(*inputs):
# Single tensor inputs need to be unwrapped
if len(inputs) == 1:
if wrap_layer:
if len(inputs) == 2:
dummy_tensor = inputs[1]
inputs = inputs[0]
else:
dummy_tensor = inputs[-1]
inputs = inputs[:-1]
elif len(inputs) == 1:
inputs = inputs[0]
for idx, layer in enumerate(self.forward_funcs[start:end]):
self.curr_layer = idx + self._local_start
Expand All @@ -338,8 +368,11 @@ def exec_func(*inputs):
self.seed_fn(new_seed)
else:
ds_utils.set_random_seed(new_seed)

inputs = layer(inputs)
if wrap_layer and idx == 0:
# the first checkpoint layer is wrapped by ModuleWrapper and get a dummy_tensor with requires_grad=True
inputs = layer(inputs, dummy_tensor)
else:
inputs = layer(inputs)
return inputs

return exec_func
Expand All @@ -359,7 +392,14 @@ def exec_func(*inputs):
if not isinstance(x, tuple):
x = (x, )

if self._is_checkpointable(funcs):
if self._is_checkpointable(funcs) == CkptLayer_Enum.wrap_ckpt_layer:
if not self.is_wrapped_ckptlayer:
self.forward_funcs[start_idx] = ModuleWrapper(self.forward_funcs[start_idx])
self.is_wrapped_ckptlayer = True

x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx, wrap_layer=True), *x,
self.dummy_tensor.to(get_accelerator().current_device()))
elif self._is_checkpointable(funcs) == CkptLayer_Enum.normal_ckpt_layer:
x = self.activation_checkpoint_func(exec_range_func(start_idx, end_idx), *x)
else:
x = exec_range_func(start_idx, end_idx)(*x)
Expand Down Expand Up @@ -581,6 +621,9 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal

checkpoint_engine.makedirs(save_dir, exist_ok=True)
for idx, layer in enumerate(layer_list):
if isinstance(layer, ModuleWrapper):
#unwrap the layer for save.
layer = layer.module
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
continue
Expand Down Expand Up @@ -618,13 +661,13 @@ def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
self._synchronize_tied_weights()

def _is_checkpointable(self, funcs):
# This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations.
# Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
# I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'):
return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs)
if self.checkpointable_layers is not None:
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)

if isinstance(funcs[0], ModuleWrapper):
return CkptLayer_Enum.wrap_ckpt_layer
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)
if any(len(list(p)) > 0 for p in params):
if not self.is_wrapped_ckptlayer:
return CkptLayer_Enum.wrap_ckpt_layer

return CkptLayer_Enum.normal_ckpt_layer
return CkptLayer_Enum.not_ckpt_layer
Loading