Skip to content

Commit

Permalink
【Semi-Auto】Adapt split spmd rule to phi (PaddlePaddle#57467)
Browse files Browse the repository at this point in the history
* adapt split rule to phi

* fix bugs and modify apis in unit test

* fix codestyle

* bug fix
  • Loading branch information
pkuzyc authored and Frida-a committed Oct 14, 2023
1 parent 9138b69 commit ed4812f
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 312 deletions.
5 changes: 0 additions & 5 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"

// TODO(ljz) Automatic this process in cmake file.
Expand All @@ -37,10 +36,6 @@ REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule);
REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);

// split rule
REGISTER_SPMD_RULE(split, SplitSPMDRule);
REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);

// transpose rule
REGISTER_SPMD_RULE(transpose, TransposeSPMDRule);

Expand Down
218 changes: 0 additions & 218 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc

This file was deleted.

This file was deleted.

10 changes: 10 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/spmd_rules/reduction.h"
#include "paddle/phi/infermeta/spmd_rules/replicated.h"
#include "paddle/phi/infermeta/spmd_rules/reshape.h"
#include "paddle/phi/infermeta/spmd_rules/split.h"

/**
* Design Notes:
Expand Down Expand Up @@ -485,5 +486,14 @@ PD_REGISTER_SPMD_RULE(
PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd),
PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse));

// split rule
PD_REGISTER_SPMD_RULE(split,
PD_INFER_SPMD(phi::distributed::SplitInferSpmd),
PD_INFER_SPMD(phi::distributed::SplitInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
split_with_num,
PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd),
PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse));

} // namespace distributed
} // namespace phi
Loading

0 comments on commit ed4812f

Please sign in to comment.