diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index f6596f3db31d5..62b595a13f960 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -543,6 +543,15 @@ void BindAutoParallel(py::module *m) { py::arg("memo")) .def("__str__", &OperatorDistAttr::to_string); + m->def( + "contains_spmd_rule", + [](const std::string op_type) { + return phi::distributed::SpmdRuleFactory::Instance().ContainsSpmdRule( + op_type) || + SPMDRuleMap::Instance().Has(op_type); // TODO(ljz): unify here + }, + py::return_value_policy::reference); + m->def( "get_spmd_rule", [](const std::string op_type) { diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 32a45d4dd8b3c..5418b34e28b57 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -333,6 +333,10 @@ PD_REGISTER_SPMD_RULE( trunc, PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + dropout, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); // elementwise binary rule PD_REGISTER_SPMD_RULE( diff --git a/python/paddle/distributed/auto_parallel/static/completion.py b/python/paddle/distributed/auto_parallel/static/completion.py index baf183f55bbae..372144982327c 100644 --- a/python/paddle/distributed/auto_parallel/static/completion.py +++ b/python/paddle/distributed/auto_parallel/static/completion.py @@ -14,23 +14,41 @@ import copy import logging +import os -from paddle.base.core import get_spmd_rule # noqa: F401 +from paddle.base.core import ( # noqa: F401 + contains_spmd_rule, + get_phi_spmd_rule, + get_spmd_rule, +) +from paddle.base.log_helper import get_logger from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.framework import core from ..process_mesh import ProcessMesh, compute_compatible_process_mesh from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_context import _node_id -from .operators import find_compatible_distributed_operator_impls +from .operators import ( + find_compatible_distributed_operator_impls, + find_distributed_operator_impl_container, +) from .process_group import get_world_process_group from .utils import ( __no_shape_var_type__, - get_logger, is_gradient_clip_op, is_naive_data_parallel, ) +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) +__skip_dims_mapping_op__ = [ + "create_py_reader", + "create_double_buffer_reader", + "while", + "read", +] + def compute_compatible_dim_mapping(dim_mapping_list): """Compute the compatible dim mapping given a list of dim mapping.""" @@ -105,12 +123,61 @@ def _validate_dims_mapping(dims_mapping, process_mesh): return True +def _can_apply_infer_spmd_rule(dist_op): + enable = os.getenv("FLAGS_infer_spmd_enable", True) + if isinstance(enable, str): + enable = enable.lower() + enable = True if enable == 'true' else False + enable = bool(enable) + + # TODO remove me. ops to be adapted: lookup_table_v2, reshape2, split, transpose2, + __adapted_ops__ = [ + "matmul_v2", + "elementwise_div", + "gelu", + "fused_softmax_mask_upper_triangle", + "elementwise_add", + "elementwise_mul", + "assign", + "scale", + "dropout", + "reduce_sum", + "layer_norm", + ] + op_type = dist_op.serial_op.type + return enable and contains_spmd_rule(op_type) and op_type in __adapted_ops__ + + +def _update_op_dims_mapping_and_distoperatorimpl( + dist_op, original_op_dist_attr, changed +): + dist_op_container = find_distributed_operator_impl_container(dist_op) + _logger.debug( + "Update Op [{}] using DistOpContainer [{}].".format( + dist_op.serial_op.type, dist_op_container.type + ) + ) + updated = dist_op_container.update_dims_mapping(dist_op) + changed = updated or changed + # TODO(ljz) remove the below code once we introduce general reshard to replace specifc distopimpls + reverted = dist_op_container.mapping_to_dist_operator_impl( + dist_op, original_op_dist_attr + ) + _logger.debug( + "Op [{}] use dist op impl [{}] idx [{}].".format( + dist_op.serial_op.type, + dist_op.dist_attr.impl_type, + dist_op.dist_attr.impl_idx, + ) + ) + return changed and not (reverted) + + class Completer: def __init__(self, dist_context): assert dist_context is not None self._dist_context = dist_context self._has_prepared = False - self._logger = get_logger(logging.INFO, "Completer") def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): changed = False @@ -205,22 +272,20 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): def _update_op_node_dims_mapping(self, op_node, fwd=True): changed = False + op_desc = op_node.op() + + # step0: skip corner cases if (not op_node.is_op()) or (op_node.op() is None): return False # Skip reader op - op_desc = op_node.op() - if ( - op_desc.type() == "create_py_reader" - or op_desc.type() == "create_double_buffer_reader" - or op_desc.type() == "while" - or op_desc.type() == "read" - ): + if op_desc.type() in __skip_dims_mapping_op__: return False + dist_op = self._dist_context.get_dist_op_for_graph(op_node) op_dist_attr = dist_op.dist_attr original_op_dist_attr = copy.deepcopy(op_dist_attr) - # step 1: merge the dims mappings of in - # dist_op with corresponding tensors + + # step 1: merge the dims mappings from tensor nodes to op nodes if fwd: node_list = op_node.inputs else: @@ -277,38 +342,53 @@ def _update_op_node_dims_mapping(self, op_node, fwd=True): ) changed = True - # step 2: infer distributed attributes in dist_op - # Find the most compatible implementations from the distributed operator - op_dist_impls = find_compatible_distributed_operator_impls( - dist_op, fwd=True - ) - if op_dist_impls is not None: - not_compatible = True - backup_op_dist_attr = copy.deepcopy(op_dist_attr) - backup_changed = changed - for op_dist_impl in op_dist_impls: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if ( - op_dist_impl.is_auto_compatible(dist_op) - and dist_op.validate_dist_attr() - ): - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx - not_compatible = False - break - else: - dist_op.dist_attr = backup_op_dist_attr - changed = backup_changed - if not_compatible: + # step 2: Infer & Update dims mapping of op node using SPMD Rule. + if _can_apply_infer_spmd_rule(dist_op): + _logger.debug( + "Op [{}] update dims mapping using New InferSPMD Rule.".format( + dist_op.serial_op.type + ) + ) + return _update_op_dims_mapping_and_distoperatorimpl( + dist_op, original_op_dist_attr, changed + ) + else: + _logger.debug( + "Op [{}] update dims mapping using Original DistOp Rule.".format( + dist_op.serial_op.type + ) + ) + # update_op_dims_mapping_v1() + op_dist_impls = find_compatible_distributed_operator_impls( + dist_op, fwd=True + ) + if op_dist_impls is not None: + not_compatible = True + backup_op_dist_attr = copy.deepcopy(op_dist_attr) + backup_changed = changed + for op_dist_impl in op_dist_impls: + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if ( + op_dist_impl.is_auto_compatible(dist_op) + and dist_op.validate_dist_attr() + ): + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + not_compatible = False + break + else: + dist_op.dist_attr = backup_op_dist_attr + changed = backup_changed + if not_compatible: + dist_op.dist_attr = original_op_dist_attr + changed = False + else: dist_op.dist_attr = original_op_dist_attr changed = False - else: - dist_op.dist_attr = original_op_dist_attr - changed = False - return changed + return changed def _update_dims_mapping_between_graphs(self): changed = False @@ -900,7 +980,7 @@ def complete_forward_annotation(self, serial_main_program=None): # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() else: - self._logger.info("Default distributed attributed will be set.") + _logger.info("Default distributed attributed will be set.") self._dist_context.initialize(with_graph=False) # A fast and special completion for data parallel self._update_dist_attr_for_dp() diff --git a/python/paddle/distributed/auto_parallel/static/dist_op.py b/python/paddle/distributed/auto_parallel/static/dist_op.py index a728b55697bfa..d60457054245e 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_op.py +++ b/python/paddle/distributed/auto_parallel/static/dist_op.py @@ -58,8 +58,10 @@ def dist_attr(self, dist_attr): def get_serial_input(self, name): if self._serial_op.type == "create_py_reader": tensor = None - else: + elif self._serial_op.block._find_var_recursive(name) is not None: tensor = self._serial_op.block._var_recursive(name) + else: + tensor = None return tensor def get_serial_output(self, name): diff --git a/python/paddle/distributed/auto_parallel/static/operators/__init__.py b/python/paddle/distributed/auto_parallel/static/operators/__init__.py index 8efb6cf068569..08482454007f9 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/static/operators/__init__.py @@ -17,6 +17,8 @@ from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from .common import find_compatible_distributed_operator_impls +from .common import find_distributed_operator_impl_container + from . import dist_embedding from . import dist_matmul from . import dist_reshape @@ -39,3 +41,4 @@ from . import dist_scale from . import dist_dropout from . import dist_flash_attn +from . import dist_layer_norm diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index fbbb5dd12b789..b134599179260 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -13,22 +13,36 @@ # limitations under the License import abc +import logging +from paddle.base.log_helper import get_logger from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from ..dist_attribute import OperatorDistAttr from ..process_group import new_process_group -from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op +from ..utils import ( + _get_comm_group, + _get_corresponding_rank, + compute_compatible_dims_mapping, + is_optimize_op, +) + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) _g_distributed_operator_impl_containers = {} _g_elementwise_ops = [ + "assign", "elementwise", "gelu", - "dropout", + # "dropout", + "scale", + "relu", "cast", - "gather", - "concat", + # "gather", + # "concat", "fused_softmax_mask_upper_triangle", ] BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} @@ -62,7 +76,7 @@ def is_elementwise_op(op_type): return False -class DistributedOperatorImplContainer: +class DistributedOperatorImplContainer(abc.ABC): def __init__(self, op_type): self._type = op_type self._impls = [] @@ -111,6 +125,19 @@ def get_compatible_impls(self, dist_op): compatible_impls.append(impl) return compatible_impls + # (NOTE) Currently, both DistributedOperatorImplContainer and DistributedOperatorImpl have update_dims_mapping method. + # But this method is supposed to be maitained by DistributedOperatorImplContainer, and we are ongoing adding method + # to DistributedOperatorImplContainer and removing those in DistributedOperatorImpl. + # @abc.abstractmethod + def update_dims_mapping(self, dist_op): + raise NotImplementedError("Please Implement this method in Subclass.") + + # (NOTE) Currently we has limited DistributedOperatorImpls for an op to deal with different parallel patterns of this op. + # This function help to choose the correct DistributedOperatorImpl based on the result from InferSPMD. + # @abc.abstractmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + raise NotImplementedError("Please Implement this method in Subclass.") + class DistributedOperatorImpl(abc.ABC): def __init__(self, name): @@ -144,14 +171,17 @@ def idx(self): def idx(self, impl_idx): self._idx = impl_idx + # to be deprecated @abc.abstractmethod def is_input_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") + # to be deprecated @abc.abstractmethod def is_output_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") + # to be deprecated @abc.abstractmethod def is_auto_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") @@ -166,6 +196,7 @@ def forward(dist_ctx, *args, **kwargs): def backward(dist_ctx, *grad_outputs, **kwargs): raise NotImplementedError("Please Implement this method in Subclass.") + # to be deprecated def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") @@ -272,6 +303,35 @@ def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True): return best_compatible_impl +def find_distributed_operator_impl_container(dist_op): + """ + Return a unique container for dist op. + If not specific container found, default container will be return. + """ + op_type = dist_op.serial_op.type + + # Op has a match container + dist_op_impl_container = get_distributed_operator_impl_container(op_type) + if dist_op_impl_container is None: + # if op is register to elemwise spmd rule and has NO specific container implemented + if is_elementwise_op(op_type): + dist_op_impl_container = get_distributed_operator_impl_container( + "elementwise" + ) + # default container for all bottom line cases + else: + dist_op_impl_container = get_distributed_operator_impl_container( + "default" + ) + + _logger.debug( + "Op [{}] Complete DistAttr using {}".format( + op_type, type(dist_op_impl_container).__name__ + ) + ) + return dist_op_impl_container + + def is_parameter_related(varname, block, dist_context=None): # TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are be renamed if ".subprog_" in varname: @@ -539,3 +599,97 @@ def is_in_backward_phase(dist_ctx): # we use this FLAG to distinguish these two phases temporarily. return dist_ctx.dist_op_context.in_backward_phase() + + +def merge_forward_backward_dims_mapping(fw_results, bw_results): + ninputs = len(fw_results[0]) + noutputs = len(fw_results[1]) + infered_input_dims_mappings = [] + infered_output_dims_mappings = [] + + for i in range(ninputs): + compatible_dims_mapping = compute_compatible_dims_mapping( + [fw_results[0][i].dims_mapping, bw_results[0][i].dims_mapping] + ) + infered_input_dims_mappings.append(compatible_dims_mapping) + + for i in range(noutputs): + compatible_dims_mapping = compute_compatible_dims_mapping( + [fw_results[1][i].dims_mapping, bw_results[1][i].dims_mapping] + ) + infered_output_dims_mappings.append(compatible_dims_mapping) + return infered_input_dims_mappings, infered_output_dims_mappings + + +def update_op_dims_mapping( + dist_op, + input_arg_names, + infered_input_dims_mappings, + output_arg_names, + infered_output_dims_mappings, +): + op_dist_attr = dist_op.dist_attr + changed = False + assert len(input_arg_names) == len( + infered_input_dims_mappings + ), "dims mapping is NOT Match, infered [{}], orignal: [{}]; dist op: [{}]".format( + len(infered_input_dims_mappings), len(input_arg_names), str(dist_op) + ) + assert len(output_arg_names) == len( + infered_output_dims_mappings + ), "dims mapping is NOT Match, infered [{}], orignal: [{}]; dist op: [{}]".format( + len(infered_output_dims_mappings), len(output_arg_names), str(dist_op) + ) + + for i in range(len(input_arg_names)): + original_dims_mapping = op_dist_attr.get_input_dims_mapping( + input_arg_names[i] + ) + infered_dims_mapping = infered_input_dims_mappings[i] + if (infered_dims_mapping is not None) and ( + original_dims_mapping != infered_dims_mapping + ): + _logger.debug( + "Changed: Op [{}], name [{}], Original [{}], Infered [{}]".format( + dist_op.serial_op.type, + input_arg_names[i], + original_dims_mapping, + infered_dims_mapping, + ) + ) + changed = True + op_dist_attr.set_input_dims_mapping( + input_arg_names[i], infered_dims_mapping + ) + + for i in range(len(output_arg_names)): + original_dims_mapping = op_dist_attr.get_output_dims_mapping( + output_arg_names[i] + ) + infered_dims_mapping = infered_output_dims_mappings[i] + if (infered_dims_mapping is not None) and ( + original_dims_mapping != infered_dims_mapping + ): + _logger.debug( + "Changed: Op [{}], name [{}], Original [{}], Infered [{}]".format( + dist_op.serial_op.type, + output_arg_names[i], + original_dims_mapping, + infered_dims_mapping, + ) + ) + changed = True + op_dist_attr.set_output_dims_mapping( + output_arg_names[i], infered_dims_mapping + ) + + return changed + + +def get_default_distributed_operator_impl(): + dist_op_default_impl_container = get_distributed_operator_impl_container( + "default" + ) + num_impls = len(dist_op_default_impl_container.impls) + assert num_impls == 1, f"Default dist op has [{num_impls}] impls" + return dist_op_default_impl_container.get_impl(0) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_assign.py b/python/paddle/distributed/auto_parallel/static/operators/dist_assign.py index 13327ef511884..e4878b4707357 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_assign.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_assign.py @@ -13,12 +13,7 @@ # limitations under the License. from ..utils import compute_compatible_and_update_dim_mapping -from .common import ( - DistributedOperatorImpl, - DistributedOperatorImplContainer, - register_distributed_operator_impl, - register_distributed_operator_impl_container, -) +from .common import DistributedOperatorImpl, DistributedOperatorImplContainer from .dist_default import DistributedDefaultImpl0 @@ -27,7 +22,8 @@ def __init__(self, op_type): super().__init__(op_type) -register_distributed_operator_impl_container(DistributedAssign("assign")) +# TODO reomve assign dist op +# register_distributed_operator_impl_container(DistributedAssign("assign")) class DistributedAssignImpl(DistributedOperatorImpl): @@ -91,4 +87,4 @@ def backward(ctx, *args, **kwargs): DistributedDefaultImpl0.backward(ctx, *args, **kwargs) -register_distributed_operator_impl("assign", DistributedAssignImpl("assign")) +# register_distributed_operator_impl("assign", DistributedAssignImpl("assign")) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py index ec3692f0385b5..ca8a2a0bcd80d 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py @@ -11,33 +11,90 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License - import logging import paddle +from paddle.base.log_helper import get_logger from paddle.framework import core from paddle.utils import unique_name -from ....utils.log_utils import get_logger - -_logger = get_logger(logging.INFO) from ...random import determinate_rng, is_enable_auto_rand_ctrl +from ..completion import get_phi_spmd_rule from ..utils import ( + get_dist_tensor_spec, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, ) from .common import ( DistributedOperatorImplContainer, + merge_forward_backward_dims_mapping, register_distributed_operator_impl, register_distributed_operator_impl_container, + update_op_dims_mapping, ) from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + class DistributedDropout(DistributedOperatorImplContainer): def __init__(self, op_type): super().__init__(op_type) + @staticmethod + def update_dims_mapping(dist_op): + # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) + op_desc = dist_op.serial_op.desc + + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + mask_name = op_desc.output('Mask')[0] + # seed_name = op_desc.input('Seed')[0] // seed is a scalar and leave it to be unsharded + + x_spec = get_dist_tensor_spec(dist_op, x_name) + output_spec = get_dist_tensor_spec(dist_op, out_name, False) + + # step2: infer spmd + rule = get_phi_spmd_rule("dropout") + # tensor order following order in PHI defition + fw_results = rule.infer_forward(x_spec) + bw_results = rule.infer_backward(x_spec, output_spec) + + # step3: merge fw & bw results + ( + infered_input_dims_mappings, + infered_output_dims_mappings, + ) = merge_forward_backward_dims_mapping(fw_results, bw_results) + + # step4: update dist_attr + # tensor order following order in PHI defition + changed = update_op_dims_mapping( + dist_op, + [x_name], + infered_input_dims_mappings, + [out_name], + infered_output_dims_mappings, + ) + + # step5: update mask and seed dropout special + if changed: + dist_op.dist_attr.set_output_dims_mapping( + mask_name, infered_output_dims_mappings[0] + ) + + return changed + + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + # all dropout op use Dropout with Random Control dist operator impl. + op_dist_attr = dist_op.dist_attr + op_dist_attr.impl_type = "dropout" + op_dist_attr.impl_idx = 0 + + return False + register_distributed_operator_impl_container(DistributedDropout("dropout")) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py index 5c11dfba08fe1..857eda7c79aad 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_eltwise.py @@ -14,6 +14,7 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole +from ..completion import get_phi_spmd_rule from ..cost import ( _g_op_cost_factory, build_comp_costs_from_descs, @@ -23,14 +24,18 @@ from ..utils import ( compute_compatible_dim_mapping, compute_compatible_dims_mapping, + get_dist_tensor_spec, ) from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, + get_default_distributed_operator_impl, is_elementwise_op, is_parameter_related, + merge_forward_backward_dims_mapping, register_distributed_operator_impl, register_distributed_operator_impl_container, + update_op_dims_mapping, ) from .dist_default import DistributedDefaultImpl0 @@ -39,6 +44,68 @@ class DistributedElementwise(DistributedOperatorImplContainer): def __init__(self, op_type): super().__init__(op_type) + @staticmethod + def update_dims_mapping(dist_op): + # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) + op_desc = dist_op.serial_op.desc + assert ( + len(op_desc.input_arg_names()) >= 1 + ), "elementwsie op [{}] has [{}] inputs".format( + op_desc.type, len(op_desc.input_arg_names()) + ) + input_arg_names = op_desc.input_arg_names() + assert ( + len(op_desc.output_arg_names()) == 1 + ), "elementwsie op [{}] has [{}] outputs".format( + str(dist_op.serial_op), len(op_desc.output_arg_names()) + ) + output_arg_name = op_desc.output_arg_names()[0] + num_inputs = len(input_arg_names) + + # TODO (zhangyichen) replace dist tensor spece by dist tensor in future. + input_specs = [] + for i in range(num_inputs): + input_specs.append( + get_dist_tensor_spec(dist_op, input_arg_names[i]) + ) + output_spec = get_dist_tensor_spec(dist_op, output_arg_name, False) + + # step2: infer spmd + # TODO reivse me + op_type = op_desc.type() + rule = get_phi_spmd_rule(op_type) + fw_results = rule.infer_forward(*input_specs) + bw_results = rule.infer_backward(*input_specs, output_spec) + + # step3: merge fw & bw results + ( + infered_input_dims_mappings, + infered_output_dims_mappings, + ) = merge_forward_backward_dims_mapping(fw_results, bw_results) + + # step4: update dist_attr + # tensor order following order in PHI defition + changed = update_op_dims_mapping( + dist_op, + input_arg_names, + infered_input_dims_mappings, + [output_arg_name], + infered_output_dims_mappings, + ) + + return changed + + # NOTE this function will be remove once we use local reshard to replace distopimpls + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + # all elementwise op use default dist operator impl. + op_dist_attr = dist_op.dist_attr + default_impl = get_default_distributed_operator_impl() + op_dist_attr.impl_type = default_impl.type + op_dist_attr.impl_idx = default_impl.idx + + return False + register_distributed_operator_impl_container( DistributedElementwise("elementwise") diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py index 482fb4f443c67..d83beb82cd12a 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py @@ -12,11 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -import logging - -from ....utils.log_utils import get_logger - -_logger = get_logger(logging.INFO) from ...random import determinate_rng, is_enable_auto_rand_ctrl from .common import ( DistributedOperatorImplContainer, diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py index 7ddff7c9c2336..5f2186575c24e 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_fused_dropout_add.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License - import logging import paddle +from paddle.base.log_helper import get_logger from paddle.framework import core from paddle.utils import unique_name -from ....utils.log_utils import get_logger - -_logger = get_logger(logging.INFO) from ...random import determinate_rng, is_enable_auto_rand_ctrl from ..utils import ( naive_set_dist_op_attr_for_program_by_mesh_and_mapping, @@ -33,6 +30,10 @@ ) from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + class DistributedDropout(DistributedOperatorImplContainer): def __init__(self, op_type): diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_layer_norm.py b/python/paddle/distributed/auto_parallel/static/operators/dist_layer_norm.py new file mode 100644 index 0000000000000..dcd1518dcd13d --- /dev/null +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_layer_norm.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +import logging + +from paddle.base.log_helper import get_logger + +from ..completion import get_phi_spmd_rule +from ..utils import get_dist_tensor_spec, is_dim_shard +from .common import ( + DistributedOperatorImplContainer, + get_default_distributed_operator_impl, + merge_forward_backward_dims_mapping, + register_distributed_operator_impl_container, + update_op_dims_mapping, +) + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + + +class DistributedLayerNorm(DistributedOperatorImplContainer): + def __init__(self, op_type): + super().__init__(op_type) + + @staticmethod + def update_dims_mapping(dist_op): + # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) + op_desc = dist_op.serial_op.desc + + x_name = op_desc.input('X')[0] + scale_name = op_desc.input('Scale')[0] + bias_name = op_desc.input('Bias')[0] + y_name = op_desc.output('Y')[0] + var_name = op_desc.output('Variance')[0] + mean_name = op_desc.output('Mean')[0] + begin_norm_axis = op_desc.attr('begin_norm_axis') + + x_spec = get_dist_tensor_spec(dist_op, x_name) + scale_spec = get_dist_tensor_spec(dist_op, scale_name) + bias_spec = get_dist_tensor_spec(dist_op, bias_name) + y_spec = get_dist_tensor_spec(dist_op, y_name, False) + var_spec = get_dist_tensor_spec(dist_op, var_name, False) + mean_spec = get_dist_tensor_spec(dist_op, mean_name, False) + + # step2: infer spmd + rule = get_phi_spmd_rule("layer_norm") + # tensor order following order in PHI defition + fw_results = rule.infer_forward( + x_spec, scale_spec, bias_spec, 1.0, begin_norm_axis + ) + bw_results = rule.infer_backward( + x_spec, + scale_spec, + bias_spec, + y_spec, + var_spec, + mean_spec, + 1.0, + begin_norm_axis, + ) + + # step3: merge fw & bw results + ( + infered_input_dims_mappings, + infered_output_dims_mappings, + ) = merge_forward_backward_dims_mapping(fw_results, bw_results) + + # step4: update dist_attr + # tensor order following order in PHI defition + changed = update_op_dims_mapping( + dist_op, + [x_name, scale_name, bias_name], + infered_input_dims_mappings, + [y_name, var_name, mean_name], + infered_output_dims_mappings, + ) + + return changed + + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + begin_norm_axis = op_desc.attr('begin_norm_axis') + + # sharded on begin_norm_axis + x_name = op_desc.input('X')[0] + x_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(x_name) + ) + if (begin_norm_axis > 0) and is_dim_shard( + x_dims_mapping[begin_norm_axis] + ): + # TODO (ljz) support sharding on `begin_norm_axis` + _logger.info( + "sharding on `begin_norm_axis` is not supported yet, we resharded it as replicated" + ) + x_dims_mapping[begin_norm_axis] = -1 + op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping) + + param_names = [op_desc.input('Scale')[0], op_desc.input('Bias')[0]] + for p_name in param_names: + p_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(p_name) + ) + p_dims_mapping[begin_norm_axis] = -1 + op_dist_attr.set_input_dims_mapping(p_name, p_dims_mapping) + + y_name = op_desc.output('Y')[0] + y_dims_mapping = copy.deepcopy( + op_dist_attr.get_output_dims_mapping(y_name) + ) + y_dims_mapping[begin_norm_axis] = -1 + op_dist_attr.set_input_dims_mapping(y_name, y_dims_mapping) + + # default impl + default_impl = get_default_distributed_operator_impl() + op_dist_attr.impl_type = default_impl.type + op_dist_attr.impl_idx = default_impl.idx + + return False + + +register_distributed_operator_impl_container(DistributedLayerNorm("layer_norm")) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py index 1386c5e661cc8..3568c928c16e7 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py @@ -23,6 +23,7 @@ from paddle.framework import core from paddle.utils import unique_name +from ..completion import get_phi_spmd_rule from ..cost import ( MatmulGradOpCost, MatmulOpCost, @@ -43,6 +44,7 @@ _get_corresponding_rank, compute_compatible_and_update_dim_mapping, compute_compatible_dims_mapping, + get_dist_tensor_spec, is_dim_replicate, is_dim_shard, is_valid_list_index, @@ -54,9 +56,11 @@ gradient_synchronization, infer_shape, is_parameter_related, + merge_forward_backward_dims_mapping, register_distributed_operator_impl, register_distributed_operator_impl_container, set_comm_op_dist_attr_for_program, + update_op_dims_mapping, ) from .dist_default import DistributedDefaultImpl0 @@ -518,10 +522,112 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): ) +def update_dims_mapping_matmul(dist_op): + # TODO (zhangyichen) provide a clean api for this. + # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) + op_desc = dist_op.serial_op.desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + out_name = op_desc.output('Out')[0] + if op_desc.type() == "matmul_v2": + trans_x = op_desc.attr('trans_x') + trans_y = op_desc.attr('trans_y') + elif op_desc.type() == "matmul": + trans_x = op_desc.attr('transpose_X') + trans_y = op_desc.attr('transpose_Y') + else: # mul + trans_x = False + trans_y = False + + # TODO (zhangyichen) replace dist tensor spece by dist tensor in future. + x_spec = get_dist_tensor_spec(dist_op, x_name) + y_spec = get_dist_tensor_spec(dist_op, y_name) + out_spec = get_dist_tensor_spec(dist_op, out_name, False) + + # step2: infer spmd + rule = get_phi_spmd_rule("matmul") + # tensor order following order in PHI defition + fw_results = rule.infer_forward(x_spec, y_spec, trans_x, trans_y) + bw_results = rule.infer_backward(x_spec, y_spec, out_spec, trans_x, trans_y) + + # step3: merge fw & bw results + ( + infered_input_dims_mappings, + infered_output_dims_mappings, + ) = merge_forward_backward_dims_mapping(fw_results, bw_results) + + # step4: update dist_attr + # tensor order following order in PHI defition + input_arg_names = [x_name, y_name] + output_arg_names = [out_name] + changed = update_op_dims_mapping( + dist_op, + input_arg_names, + infered_input_dims_mappings, + output_arg_names, + infered_output_dims_mappings, + ) + + return changed + + +def mapping_to_dist_operator_impl_matmul(dist_op, original_op_dist_attr): + reverted = False + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) + if op_desc.type() == "matmul_v2": + trans_x = op_desc.attr('trans_x') + trans_y = op_desc.attr('trans_y') + elif op_desc.type() == "matmul": + trans_x = op_desc.attr('transpose_X') + trans_y = op_desc.attr('transpose_Y') + else: # mul + trans_x = False + trans_y = False + + op_dist_attr.impl_type = op_desc.type() + + # [m,k] * [k,n] --> [m, n] + # m_axis_dim = x_dims_mapping[-1] if trans_x else x_dims_mapping[-2] + k_axis_dim = x_dims_mapping[-2] if trans_x else x_dims_mapping[-1] + n_axis_dim = y_dims_mapping[-2] if trans_y else y_dims_mapping[-1] + + # col parallel matmul + if is_dim_replicate(k_axis_dim) and is_dim_shard(n_axis_dim): + op_dist_attr.impl_idx = 0 + # row parallel matmul + elif is_dim_shard(k_axis_dim) and is_dim_replicate(n_axis_dim): + op_dist_attr.impl_idx = 1 + # k, n unsharded matmul + elif is_dim_replicate(n_axis_dim) and is_dim_replicate(k_axis_dim): + op_dist_attr.impl_idx = 2 + # TODO support new dist op impl: m (not broadcast axis) sharded, backward need allreduce on Y + else: + dist_op.dist_attr = original_op_dist_attr + reverted = True + + return reverted + + class DistributedMatmul(DistributedOperatorImplContainer): def __init__(self, op_type): super().__init__(op_type) + @staticmethod + def update_dims_mapping(dist_op): + return update_dims_mapping_matmul(dist_op) + + # NOTE this function will be remove once we use local reshard to replace distopimpls + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + return mapping_to_dist_operator_impl_matmul( + dist_op, original_op_dist_attr + ) + register_distributed_operator_impl_container(DistributedMatmul("matmul")) @@ -1335,6 +1441,17 @@ class DistributedMatmulV2(DistributedOperatorImplContainer): def __init__(self, op_type): super().__init__(op_type) + @staticmethod + def update_dims_mapping(dist_op): + return update_dims_mapping_matmul(dist_op) + + # NOTE this function will be remove once we use local reshard to replace distopimpls + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + return mapping_to_dist_operator_impl_matmul( + dist_op, original_op_dist_attr + ) + register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2")) @@ -2154,6 +2271,17 @@ class DistributedMul(DistributedOperatorImplContainer): def __init__(self, op_type): super().__init__(op_type) + @staticmethod + def update_dims_mapping(dist_op): + return update_dims_mapping_matmul(dist_op) + + # NOTE this function will be remove once we use local reshard to replace distopimpls + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + return mapping_to_dist_operator_impl_matmul( + dist_op, original_op_dist_attr + ) + register_distributed_operator_impl_container(DistributedMul("mul")) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py b/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py index ba74be866c1ee..85abed9558f4e 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_reduce_sum_p.py @@ -12,19 +12,126 @@ # See the License for the specific language governing permissions and # limitations under the License +import copy + from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole +from ..completion import get_phi_spmd_rule from ..dist_attribute import OperatorDistAttr from ..process_group import new_process_group -from ..utils import set_dist_op_desc_original_id +from ..utils import ( + get_dist_tensor_spec, + is_dim_shard, + set_dist_op_desc_original_id, +) from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, + get_default_distributed_operator_impl, + merge_forward_backward_dims_mapping, register_distributed_operator_impl, register_distributed_operator_impl_container, + update_op_dims_mapping, ) +class DistributedReduceSum(DistributedOperatorImplContainer): + def __init__(self, op_type): + super().__init__(op_type) + + @staticmethod + def update_dims_mapping(dist_op): + # step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args) + + op_desc = dist_op.serial_op.desc + assert ( + len(op_desc.input_arg_names()) == 1 + ), "reduce_sum op [{}] has [{}] inputs".format( + op_desc.type, len(op_desc.input_arg_names()) + ) + input_arg_name = op_desc.input_arg_names()[0] + assert ( + len(op_desc.output_arg_names()) == 1 + ), "reduce_sum op [{}] has [{}] outputs".format( + op_desc.type, len(op_desc.output_arg_names()) + ) + output_arg_name = op_desc.output_arg_names()[0] + keep_dim = op_desc.attr('keep_dim') + dims = op_desc.attr('dim') + + # TODO (zhangyichen) replace dist tensor spece by dist tensor in future. + input_spec = get_dist_tensor_spec(dist_op, input_arg_name) + output_spec = get_dist_tensor_spec(dist_op, output_arg_name, False) + # len(dims) == 0 means reduce_all + if len(dims) == 0: + dims = list(range(len(input_spec.shape))) + + # step2: infer spmd + rule = get_phi_spmd_rule("reduce_sum") + fw_results = rule.infer_forward(input_spec, dims, keep_dim) + bw_results = rule.infer_backward( + input_spec, output_spec, dims, keep_dim + ) + # step3: merge fw & bw results + ( + infered_input_dims_mappings, + infered_output_dims_mappings, + ) = merge_forward_backward_dims_mapping(fw_results, bw_results) + + # step4: update dist_attr + # tensor order following order in PHI defition + changed = update_op_dims_mapping( + dist_op, + [input_arg_name], + infered_input_dims_mappings, + [output_arg_name], + infered_output_dims_mappings, + ) + + return changed + + # NOTE this function will be remove once we use local reshard to replace distopimpls + @staticmethod + def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr): + op_dist_attr = dist_op.dist_attr + op_desc = dist_op.serial_op.desc + input_name = op_desc.input_arg_names()[0] + input_dims_mapping = copy.deepcopy( + op_dist_attr.get_input_dims_mapping(input_name) + ) + axes = op_desc.attr('dim') + + op_dist_attr = dist_op.dist_attr + reverted = False + + def is_partial_reduce(axes, dims_mapping): + # FIXME(ljz) Hack for performance: + # if the reduce result is a scalar, it is the loss reduce in GPT case, + # and if any axis of reduce input is sharded, the result loss would be partial. + # BUT we keep the loss as partial instead of allreduce it for performance, since it would effect the backward. + # we should use an optimization pass for the Hack in future. + if len(axes) != 0 and (len(axes) < len(dims_mapping)): + for axis in axes: + if is_dim_shard(dims_mapping[axis]): + return True # reverted + return False + + # if reduce_axis is sharded, the output is partial and need to be allreduce + if is_partial_reduce(axes, input_dims_mapping): + # TODO (ljz) support reduce where the reduce_axis is sharded + dist_op.dist_attr = original_op_dist_attr + reverted = True + # if reduce_axis is unsharded, NO extra operator need. + else: + default_impl = get_default_distributed_operator_impl() + op_dist_attr.impl_type = default_impl.type + op_dist_attr.impl_idx = default_impl.idx + return reverted + + +register_distributed_operator_impl_container(DistributedReduceSum("reduce_sum")) + + class DistributedReduceSumPrimtive(DistributedOperatorImplContainer): def __init__(self, op_type): super().__init__(op_type) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_scale.py b/python/paddle/distributed/auto_parallel/static/operators/dist_scale.py index 66a35b1eadb68..b4a35a0382a69 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_scale.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_scale.py @@ -25,8 +25,6 @@ DistributedOperatorImpl, DistributedOperatorImplContainer, is_parameter_related, - register_distributed_operator_impl, - register_distributed_operator_impl_container, ) from .dist_default import DistributedDefaultImpl0 @@ -36,10 +34,11 @@ def __init__(self, op_type): super().__init__(op_type) -register_distributed_operator_impl_container(DistributedScale("scale")) -register_distributed_operator_impl_container(DistributedScale("fill_any_like")) -register_distributed_operator_impl_container(DistributedScale("where")) -register_distributed_operator_impl_container(DistributedScale("tanh")) +# TODO reomve assign dist op +# register_distributed_operator_impl_container(DistributedScale("scale")) +# register_distributed_operator_impl_container(DistributedScale("fill_any_like")) +# register_distributed_operator_impl_container(DistributedScale("where")) +# register_distributed_operator_impl_container(DistributedScale("tanh")) class DistributedScaleImpl(DistributedOperatorImpl): @@ -185,9 +184,9 @@ def backward(ctx, *args, **kwargs): DistributedDefaultImpl0.backward(ctx, *args, **kwargs) -register_distributed_operator_impl("scale", DistributedScaleImpl("scale")) -register_distributed_operator_impl( - "fill_any_like", DistributedScaleImpl("fill_any_like") -) -register_distributed_operator_impl("where", DistributedScaleImpl("where")) -register_distributed_operator_impl("tanh", DistributedScaleImpl("tanh")) +# register_distributed_operator_impl("scale", DistributedScaleImpl("scale")) +# register_distributed_operator_impl( +# "fill_any_like", DistributedScaleImpl("fill_any_like") +# ) +# register_distributed_operator_impl("where", DistributedScaleImpl("where")) +# register_distributed_operator_impl("tanh", DistributedScaleImpl("tanh")) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 2e41c6de99802..fac4df3d45144 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2499,3 +2499,12 @@ def wrap_data_for_completion( attrs[attr_name] = serial_op.desc.attr(attr_name) return input_specs, output_specs, attrs + + +def get_dist_tensor_spec(dist_op, name, is_input=True): + tensor_shape = dist_op.serial_op.block._var_recursive(name).shape + if is_input: + tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name) + else: + tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name) + return DistTensorSpec(tensor_shape, tensor_dist_attr) diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index 78b8ff0728c6d..cf034e33678aa 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -17,6 +17,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_reshape_rule MODULES test_reshape_rule) py_test_modules(test_default_data_parallel_rule MODULES test_default_data_parallel_rule) + py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/test_dist_assign.py b/test/auto_parallel/test_dist_assign.py index 030a6b1513888..5dfbffbce60b5 100644 --- a/test/auto_parallel/test_dist_assign.py +++ b/test/auto_parallel/test_dist_assign.py @@ -65,7 +65,7 @@ def test_dist_assign(self): for op in ops: if op.type == "assign": dist_op = dist_context.get_dist_op_for_program(op) - assert dist_op.dist_attr.impl_type == "assign" + assert dist_op.dist_attr.impl_type == "default" x_name = op.input_arg_names[0] out_name = op.output_arg_names[0]