Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi Auto] Refactor Completion Mechanism (Part1) #57447

Merged
merged 21 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ void BindAutoParallel(py::module *m) {
py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string);

m->def(
"contains_spmd_rule",
[](const std::string op_type) {
return phi::distributed::SpmdRuleFactory::Instance().ContainsSpmdRule(
op_type) ||
SPMDRuleMap::Instance().Has(op_type); // TODO(ljz): unify here
},
py::return_value_policy::reference);

m->def(
"get_spmd_rule",
[](const std::string op_type) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ PD_REGISTER_SPMD_RULE(
trunc,
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
dropout,
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse));

// elementwise binary rule
PD_REGISTER_SPMD_RULE(
Expand Down
166 changes: 123 additions & 43 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,41 @@

import copy
import logging
import os

from paddle.base.core import get_spmd_rule # noqa: F401
from paddle.base.core import ( # noqa: F401
contains_spmd_rule,
get_phi_spmd_rule,
get_spmd_rule,
)
from paddle.base.log_helper import get_logger
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.framework import core

from ..process_mesh import ProcessMesh, compute_compatible_process_mesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls
from .operators import (
find_compatible_distributed_operator_impls,
find_distributed_operator_impl_container,
)
from .process_group import get_world_process_group
from .utils import (
__no_shape_var_type__,
get_logger,
is_gradient_clip_op,
is_naive_data_parallel,
)

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
__skip_dims_mapping_op__ = [
"create_py_reader",
"create_double_buffer_reader",
"while",
"read",
]


def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
Expand Down Expand Up @@ -105,12 +123,61 @@ def _validate_dims_mapping(dims_mapping, process_mesh):
return True


def _can_apply_infer_spmd_rule(dist_op):
enable = os.getenv("FLAGS_infer_spmd_enable", True)
if isinstance(enable, str):
enable = enable.lower()
enable = True if enable == 'true' else False
enable = bool(enable)

# TODO remove me. ops to be adapted: lookup_table_v2, reshape2, split, transpose2,
__adapted_ops__ = [
"matmul_v2",
"elementwise_div",
"gelu",
"fused_softmax_mask_upper_triangle",
"elementwise_add",
"elementwise_mul",
"assign",
"scale",
"dropout",
"reduce_sum",
"layer_norm",
]
op_type = dist_op.serial_op.type
return enable and contains_spmd_rule(op_type) and op_type in __adapted_ops__


def _update_op_dims_mapping_and_distoperatorimpl(
dist_op, original_op_dist_attr, changed
):
dist_op_container = find_distributed_operator_impl_container(dist_op)
_logger.debug(
"Update Op [{}] using DistOpContainer [{}].".format(
dist_op.serial_op.type, dist_op_container.type
)
)
updated = dist_op_container.update_dims_mapping(dist_op)
changed = updated or changed
# TODO(ljz) remove the below code once we introduce general reshard to replace specifc distopimpls
reverted = dist_op_container.mapping_to_dist_operator_impl(
dist_op, original_op_dist_attr
)
_logger.debug(
"Op [{}] use dist op impl [{}] idx [{}].".format(
dist_op.serial_op.type,
dist_op.dist_attr.impl_type,
dist_op.dist_attr.impl_idx,
)
)
return changed and not (reverted)


class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False
self._logger = get_logger(logging.INFO, "Completer")

def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
Expand Down Expand Up @@ -205,22 +272,20 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):

def _update_op_node_dims_mapping(self, op_node, fwd=True):
changed = False
op_desc = op_node.op()

# step0: skip corner cases
if (not op_node.is_op()) or (op_node.op() is None):
return False
# Skip reader op
op_desc = op_node.op()
if (
op_desc.type() == "create_py_reader"
or op_desc.type() == "create_double_buffer_reader"
or op_desc.type() == "while"
or op_desc.type() == "read"
):
if op_desc.type() in __skip_dims_mapping_op__:
return False

dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
# step 1: merge the dims mappings of in
# dist_op with corresponding tensors

# step 1: merge the dims mappings from tensor nodes to op nodes
if fwd:
node_list = op_node.inputs
else:
Expand Down Expand Up @@ -277,38 +342,53 @@ def _update_op_node_dims_mapping(self, op_node, fwd=True):
)
changed = True

# step 2: infer distributed attributes in dist_op
# Find the most compatible implementations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
# step 2: Infer & Update dims mapping of op node using SPMD Rule.
if _can_apply_infer_spmd_rule(dist_op):
_logger.debug(
"Op [{}] update dims mapping using New InferSPMD Rule.".format(
dist_op.serial_op.type
)
)
return _update_op_dims_mapping_and_distoperatorimpl(
dist_op, original_op_dist_attr, changed
)
else:
_logger.debug(
"Op [{}] update dims mapping using Original DistOp Rule.".format(
dist_op.serial_op.type
)
)
# update_op_dims_mapping_v1()
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False

return changed
return changed

def _update_dims_mapping_between_graphs(self):
changed = False
Expand Down Expand Up @@ -900,7 +980,7 @@ def complete_forward_annotation(self, serial_main_program=None):
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
else:
self._logger.info("Default distributed attributed will be set.")
_logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel
self._update_dist_attr_for_dp()
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/static/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def dist_attr(self, dist_attr):
def get_serial_input(self, name):
if self._serial_op.type == "create_py_reader":
tensor = None
else:
elif self._serial_op.block._find_var_recursive(name) is not None:
tensor = self._serial_op.block._var_recursive(name)
else:
tensor = None
return tensor

def get_serial_output(self, name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import find_compatible_distributed_operator_impls
from .common import find_distributed_operator_impl_container

from . import dist_embedding
from . import dist_matmul
from . import dist_reshape
Expand All @@ -39,3 +41,4 @@
from . import dist_scale
from . import dist_dropout
from . import dist_flash_attn
from . import dist_layer_norm
Loading