Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] add split spmd rule #55397

Merged
merged 3 commits into from
Jul 24, 2023
Merged

Conversation

pkuzyc
Copy link
Contributor

@pkuzyc pkuzyc commented Jul 13, 2023

PR types

New features

PR changes

Others

Description

Pcard-70448

Add split spmd rule for auto parallel. It infers the output dims mapping as following: the splitted axis cannot be sharded, set the dims mapping of splitted axis in input and outputs to -1. For other axes in output, set their dims mapping equal to input's one.

@paddle-bot
Copy link

paddle-bot bot commented Jul 13, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -145,6 +146,9 @@ REGISTER_SPMD_RULE(lookup_table_v2, EmbeddingSPMDRule);
REGISTER_SPMD_RULE(softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);

// split rule
REGISTER_SPMD_RULE(split, SplitSPMDRule);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name string for regstration should be op names which are defined in phi yaml.
for split them are : split_with_num & split.

and you should take care about the different in attribute among these two ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// step1: Build Einsum Notation
int64_t ndim = input_specs[0].shape().size();
Attribute section_attr = GetAttr("num_or_sections", attrs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow definition in phi yaml:

  • op : split_with_num
    args : (Tensor x, int num, Scalar(int) axis)
  • op : split
    args : (Tensor x, IntArray sections, Scalar(int) axis)

there is no num_or_sections

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

input_axes_vec.emplace_back(input_axes);

// get einsum notation for output
std::string output_axes(input_axes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should support input tensor sharded on split_axis.
the main logic for split spmd forward infer should be:

  1. for axes other than the split_axis, treat as broadcast axes, copy dims from input to outputs.
  2. for split_axis:
    the infered dim_mapping for all outputs are replicated
    the infered dst_dim_mapping for all outputs is also replicated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a special char for split axis, like: "k"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""

def setUp(self):
self.rule = get_spmd_rule("split")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for both split and split_with_num

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Jul 21, 2023

Sorry to inform you that fe38b69's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JZ-LIANG JZ-LIANG merged commit cf76e7a into PaddlePaddle:develop Jul 24, 2023
cqulilujia pushed a commit to cqulilujia/Paddle that referenced this pull request Jul 24, 2023
* add split spmd rule

* add pytest in cmake file

* small fix
wz1qqx pushed a commit to wz1qqx/Paddle that referenced this pull request Jul 31, 2023
* add split spmd rule

* add pytest in cmake file

* small fix
@pkuzyc pkuzyc deleted the split_rule branch August 14, 2023 02:18
jinjidejinmuyan pushed a commit to jinjidejinmuyan/Paddle that referenced this pull request Aug 30, 2023
* add split spmd rule

* add pytest in cmake file

* small fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants