Skip to content

Commit

Permalink
graph_capture
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Sep 13, 2023
1 parent 581e44d commit 4abab21
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 28 deletions.
59 changes: 41 additions & 18 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage)
is_model_parallel_parameter, see_memory_usage, graph_warp)

from deepspeed.utils import link_hp_params, fragment_address
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
PARAM_SLICE_MAPPINGS)
import time

setattr(sys.modules[__name__], 'fragment_address', fragment_address)

Expand All @@ -37,7 +38,8 @@ def __init__(self,
norm_type=2,
allgather_bucket_size=5000000000,
dp_process_group=None,
timers=None):
timers=None,
use_graph_for_utils=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
Expand Down Expand Up @@ -77,12 +79,15 @@ def __init__(self,

self.step_count = 0
self.group_paddings = []

self.use_graph_for_utils = use_graph_for_utils
if self.using_real_optimizer:
self._setup_for_real_optimizer()

see_memory_usage('end bf16_optimizer', force=True)

def print_rank0(self, msg):
if dist.get_rank() == 0:
print(msg)

def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
Expand Down Expand Up @@ -237,15 +242,17 @@ def step(self, closure=None):

all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
mpu=self.mpu,
norm_type=self.norm_type)
norm_type=self.norm_type,
use_graph=self.use_graph_for_utils)
self._global_grad_norm = all_groups_norm

assert all_groups_norm > 0.
if self.clip_grad > 0.:
clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
max_norm=self.clip_grad,
global_norm=all_groups_norm,
mpu=self.mpu)
mpu=self.mpu,
use_graph=self.use_graph_for_utils)

self.optimizer.step()

Expand All @@ -271,22 +278,36 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg

@torch.no_grad()
def update_hp_grads(self, clear_lp_grads=False):

def _update_hp_grads_func(clear_lp_grads=False):
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue
hp_grad = self.fp32_groups_gradients[i][j]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[i][j] = True
# clear gradients
if clear_lp_grads:
lp.grad._zero()

start = time.time()
if self.use_graph_for_utils:
graph_warp(False, _update_hp_grads_func, clear_lp_grads)
else:
_update_hp_grads_func(clear_lp_grads)
#cpu op
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue

hp_grad = self.fp32_groups_gradients[i][j]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'

hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[i][j] = True

# clear gradients
if clear_lp_grads:
lp.grad = None
end = time.time()
duration = end - start
self.print_rank0(f"update_hp_grads, use_graph:{self.use_graph_for_utils}, duration:{duration}")

@torch.no_grad()
def get_grads_for_reduction(self):
Expand Down Expand Up @@ -337,7 +358,9 @@ def clear_hp_grads(self):
def clear_lp_grads(self):
for group in self.bf16_groups:
for param in group:
param.grad = None
if param.grad is not None:
# Using zero_() fixed memory address for graph replay
param.grad.zero_()

def state_dict(self):
state_dict = {}
Expand Down
80 changes: 70 additions & 10 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@

from deepspeed.module_inject.policy import transpose
from torch.nn import functional as F
import time

torch_memory_reserved = get_accelerator().memory_reserved
torch_max_memory_reserved = get_accelerator().max_memory_reserved


def print_rank0(msg):
if dist.get_rank() == 0:
print(msg)


class DummyOptim():
"""
Dummy optimizer presents model parameters as a param group, this is
Expand All @@ -47,6 +53,25 @@ def __init__(self, params):
self.param_groups = []
self.param_groups.append({'params': params})

graph_cache = {}
def graph_warp(replay_first_step, func, *args, **kwargs):
# `func` should only contain operations on the GPU
# Please ensure that the memory address of the data required by 'func' remains constant
if func.__name__ not in graph_cache:
cuda_stream = get_accelerator().Stream()
cuda_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(cuda_stream):
func(*args, **kwargs)
get_accelerator().current_stream().wait_stream(cuda_stream)
# TODO: Apply get_accelerator interface for torch.cuda.CUDAGraph and torch.cuda.graph #ignore-cuda
graph_cache[func.__name__] = torch.cuda.CUDAGraph() #ignore-cuda
with torch.cuda.graph(graph_cache[func.__name__]): #ignore-cuda
func(*args, **kwargs)
if replay_first_step:
graph_cache[func.__name__].replay()
else:
graph_cache[func.__name__].replay()


def noop_decorator(func):
return func
Expand Down Expand Up @@ -864,7 +889,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep
return global_grad_norm


def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None):
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False):
"""Get norm of an iterable of tensors.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
Expand All @@ -878,7 +903,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None):
Returns:
Total norm of the tensors (viewed as a single vector).
"""

start = time.time()
assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}'
assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors'

Expand All @@ -890,8 +915,27 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None):
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if use_graph:
if 'norm_tensors_compute_buffer' not in graph_cache:
graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors]
compute_buffer = graph_cache['norm_tensors_compute_buffer']

def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
for i, t in enumerate(tensor_list):
_compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type)
if i != 0:
_compute_buffer[0].data.add_(_compute_buffer[i].data)

graph_warp(False, _norm_tensors, input_tensors, compute_buffer, norm_type)

total_norm = compute_buffer[0]
else:
total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])

end = time.time()
duration = end - start
print_rank0(f"norm in get_global_norm_of_tensors, use_graph:{use_graph}, duration:{duration}")
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach()
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
Expand All @@ -902,7 +946,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None):
return total_norm


def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6):
def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False):
"""Clip list of tensors by global norm.
Args:
input_tensors: List of tensors to be clipped
Expand All @@ -913,14 +957,30 @@ def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, m
float: the global norm
"""
if global_norm is None:
global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu)

global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph)
start = time.time()
clip_coef = max_norm / (global_norm + eps)

if clip_coef < 1:
for t in input_tensors:
t.detach().mul_(clip_coef)
if use_graph:

def clip_tensors(_tensor_list, _clip_coef_tensor):
for t in _tensor_list:
t.detach().mul_(_clip_coef_tensor)

if 'clip_coef_tensor' not in graph_cache:
# Alloc memory
graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef,
dtype=torch.float32).to(get_accelerator().device_name())
clip_coef_tensor = graph_cache['clip_coef_tensor']
clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32))
graph_warp(False, clip_tensors, input_tensors, clip_coef_tensor)

else:
for t in input_tensors:
t.detach().mul_(clip_coef)
end = time.time()
duration = end - start
print_rank0(f"mul_ in clip_tensors_by_global_norm_, use_graph:{use_graph}, duration:{duration}")
return global_norm


Expand Down

0 comments on commit 4abab21

Please sign in to comment.