-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-Auto] Add reshape spmd rule #55177
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
a5efe9d
to
4370e87
Compare
Sorry to inform you that 4370e87's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
Sorry to inform you that 7b96d26's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
const std::vector<int64_t>& src_shape, | ||
const std::vector<int64_t>& tgt_shape) { | ||
std::vector<DimTrans*> ret; | ||
int64_t src_size = std::accumulate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src_size --> src_numel / src_nelem,
src_size is ambiguous with src_shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, use total_elem_num_src now.
// get the size of each output dimension, and get the | ||
// map from sharded input dimensions to output dimensions. | ||
std::vector<int64_t> dim_map_src2tgt(ndim, -1); | ||
std::vector<int64_t> out_shape(dim_trans.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is redundant to compute the output shape again here.
only DimTrans::Type::SPLIT need maintain the output shape segment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove the computing output parts now.
int64_t split_id() const; | ||
|
||
// get the splitted shape of the split_id_ dimension | ||
int64_t local_split_shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better local_split_shape --> local_axis_size ?
shape: [a,b,c]
axis_size: the value of a or b or c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, rename to "local_splitted_shape_value".
} | ||
|
||
// if one input dimension is sharded on a | ||
// unshardable mesh we need to reshard the input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to trick to calculate the input_dims_mapping_dst.
not need to introduce "reshard" into InferSPMD;
directly use shardable vector to remove "sharded" in input_dims_mapping_src.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove the "reshard" word.
std::vector<int64_t> dim_map_src2tgt(ndim, -1); | ||
std::vector<int64_t> out_shape(dim_trans.size()); | ||
for (int64_t i = 0, n = dim_trans.size(); i < n; i++) { | ||
std::pair<int64_t, DimTrans*> dim_size = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too trick to calculate the output_dims_mapping.
should main dim_map_tgt2src, and unshardedable map for output axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, remove the redundant code. Using "dim_map_tgt2src" will meet some bugs when input_dims_mapping should be set to replicated, so keep "dim_map_src2tgt" now, and it is more intuitive.
TensorDistAttr output_dist_attr(input_specs[0].dist_attr()); | ||
output_dist_attr.set_dims_mapping(dims_mapping_vec[1]); | ||
|
||
VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: think about print useful info about tensor axes and dims_mapping for debug:
idea1: construct einsum notation for debug and giving corresponding axes between input and output a specific character, therefore user could be notified that those axes are related.
idea2: print out the DimTrans and make the info readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, print the transformation info.
|
||
self.attrs = {"shape": [1, 72, 48, 4, 6]} | ||
|
||
def test_reshape_infer_forward(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unitest should include following cases:
- input axis directly map to output axis
- multiple input axes merge into single output axis, with shard on first input axis
- multiple input axes merge into single output axis, with shard on axis other than first input axis
- single input axis split into multiple output axes, with first output axis dividable
- single input axis split into multiple output axes, with first output axis non-dividable
- multiple input axes transform into multiple output axis, with shard on first input axis/shard on input axis other than the first axis/ first output axis dividable/ first output axis non-dividable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
static std::vector<DimTrans*> all_dim_trans; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use static global vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here "static" indicates that the global variable can only be used in this file. The global vector is used to store all transformation objects so that we can free them after inferring distributed attributes.
PR types
New features
PR changes
Others
Description
Pcard-70448
Add reshape spmd rule for auto parallel. This rule infers the output's distributed attribute with the following two steps: