Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(activation_checkpointing): add non_reentrant_checkpoint to support inputs require no grad #4118

Merged
merged 27 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a20c79c
feat: add `non_reentrant_checkpoint`
hughpu Aug 7, 2023
8aeba5f
feat: add missing output postprocess and change the hook to record le…
hughpu Aug 7, 2023
ee04fa8
fix: make the multi_grad_hook registered after graph construction
hughpu Aug 7, 2023
51f833d
fix: backward compatibility for multi_tensor_hook
hughpu Aug 8, 2023
b29c1ef
fix: nonlocal reference error of deepspeed_saved_tensors
hughpu Aug 8, 2023
37e7c23
fix: reduce repeating hook registration
hughpu Aug 8, 2023
d7c5440
Merge branch 'microsoft:master' into feat/non-reentrant-checkpoint
hughpu Aug 9, 2023
e22c487
test: add test for `activation_checkpointing.checkpointing.non_reentr…
hughpu Aug 9, 2023
4d2a274
Pass correct node size for ZeRO++ (#4085)
cmikeh2 Aug 9, 2023
d4d070b
add deepspeed chat arxiv report (#4110)
conglongli Aug 9, 2023
aaf309e
style: change flake8 detected style missmatch
hughpu Aug 9, 2023
a910922
test: hack to clone the `test_activation_checkpointing` module for re…
hughpu Aug 9, 2023
fc919b1
doc: explain the introduction of `non_reentrant_checkpoint`
hughpu Aug 9, 2023
b6a0a44
doc: explain the test of `non_reentrant_checkpoint`
hughpu Aug 9, 2023
8ec86a4
Merge branch 'microsoft:master' into feat/non-reentrant-checkpoint
hughpu Aug 9, 2023
78c0d65
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 9, 2023
e4eff23
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 10, 2023
a6c7871
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 11, 2023
fbbb760
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 15, 2023
a338097
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 16, 2023
a00cff1
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 16, 2023
c17cc3d
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 17, 2023
a680399
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 18, 2023
13e766d
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 19, 2023
a46e326
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 21, 2023
b5c03f4
Merge branch 'master' into feat/non-reentrant-checkpoint
hughpu Aug 23, 2023
13a026d
Merge branch 'master' into feat/non-reentrant-checkpoint
tjruwase Aug 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import contextlib
from deepspeed import comm as dist
import weakref

import mmap
from torch import _C
Expand Down Expand Up @@ -705,6 +706,271 @@ def backward(ctx, *grads):
return tuple(ret_list)


def non_reentrant_checkpoint(function, *args):
"""This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module

This function is aim to solve the back probagation error raised from all input requires no grad.
* has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode.
* can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable`

Main modifications compared to the implementation of torch:
1. adapt to the signature of `checkpoint` function in this module
2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction`
3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation
4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution.
5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0
"""
global mpu, timers, SYNCHRONIZE, PROFILE_TIME

deepspeed_saved_tensors = None
non_tensor_args = None
tensor_flags = None

def save_args_for_backward(*all_args):
"""keep this function to reduce the modification from original implementation"""
nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
deepspeed_saved_tensors = tensor_args
non_tensor_args = non_tensor_args
tensor_flags = tensor_flags

if SYNCHRONIZE:
get_accelerator().synchronize()

if timers is None and PROFILE_TIME:
timers = Timers()

if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).start()

global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None

global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

if cuda_device is None:
see_memory_usage("First Forward Beginning", force=False)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")

cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)

if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT:
inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)

# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)

# Copy the rng states.
fwd_cpu_rng_state = torch.get_rng_state()
fwd_cuda_rng_state = get_accelerator().get_rng_state()
fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

if PARTITION_ACTIVATIONS:
new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)

class Holder():
"""the place holder object used as activations to save memory"""
pass

# weakref seems utilized to discover the tensor deletion before a whole
# forward backward pair loop finished
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
weak_holder_list = []
leaf_tensors = []
backward_visited_leaf_nodes = 0

def checkpoint_pack(tensor_from_forward):
"""used to record the activation order in the `weak_holder_list`

the activation order in holder list is consistent between the first forward and recomputing forward.
* the jit compiled forward will break the order consistency *
"""
res = Holder()
weak_holder_list.append(weakref.ref(res))

# if this is a leaf tensor, save it for backward progression trace
# leaf tensor used to be input or parameters, which is not activations and
# has no memory overhead
if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf:
leaf_tensors.append(tensor_from_forward)
return res

def checkpoint_unpack(holder_from_backward):
"""retrieve the activations from recompute"""
nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags

# if this is the first step of backward probagation, recompute the graph and save
# all the activations with the same order as `checkpoint_pack` does
if len(storage) == 0:
unpack_counter = 0

def replay_pack(tensor_from_replay):
"""save recompute activations"""
nonlocal unpack_counter
unpack_counter += 1

if weak_holder_list[unpack_counter - 1]() is None:
return

detached_activations = tensor_from_replay.detach()
storage[weak_holder_list[unpack_counter - 1]()] = detached_activations

return

def replay_unpack(none_value):
"""recompute graph need not to backward"""
raise RuntimeError("You are calling backwards on a tensor that is never exposed.")

global timers
see_memory_usage("In backward", force=False)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
get_accelerator().synchronize()
if PROFILE_TIME:
timers('backward').start()

if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers

for buffers in contiguous_data_buffers:
buffers = []

# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []

see_memory_usage("In backward checkpointing code", force=False)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")

global cuda_device, transport_stream, PARTITION_ACTIVATIONS

# gather inputs which is partitioned or checkpointed before first forward
if PARTITION_ACTIVATIONS:
# with get_accelerator().stream(transport_stream):
inputs = gather_partitioned_activations(deepspeed_saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = deepspeed_saved_tensors
detached_inputs = detach_variable(inputs)

# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=non_tensor_args,
tensor_flags=tensor_flags)

# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = get_accelerator().get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

# Set the states to what it used to be before the forward pass.
torch.set_rng_state(fwd_cpu_rng_state)
_set_cuda_rng_state(fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker)

see_memory_usage("In backward checkpointing code before forward", force=False)
with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack):
_unused = function(*detached_inputs)

see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

deepspeed_saved_tensors = None
non_tensor_args = None
tensor_flags = None

if holder_from_backward not in storage:
raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported.")

return storage[holder_from_backward]

def after_backward_hook(_nonuse_grads):
"""the hook registered to all leaf tensors"""
nonlocal leaf_tensors, backward_visited_leaf_nodes
backward_visited_leaf_nodes += 1

if backward_visited_leaf_nodes == len(leaf_tensors):
see_memory_usage("After backward checkpointing code after backward", force=False)

if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()

with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
outputs = function(*inputs_cuda)
for leaf_tensor in leaf_tensors:
leaf_tensor.register_hook(after_backward_hook)

see_memory_usage("After running forward on the layer", force=False)

if PROFILE_TIME:
timers(FORWARD_GLOBAL_TIMER).stop()
timers.log([FORWARD_GLOBAL_TIMER])
if SYNCHRONIZE:
get_accelerator().synchronize()

all_outputs = []
if torch.is_tensor(outputs):
all_outputs += [outputs]
else:
all_outputs += outputs

if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)


def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
Expand Down
Loading
Loading