Skip to content

Commit

Permalink
[Semi Auto] Refactor Completion Mechanism (Part1) (PaddlePaddle#57447)
Browse files Browse the repository at this point in the history
* first commit

* framework

* matmul done

* elementwise done

* adapt done

* polish code

* revise logging

* revise log

* update doc

* enable LN unitest

* precommit

* bugfix reduce_sum

* bugfix assign

* bugfix for print program

* enable rule for dropout

* bugfix for dist op
  • Loading branch information
JZ-LIANG authored and jiahy0825 committed Oct 16, 2023
1 parent ff2fc84 commit c0781b1
Show file tree
Hide file tree
Showing 18 changed files with 834 additions and 84 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ void BindAutoParallel(py::module *m) {
py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string);

m->def(
"contains_spmd_rule",
[](const std::string op_type) {
return phi::distributed::SpmdRuleFactory::Instance().ContainsSpmdRule(
op_type) ||
SPMDRuleMap::Instance().Has(op_type); // TODO(ljz): unify here
},
py::return_value_policy::reference);

m->def(
"get_spmd_rule",
[](const std::string op_type) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ PD_REGISTER_SPMD_RULE(
trunc,
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
dropout,
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd),
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse));

// elementwise binary rule
PD_REGISTER_SPMD_RULE(
Expand Down
166 changes: 123 additions & 43 deletions python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,41 @@

import copy
import logging
import os

from paddle.base.core import get_spmd_rule # noqa: F401
from paddle.base.core import ( # noqa: F401
contains_spmd_rule,
get_phi_spmd_rule,
get_spmd_rule,
)
from paddle.base.log_helper import get_logger
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.framework import core

from ..process_mesh import ProcessMesh, compute_compatible_process_mesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls
from .operators import (
find_compatible_distributed_operator_impls,
find_distributed_operator_impl_container,
)
from .process_group import get_world_process_group
from .utils import (
__no_shape_var_type__,
get_logger,
is_gradient_clip_op,
is_naive_data_parallel,
)

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
__skip_dims_mapping_op__ = [
"create_py_reader",
"create_double_buffer_reader",
"while",
"read",
]


def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
Expand Down Expand Up @@ -105,12 +123,61 @@ def _validate_dims_mapping(dims_mapping, process_mesh):
return True


def _can_apply_infer_spmd_rule(dist_op):
enable = os.getenv("FLAGS_infer_spmd_enable", True)
if isinstance(enable, str):
enable = enable.lower()
enable = True if enable == 'true' else False
enable = bool(enable)

# TODO remove me. ops to be adapted: lookup_table_v2, reshape2, split, transpose2,
__adapted_ops__ = [
"matmul_v2",
"elementwise_div",
"gelu",
"fused_softmax_mask_upper_triangle",
"elementwise_add",
"elementwise_mul",
"assign",
"scale",
"dropout",
"reduce_sum",
"layer_norm",
]
op_type = dist_op.serial_op.type
return enable and contains_spmd_rule(op_type) and op_type in __adapted_ops__


def _update_op_dims_mapping_and_distoperatorimpl(
dist_op, original_op_dist_attr, changed
):
dist_op_container = find_distributed_operator_impl_container(dist_op)
_logger.debug(
"Update Op [{}] using DistOpContainer [{}].".format(
dist_op.serial_op.type, dist_op_container.type
)
)
updated = dist_op_container.update_dims_mapping(dist_op)
changed = updated or changed
# TODO(ljz) remove the below code once we introduce general reshard to replace specifc distopimpls
reverted = dist_op_container.mapping_to_dist_operator_impl(
dist_op, original_op_dist_attr
)
_logger.debug(
"Op [{}] use dist op impl [{}] idx [{}].".format(
dist_op.serial_op.type,
dist_op.dist_attr.impl_type,
dist_op.dist_attr.impl_idx,
)
)
return changed and not (reverted)


class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False
self._logger = get_logger(logging.INFO, "Completer")

def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
Expand Down Expand Up @@ -205,22 +272,20 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):

def _update_op_node_dims_mapping(self, op_node, fwd=True):
changed = False
op_desc = op_node.op()

# step0: skip corner cases
if (not op_node.is_op()) or (op_node.op() is None):
return False
# Skip reader op
op_desc = op_node.op()
if (
op_desc.type() == "create_py_reader"
or op_desc.type() == "create_double_buffer_reader"
or op_desc.type() == "while"
or op_desc.type() == "read"
):
if op_desc.type() in __skip_dims_mapping_op__:
return False

dist_op = self._dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
original_op_dist_attr = copy.deepcopy(op_dist_attr)
# step 1: merge the dims mappings of in
# dist_op with corresponding tensors

# step 1: merge the dims mappings from tensor nodes to op nodes
if fwd:
node_list = op_node.inputs
else:
Expand Down Expand Up @@ -277,38 +342,53 @@ def _update_op_node_dims_mapping(self, op_node, fwd=True):
)
changed = True

# step 2: infer distributed attributes in dist_op
# Find the most compatible implementations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
# step 2: Infer & Update dims mapping of op node using SPMD Rule.
if _can_apply_infer_spmd_rule(dist_op):
_logger.debug(
"Op [{}] update dims mapping using New InferSPMD Rule.".format(
dist_op.serial_op.type
)
)
return _update_op_dims_mapping_and_distoperatorimpl(
dist_op, original_op_dist_attr, changed
)
else:
_logger.debug(
"Op [{}] update dims mapping using Original DistOp Rule.".format(
dist_op.serial_op.type
)
)
# update_op_dims_mapping_v1()
op_dist_impls = find_compatible_distributed_operator_impls(
dist_op, fwd=True
)
if op_dist_impls is not None:
not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr)
backup_changed = changed
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if (
op_dist_impl.is_auto_compatible(dist_op)
and dist_op.validate_dist_attr()
):
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
not_compatible = False
break
else:
dist_op.dist_attr = backup_op_dist_attr
changed = backup_changed
if not_compatible:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False
else:
dist_op.dist_attr = original_op_dist_attr
changed = False

return changed
return changed

def _update_dims_mapping_between_graphs(self):
changed = False
Expand Down Expand Up @@ -900,7 +980,7 @@ def complete_forward_annotation(self, serial_main_program=None):
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
else:
self._logger.info("Default distributed attributed will be set.")
_logger.info("Default distributed attributed will be set.")
self._dist_context.initialize(with_graph=False)
# A fast and special completion for data parallel
self._update_dist_attr_for_dp()
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/static/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def dist_attr(self, dist_attr):
def get_serial_input(self, name):
if self._serial_op.type == "create_py_reader":
tensor = None
else:
elif self._serial_op.block._find_var_recursive(name) is not None:
tensor = self._serial_op.block._var_recursive(name)
else:
tensor = None
return tensor

def get_serial_output(self, name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import find_compatible_distributed_operator_impls
from .common import find_distributed_operator_impl_container

from . import dist_embedding
from . import dist_matmul
from . import dist_reshape
Expand All @@ -39,3 +41,4 @@
from . import dist_scale
from . import dist_dropout
from . import dist_flash_attn
from . import dist_layer_norm
Loading

0 comments on commit c0781b1

Please sign in to comment.