diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 067eb83f2c646..2a8df1e51a3ee 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -354,6 +354,7 @@ infer_meta : func : UnchangedInferMeta param: [x] + spmd_rule : ReductionGradInferSpmd kernel : func : max_grad composite : max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 7b422ce0fe285..77a9a57559a1a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -712,6 +712,7 @@ output : Tensor(out) infer_meta : func : ReduceIntArrayAxisInferMeta + spmd_rule: ReductionMaxInferSpmdDynamic kernel : func : max backward : max_grad diff --git a/paddle/phi/infermeta/spmd_rules/reduction.cc b/paddle/phi/infermeta/spmd_rules/reduction.cc index a45ae6822940f..a1fc0873a244a 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.cc +++ b/paddle/phi/infermeta/spmd_rules/reduction.cc @@ -152,6 +152,13 @@ SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x, x, axis.GetData(), keep_dim, static_cast(ReduceType::kRedSum)); } +SpmdInfo ReductionMaxInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + bool keep_dim) { + return ReductionInferSpmdBase( + x, axis.GetData(), keep_dim, static_cast(ReduceType::kRedMax)); +} + SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, const std::vector& axis, @@ -246,5 +253,25 @@ SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, return {{x_dist_attr, out_grad_dist_attr}, {x_grad_dist_attr}}; } +SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out, + const DistMetaTensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all) { + SpmdInfo spmd_info = + ReductionGradInferSpmd(x, out_grad, axis, keep_dim, reduce_all); + // NOTE(zhonghui): dist_attr of max/min out must be changed to Replicate if it + // is Partial, Otherwise each shard will generate a gradient and have a + // position of 1. But in fact, the gradient of max has only one position that + // is 1, and all other positions are zero. + TensorDistAttr out_dist_attr = out_grad.dist_attr(); + if (out_dist_attr.is_partial()) { + out_dist_attr.clean_partial_status(); + } + spmd_info.first.insert(spmd_info.first.begin() + 1, out_dist_attr); + return spmd_info; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/reduction.h b/paddle/phi/infermeta/spmd_rules/reduction.h index e010abbb1f60c..30144e6d7ca46 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.h +++ b/paddle/phi/infermeta/spmd_rules/reduction.h @@ -40,6 +40,10 @@ SpmdInfo ReductionSumInferSpmdDynamic(const DistMetaTensor& x, DataType dtype, bool keep_dim); +SpmdInfo ReductionMaxInferSpmdDynamic(const DistMetaTensor& x, + const IntArray& axis, + bool keep_dim); + SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, const std::vector& axis, @@ -51,5 +55,12 @@ SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, bool keep_dim, bool reduce_all); +SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out, + const DistMetaTensor& out_grad, + const IntArray& axis, + bool keep_dim, + bool reduce_all); + } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/semi_auto_parallel_for_reduction.py b/test/auto_parallel/semi_auto_parallel_for_reduction.py index 4b2e7d4bb026b..e96566075498e 100644 --- a/test/auto_parallel/semi_auto_parallel_for_reduction.py +++ b/test/auto_parallel/semi_auto_parallel_for_reduction.py @@ -93,6 +93,26 @@ def test_mean_x_shard(self): op_func=paddle.mean, ) + def test_max_x_shard(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 6], + x_specs=['x', None, None], + axis=1, + keepdim=False, + op_func=paddle.max, + ) + + def test_max_x_shard_on_axis(self): + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 6], + x_specs=[None, 'x', None], + axis=1, + keepdim=False, + op_func=paddle.max, + ) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -105,6 +125,8 @@ def run_test_case(self): self.test_sum_x_shard_on_axis() self.test_sum_x_shard_on_axis_keepdim() self.test_mean_x_shard() + self.test_max_x_shard() + self.test_max_x_shard_on_axis() if __name__ == '__main__': diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 77c0c555b4564..ff6e60c556c66 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -974,6 +974,38 @@ TEST(WhereRule, Ctor) { check_partial_dims(infered_dist_attrs.second[1], {0}); } +TEST(ReduceMaxRule, Ctor) { + std::vector mesh_shape = {2}; + std::vector process_ids = {0, 1}; + std::vector dim_names = {"x"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + // test forward + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping({-1, 0, -1}); + t_dist_attr.set_dynamic_dims({false, false, false}); + phi::distributed::DistMetaTensor x = + phi::distributed::DistMetaTensor(phi::make_ddim({4, 6, 8}), t_dist_attr); + IntArray axis = {1}; + bool keep_dim = false; + phi::distributed::SpmdInfo forward_info = + phi::distributed::ReductionMaxInferSpmdDynamic(x, axis, keep_dim); + check_dim_mapping(forward_info.second[0], {-1, -1}); + check_partial_dims(forward_info.second[0], {0}); + // test backward + phi::distributed::DistMetaTensor out = phi::distributed::DistMetaTensor( + phi::make_ddim({4, 8}), + PADDLE_GET_CONST(TensorDistAttr, forward_info.second[0])); + phi::distributed::DistMetaTensor out_grad = out; + phi::distributed::SpmdInfo backward_info = + phi::distributed::ReductionGradInferSpmd( + x, out, out_grad, {1}, false, false); + check_partial_dims(backward_info.first[1], {}); + check_dim_mapping(backward_info.second[0], {-1, -1, -1}); + check_partial_dims(backward_info.second[0], {}); +} + TEST(Numel, Ctor) { std::vector mesh_shape = {2, 2}; std::vector process_ids = {0, 1, 2, 3};