Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] Add reshape spmd rule #55177

Merged
merged 8 commits into from
Aug 14, 2023
Merged

Conversation

pkuzyc
Copy link
Contributor

@pkuzyc pkuzyc commented Jul 5, 2023

PR types

New features

PR changes

Others

Description

Pcard-70448
Add reshape spmd rule for auto parallel. This rule infers the output's distributed attribute with the following two steps:

  1. Compute the transformation from the original shape to the target shape.
  2. Compute the output's distributed attribute according to the transformation from step 1.

@paddle-bot
Copy link

paddle-bot bot commented Jul 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@pkuzyc pkuzyc closed this Jul 5, 2023
@pkuzyc pkuzyc reopened this Jul 5, 2023
@pkuzyc pkuzyc force-pushed the reshape_rule branch 2 times, most recently from a5efe9d to 4370e87 Compare July 12, 2023 09:21
@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Jul 20, 2023

Sorry to inform you that 4370e87's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Aug 1, 2023

Sorry to inform you that 7b96d26's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

const std::vector<int64_t>& src_shape,
const std::vector<int64_t>& tgt_shape) {
std::vector<DimTrans*> ret;
int64_t src_size = std::accumulate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src_size --> src_numel / src_nelem,
src_size is ambiguous with src_shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, use total_elem_num_src now.

// get the size of each output dimension, and get the
// map from sharded input dimensions to output dimensions.
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
std::vector<int64_t> out_shape(dim_trans.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is redundant to compute the output shape again here.
only DimTrans::Type::SPLIT need maintain the output shape segment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove the computing output parts now.

int64_t split_id() const;

// get the splitted shape of the split_id_ dimension
int64_t local_split_shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better local_split_shape --> local_axis_size ?
shape: [a,b,c]
axis_size: the value of a or b or c

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, rename to "local_splitted_shape_value".

}

// if one input dimension is sharded on a
// unshardable mesh we need to reshard the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to trick to calculate the input_dims_mapping_dst.
not need to introduce "reshard" into InferSPMD;
directly use shardable vector to remove "sharded" in input_dims_mapping_src.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove the "reshard" word.

std::vector<int64_t> dim_map_src2tgt(ndim, -1);
std::vector<int64_t> out_shape(dim_trans.size());
for (int64_t i = 0, n = dim_trans.size(); i < n; i++) {
std::pair<int64_t, DimTrans*> dim_size =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too trick to calculate the output_dims_mapping.
should main dim_map_tgt2src, and unshardedable map for output axis.

Copy link
Contributor Author

@pkuzyc pkuzyc Aug 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove the redundant code. Using "dim_map_tgt2src" will meet some bugs when input_dims_mapping should be set to replicated, so keep "dim_map_src2tgt" now, and it is more intuitive.

TensorDistAttr output_dist_attr(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: think about print useful info about tensor axes and dims_mapping for debug:

idea1: construct einsum notation for debug and giving corresponding axes between input and output a specific character, therefore user could be notified that those axes are related.

idea2: print out the DimTrans and make the info readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, print the transformation info.


self.attrs = {"shape": [1, 72, 48, 4, 6]}

def test_reshape_infer_forward(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unitest should include following cases:

  1. input axis directly map to output axis
  2. multiple input axes merge into single output axis, with shard on first input axis
  3. multiple input axes merge into single output axis, with shard on axis other than first input axis
  4. single input axis split into multiple output axes, with first output axis dividable
  5. single input axis split into multiple output axes, with first output axis non-dividable
  6. multiple input axes transform into multiple output axis, with shard on first input axis/shard on input axis other than the first axis/ first output axis dividable/ first output axis non-dividable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

namespace distributed {
namespace auto_parallel {

static std::vector<DimTrans*> all_dim_trans;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use static global vector?

Copy link
Contributor Author

@pkuzyc pkuzyc Aug 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here "static" indicates that the global variable can only be used in this file. The global vector is used to store all transformation objects so that we can free them after inferring distributed attributes.

@JZ-LIANG JZ-LIANG merged commit a97b507 into PaddlePaddle:develop Aug 14, 2023
@pkuzyc pkuzyc deleted the reshape_rule branch February 6, 2024 02:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants