Skip to content

Commit

Permalink
[Semi-Auto] Update merging tensor dims_mapping part in update_op_node…
Browse files Browse the repository at this point in the history
…_dims_mapping function (PaddlePaddle#57008)

* mrege redundant code about merging dims_mapping in update_op_node_dims_mapping function

* fix the bug of importing paddle.fluid
  • Loading branch information
pkuzyc authored Sep 12, 2023
1 parent 5c428db commit f98a3c0
Showing 1 changed file with 84 additions and 139 deletions.
223 changes: 84 additions & 139 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,150 +219,95 @@ def _update_op_node_dims_mapping(self, op_node, fwd=True):
dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
# step 1: merge the dims mappings of in
# dist_op with corresponding tensors
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_input_dims_mapping(
node_list = op_node.inputs
else:
node_list = op_node.outputs
for tensor_node in node_list:
if not tensor_node.is_var() or tensor_node.var() is None:
continue
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue

tensor_desc = tensor_node.var()
if fwd:
annotated = op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()
)
else:
annotated = op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()
)
if annotated:
continue

tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_graph(tensor_node)
)
if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
tensor_dims_mapping = tensor_dist_attr.dims_mapping
if fwd:
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name()
):
continue
tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node
)
)
if (
op_dist_attr.process_mesh
== tensor_dist_attr.process_mesh
):
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name()
)
compatible_dims_mapping = (
compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]
)
)
if not _validate_dims_mapping(
compatible_dims_mapping, op_dist_attr.process_mesh
):
continue
if (compatible_dims_mapping is not None) and (
compatible_dims_mapping != op_dims_mapping
):
op_dist_attr.set_input_dims_mapping(
tensor_desc.name(), compatible_dims_mapping
)
changed = True
# Find the most compatible implementations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER:
continue
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping(
else:
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name()
):
continue
tensor_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node
)
)
if (
op_dist_attr.process_mesh
== tensor_dist_attr.process_mesh
):
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name()
)
compatible_dims_mapping = (
compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]
)

compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping]
)
if not _validate_dims_mapping(
compatible_dims_mapping, op_dist_attr.process_mesh
):
continue
if (compatible_dims_mapping is not None) and (
compatible_dims_mapping != op_dims_mapping
):
if fwd:
op_dist_attr.set_input_dims_mapping(
tensor_desc.name(), compatible_dims_mapping
)
if not _validate_dims_mapping(
compatible_dims_mapping, op_dist_attr.process_mesh
):
continue
if (compatible_dims_mapping is not None) and (
compatible_dims_mapping != op_dims_mapping
):
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping
)
changed = True
# Find the most compatible implementations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=False
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping
)
changed = True

# step 2: infer distributed attributes in dist_op
# Find the most compatible implementations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False

return changed

def _update_dims_mapping_between_graphs(self):
Expand Down Expand Up @@ -1924,10 +1869,10 @@ def _init_global_mesh_for_program(self):
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
dist_op.dist_attr.impl_type = "default"
else:
dist_op.dist_attr.impl_type = op_dist_impl.type
# if op_dist_impl.type == "elementwise":
# dist_op.dist_attr.impl_type = "default"
# else:
dist_op.dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
dist_op.dist_attr.impl_idx = op_dist_impl.idx
break
Expand Down

0 comments on commit f98a3c0

Please sign in to comment.