Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel]Add c_concat pass for reshard #47809

Merged
merged 2 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 141 additions & 35 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,42 @@ def __repr__(self):
return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."


class AllGatherConcatOpDesc:
"""
Describe the c_concat op in the reshard phase.

Args:
group (list): Process group.
shape (list): The tensor shape.
is_bool (bool): Whether c_concat bool data. Default: False.
"""

def __init__(self, group, shape, is_bool=False):
self._group = group
self._desc = "c_concat"
self._shape = shape
self._is_bool = is_bool

@property
def is_bool(self):
return self._is_bool

@property
def group(self):
return self._group

@property
def desc(self):
return self._desc

@property
def shape(self):
return self._shape

def __repr__(self):
return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}."


class SendOpDesc:
"""
Describe the send op in the reshard phase.
Expand Down Expand Up @@ -640,6 +676,46 @@ def insert_allgather_op(block, idx, tensor, ranks, op_role):
tensor_list.extend(split_out)
return tensor_list, idx_offset

@staticmethod
def insert_c_concat_op(block, idx, tensor, ranks, op_role):
"""Insert c_concat op into block at the given index."""
group = new_process_group(ranks)
idx_offset = 0

# insert c_concat op
op_type = 'c_concat'
# to avoid name conflict with framework
helper = LayerHelper(op_type + "@RESHARD", **locals())
with paddle.static.program_guard(block.program):
c_concat_out = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
),
dtype=tensor.dtype,
shape=None,
lod_level=tensor.lod_level,
type=tensor.type,
persistable=False,
stop_gradient=False,
)
cur_rank = paddle.distributed.get_rank()
c_concat_op = block._insert_op(
idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [c_concat_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
'nranks': group.nranks,
'op_role': op_role,
'rank': group.ranks.index(cur_rank) if cur_rank in ranks else 0,
},
)
c_concat_op._set_attr('op_namescope', "/auto_parallel/reshard")
return c_concat_out

@staticmethod
def concat_partitions_with_op(
partition_tensor_list, tensor, partition_index, block, idx, op_role
Expand Down Expand Up @@ -1535,7 +1611,7 @@ def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False):
)
)

# in the same process group, it will use allgahther and slice op.
# In the same process group, it will use allgahther and slice op.
else:
# NOTE: It just supports even partition scene.
partition_index_list = []
Expand Down Expand Up @@ -1599,21 +1675,37 @@ def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False):
if not serial
else dist_tensor.local_sizes(rank=process)
)
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
),
slice_op_desc,
# c_concat pass
if (
target_dims_mapping.count(-1)
== len(target_dims_mapping)
and source_dims_mapping[:-1].count(-1)
== len(source_dims_mapping[:-1])
and source_dims_mapping[-1] != -1
):
op_desc_seq[process] = [
AllGatherConcatOpDesc(
group=group, shape=allgather_shape
)
]
if len(group) > 1
else [slice_op_desc]
)
else:
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
shape=allgather_shape,
is_bool=(
source_tensor.dtype == paddle.bool
),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
),
slice_op_desc,
]
if len(group) > 1
else [slice_op_desc]
)

return op_desc_seq

Expand Down Expand Up @@ -1850,27 +1942,41 @@ def parse_op_desc(
)
idx = idx_list[0]

elif isinstance(op_desc, SliceOpDesc):
assert (
len(partition_tensor_list) == 1 or not partition_tensor_list
)
to_slice_tensor = (
partition_tensor_list[0][0]
if len(partition_tensor_list) == 1
else source_tensor
)
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = Inserter.insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'),
)
elif isinstance(op_desc, SliceOpDesc) or isinstance(
op_desc, AllGatherConcatOpDesc
):
target_tensor = None
if isinstance(op_desc, SliceOpDesc):
assert (
len(partition_tensor_list) == 1
or not partition_tensor_list
)
to_slice_tensor = (
partition_tensor_list[0][0]
if len(partition_tensor_list) == 1
else source_tensor
)
new_name = unique_name.generate(var_name + "@RESHARD")
target_tensor = Inserter.insert_slice_op(
block,
idx,
to_slice_tensor,
starts=op_desc.starts,
ends=op_desc.ends,
axes=op_desc.axes,
new_var_name=new_name,
op_role=reshard_op.attr('op_role'),
)
else:
target_tensor = Inserter.insert_c_concat_op(
block,
idx,
source_tensor,
op_desc.group,
reshard_op.attr('op_role'),
)

assert target_tensor is not None
process_mesh = dist_attr[0]
dims_mapping = dist_attr[1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,60 @@ def test_allgather(self):
# the x should not be slice
self.assertTrue(check_allgather(partitioned_main_prog))

def test_c_concat(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x, process_mesh, [None, "x"])
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w, process_mesh, [None, None])

y = paddle.distributed.shard_op(
paddle.matmul, process_mesh, [[None, None], [None, None]]
)(x, w)

rank_id = 0
dist_context = DistributedContext()
dist_strategy = fleet.DistributedStrategy()
partitioner = Partitioner(dist_context, rank_id)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
)
dist_context.block_state.parse_forward_blocks(complete_train_program)
(
partitioned_main_prog,
partitioned_startup_prog,
partitioned_params_grads,
) = partitioner.partition(complete_train_program, startup_program, [])

# test estimator
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context
)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context
)
assert global_cost.time >= 0
assert max_memory > 0

resharder = Resharder(
partitioned_main_prog,
partitioned_startup_prog,
rank_id,
dist_context,
partitioned_params_grads,
)
resharder.reshard()


if __name__ == "__main__":
unittest.main()