Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing data parallel Fuse-Allreduce-Overlapping #48092

Merged
merged 58 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d73a5eb
add depend
JZ-LIANG Nov 9, 2022
6414f04
add depend
JZ-LIANG Nov 9, 2022
9565b84
add depend
JZ-LIANG Nov 9, 2022
ca8696c
add depend
JZ-LIANG Nov 9, 2022
7234931
add depend
JZ-LIANG Nov 9, 2022
e711b4f
add depend
JZ-LIANG Nov 9, 2022
075eabf
add depend
JZ-LIANG Nov 9, 2022
11e394c
add depend
JZ-LIANG Nov 9, 2022
5c050ca
add depend
JZ-LIANG Nov 9, 2022
eb13147
add depend
JZ-LIANG Nov 9, 2022
c7ca20f
add depend
JZ-LIANG Nov 9, 2022
d29a95c
add depend
JZ-LIANG Nov 9, 2022
e96a93d
add depend
JZ-LIANG Nov 9, 2022
79b2e77
add depend
JZ-LIANG Nov 9, 2022
29a60ab
add origin amp files
JZ-LIANG Nov 10, 2022
b8f3f69
fp16 distinguish None & False
JZ-LIANG Nov 10, 2022
f07fd8a
engine log
JZ-LIANG Nov 14, 2022
e6b1995
engine log
JZ-LIANG Nov 14, 2022
8b4f299
engine log
JZ-LIANG Nov 14, 2022
4ea87d6
engine log
JZ-LIANG Nov 14, 2022
97d490f
engine log
JZ-LIANG Nov 14, 2022
187e8a9
Merge remote-tracking branch 'upstream/develop' into AutoParallel/new…
JZ-LIANG Nov 14, 2022
7149e09
engine log
JZ-LIANG Nov 14, 2022
d994cb8
Merge remote-tracking branch 'upstream/develop' into AutoParallel/new…
JZ-LIANG Nov 14, 2022
34d09b0
log
JZ-LIANG Nov 14, 2022
0a09420
log
JZ-LIANG Nov 14, 2022
0e6a1f6
profile
JZ-LIANG Nov 16, 2022
b6c097b
issued order comm first calc later
JZ-LIANG Nov 16, 2022
6963024
disable comm op seq dep
JZ-LIANG Nov 16, 2022
f0aab8c
dp add deps for graph exe
JZ-LIANG Nov 17, 2022
e80f3e3
dp add deps for graph exe
JZ-LIANG Nov 17, 2022
1a682f1
dp add deps for graph exe
JZ-LIANG Nov 17, 2022
8057858
bugfix in recompute
JZ-LIANG Nov 17, 2022
d0cb4a5
bugfix in recompute
JZ-LIANG Nov 18, 2022
ae44626
bugfix
JZ-LIANG Nov 18, 2022
5f967ce
bugfix
JZ-LIANG Nov 18, 2022
3d96184
bugfix
JZ-LIANG Nov 18, 2022
e0592e0
bugfix
JZ-LIANG Nov 18, 2022
48912f3
bugfix
JZ-LIANG Nov 18, 2022
8c99fb4
bugfix
JZ-LIANG Nov 18, 2022
809c27b
bugfix
JZ-LIANG Nov 18, 2022
aefff08
bugfix
JZ-LIANG Nov 18, 2022
22fccd9
bugfix
JZ-LIANG Nov 18, 2022
bd0483d
bugfix
JZ-LIANG Nov 18, 2022
eefe981
bugfix
JZ-LIANG Nov 18, 2022
14fdb51
bugfix
JZ-LIANG Nov 18, 2022
80c53e0
bugfix
JZ-LIANG Nov 18, 2022
4cec50d
add deps for clip
JZ-LIANG Nov 18, 2022
927079f
add deps for clip
JZ-LIANG Nov 18, 2022
d13c3cc
add deps for clip
JZ-LIANG Nov 18, 2022
53ca0e3
add dep for grad clip
JZ-LIANG Nov 21, 2022
5e8df91
add dep for grad clip
JZ-LIANG Nov 21, 2022
00d8e3b
local
JZ-LIANG Nov 22, 2022
6d5d25b
clean code
JZ-LIANG Nov 22, 2022
5e774ee
Merge remote-tracking branch 'upstream/develop' into AutoParallel/new…
JZ-LIANG Nov 22, 2022
aa846e2
dep ops in comm stream
JZ-LIANG Nov 28, 2022
57e0132
Merge remote-tracking branch 'upstream/develop' into AutoParallel/new…
JZ-LIANG Nov 28, 2022
b9b8755
unitest
JZ-LIANG Nov 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,9 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
def naive_set_dist_op_attr_for_program_by_mesh(
new_op, process_mesh, ctx, is_recompute=False
):
# hack to skip coalesce var for dist attr
if not is_recompute:
return
assert process_mesh is not None

new_op_dist_attr = OperatorDistributedAttribute()
Expand Down Expand Up @@ -2129,13 +2132,13 @@ def insert_dependencies_for_two_ops(
block,
idx,
prior_op,
posterior,
posterior_op,
dist_context,
is_recompute=False,
sync=False,
):
"""
dependency: prior_op should be run before posterior
dependency: prior_op should be run before posterior_op
"""

assert (
Expand All @@ -2144,15 +2147,15 @@ def insert_dependencies_for_two_ops(
str(prior_op)
)
assert (
len(posterior.input_arg_names) >= 1
len(posterior_op.input_arg_names) >= 1
), "second op of dependency should at least have one input. [{}]".format(
str(posterior)
str(posterior_op)
)
prior_op_mesh = dist_context.get_op_dist_attr_for_program(
prior_op
).process_mesh
posterior_mesh = dist_context.get_op_dist_attr_for_program(
posterior
posterior_op
).process_mesh
assert (
prior_op_mesh == posterior_mesh
Expand All @@ -2171,25 +2174,72 @@ def _select_best_depend_var(vars):
[block.var(name) for name in prior_op.output_arg_names]
)
second_var = _select_best_depend_var(
[block.var(name) for name in posterior.input_arg_names]
[block.var(name) for name in posterior_op.input_arg_names]
)

return insert_dependencies_for_two_vars(
block,
idx,
first_var,
second_var,
dist_context,
OpRole.Backward,
prior_op_mesh,
is_recompute,
sync,
)


def insert_dependencies_for_two_vars(
block,
idx,
prior_var,
post_var,
dist_context,
oprole,
process_mesh=None,
is_recompute=False,
sync=False,
):
"""
dependency: op that generates prior_var should be run before op that generates post_var
"""
assert block.has_var(prior_var.name)
assert block.has_var(post_var.name)
if process_mesh is None:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
post_var
).process_mesh
assert process_mesh is not None

depend_op = block._insert_op_without_sync(
idx,
type='nop',
inputs={
"X": first_var,
"X": prior_var,
},
outputs={"Out": second_var},
outputs={"Out": post_var},
)
# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, OpRole.Backward)
depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
# self.desc.set_output(out_proto.name, out_arg_names)

naive_set_dist_op_attr_for_program_by_mesh(
depend_op, prior_op_mesh, dist_context, is_recompute
depend_op, process_mesh, dist_context, is_recompute
)

if sync:
block._sync_with_cpp()

return depend_op


def use_standalone_executor():
return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [
1,
'1',
True,
'True',
'true',
]
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
find_higher_order_backward_op,
is_loss_grad_op,
is_optimize_op,
is_forward_op,
ring_id_to_process_group,
get_var_numel,
use_standalone_executor,
insert_dependencies_for_two_vars,
)

# add new optimizers supporting rescale_grad here
Expand Down Expand Up @@ -87,16 +90,20 @@ def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding")
self.coalesce_prefix = 'coalesce_grad'
if use_standalone_executor():
self.gradient_sync_stream = "gradient_sync_stream"

with paddle.static.program_guard(main_program, startup_program):
self._analyze_program()

# TODO refactor here to first fuse then overlap
if self.is_data_parallel_applied():
self._prune_grad_scaling()
self._calc_comm_overlap()
grad_group = self._fuse_allreduce()

# self.summary(grad_group)
self._add_dependencies(grad_group)
self.summary(grad_group)

def _prune_grad_scaling(self):

Expand Down Expand Up @@ -284,7 +291,6 @@ def _comms_overlap_calc(self):
# InterpreterCore has a different logic for overlapping
# which is different from use_calc_stream
block = default_main_program().global_block()
ops = block.ops

# comm wait calc to finish
for idx, op in reversed(list(enumerate(block.ops))):
Expand All @@ -294,7 +300,6 @@ def _comms_overlap_calc(self):

op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id")

block._insert_op_without_sync(
idx,
type='c_wait_compute',
Expand All @@ -307,8 +312,10 @@ def _comms_overlap_calc(self):

def _calc_wait_comms(self):

if use_standalone_executor():
return

block = default_main_program().global_block()
ops = block.ops

# NOTE the naive overlap implement in static hybird parallel only sync comm stream
# at the end of Backward phase, based on a strong constraint that
Expand All @@ -325,7 +332,7 @@ def _calc_wait_comms(self):
ring_id_to_un_sync_grad_map[group.id] = []

# analyze the where need to sync
for i, op in enumerate(ops):
for i, op in enumerate(block.ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
Expand Down Expand Up @@ -365,6 +372,7 @@ def _calc_wait_comms(self):
outputs={'Out': []},
attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
)
block._sync_with_cpp()

def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
Expand Down Expand Up @@ -404,8 +412,6 @@ def _group_grads(self):
def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0:
cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else:
cur_group.finalize()
grad_groups.append(cur_group)
Expand Down Expand Up @@ -451,9 +457,16 @@ def _update_program(self, grad_groups):

for i, group in enumerate(grad_groups[::-1]):

# skip unfused big tensor
if len(group.gradients) <= 1:
group.coalesce_var = group.gradients[0]
continue

# create coalecse tensor
group.coalesce_var = block.create_var(
name=unique_name.generate('coalecse_grad_{}'.format(i)),
name=unique_name.generate(
self.coalesce_prefix + '_{}'.format(i)
),
dtype=group.dtype,
persistable=False,
stop_gradient=True,
Expand Down Expand Up @@ -497,7 +510,7 @@ def _update_program(self, grad_groups):
), "Unexception: try to remove op {}".format(
str(block.ops[idx])
)
block._remove_op(idx)
block._remove_op(idx, False)

# insert coalecse op
concated_shapes = []
Expand Down Expand Up @@ -529,6 +542,141 @@ def _update_program(self, grad_groups):
block._sync_with_cpp()
# TODO update dist attr

def _add_dependencies(self, grad_groups):
# NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
# multiple stream executor(standalone exe). This function just for standalone exe. Refactor here
# in future when only one executor stay.

if not use_standalone_executor() or len(grad_groups) == 0:
return
block = default_main_program().global_block()

# Build maps
vars_to_coalesce_map = {}
coalesce_to_vars_map = {}

for group in grad_groups:
grad_names = []
coalesce_name = group.coalesce_var.name
for grad in group.gradients:
vars_to_coalesce_map[grad.name] = coalesce_name
grad_names.append(grad.name)
coalesce_to_vars_map[coalesce_name] = grad_names

# analyze dependencies
# Record ONLY the last grad that generated before allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
prior_allreduce_deps = {}
for idx, op in reversed(list(enumerate(block.ops))):
if is_forward_op(op):
break
if is_optimize_op(op):
continue

if is_data_parallel_reduce_op(op):
coalesce_var_name = op.output_arg_names[0]

# NOTE only add extra deps for fused tensor, other tensor rely on
# data flow analysis of executor.
if self.coalesce_prefix in coalesce_var_name:
prior_allreduce_deps[coalesce_var_name] = [
idx,
None,
coalesce_var_name,
]
not_sync_coalesces.append(coalesce_var_name)
continue

for out_name in op.output_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
prior_allreduce_deps[var_name][1] = out_name
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)

# Record ONLY the first grad that used after allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
post_allreduce_deps = {}
for idx, op in enumerate(block.ops):
if is_forward_op(op):
continue

if is_data_parallel_reduce_op(op):
coalesce_var_name = op.input_arg_names[0]
if self.coalesce_prefix in coalesce_var_name:
post_allreduce_deps[coalesce_var_name] = [
None,
coalesce_var_name,
None,
]
not_sync_coalesces.append(coalesce_var_name)
continue

for out_name in op.input_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
post_allreduce_deps[var_name][0] = idx
post_allreduce_deps[var_name][2] = out_name
not_sync_coalesces.remove(var_name)

assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)

# Update program IR insert dependencise op
dep_var_pairs = []
for deps in [prior_allreduce_deps, post_allreduce_deps]:
for pair in deps.values():
dep_var_pairs.append(pair)

dep_var_pairs.sort(key=lambda x: x[0], reverse=True)
for idx, prior_name, post_name in dep_var_pairs:
prior_var = block.var(prior_name)
post_var = block.var(post_name)
depend_op = insert_dependencies_for_two_vars(
block,
idx,
prior_var,
post_var,
self.dist_context,
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp()

# remove naive synchronization & assign allreduce stream
def remove_cond(op):
if op.type != "c_wait_compute":
return False
if len(op.input_arg_names) != 0:
return False
if len(op.output_arg_names) != 0:
return False
return True

for idx, op in reversed(list(enumerate(block.ops))):
if is_data_parallel_reduce_op(op):
op._set_attr('use_calc_stream', True)
op.dist_attr.execution_stream = self.gradient_sync_stream

if remove_cond(op):
block._remove_op(idx, sync=False)

block._sync_with_cpp()

def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
Expand Down
Loading