Skip to content

Commit

Permalink
fix bugs and modify apis in unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuzyc committed Sep 22, 2023
1 parent 3b7f35d commit 385e423
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 34 deletions.
11 changes: 6 additions & 5 deletions paddle/phi/infermeta/spmd_rules/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis) {
return {{x_dist_attr_dst}, out_dist_attrs};
}

SpmdInfo SplitWithNumInferSpmdReverse(const DistMetaTensor& x,
const std::vector<DistMetaTensor*>& outs,
int num,
int axis) {
SpmdInfo SplitWithNumInferSpmdReverse(
const DistMetaTensor& x,
const std::vector<const DistMetaTensor*>& outs,
int num,
int axis) {
// Step0: Verify input args based on split logic
int nouts = outs.size();
int out_ndim = phi::vectorize(outs[0]->dims()).size();
Expand Down Expand Up @@ -204,7 +205,7 @@ SpmdInfo SplitInferSpmd(const DistMetaTensor& x,
}

SpmdInfo SplitInferSpmdReverse(const DistMetaTensor& x,
const std::vector<DistMetaTensor*>& outs,
const std::vector<const DistMetaTensor*>& outs,
const std::vector<int>& sections,
int axis) {
int num = sections.size();
Expand Down
11 changes: 6 additions & 5 deletions paddle/phi/infermeta/spmd_rules/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@ SpmdInfo SplitInferSpmd(const DistMetaTensor& x,
int axis);

SpmdInfo SplitInferSpmdReverse(const DistMetaTensor& x,
const std::vector<DistMetaTensor*>& outs,
const std::vector<const DistMetaTensor*>& outs,
const std::vector<int>& sections,
int axis);

SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis);

SpmdInfo SplitWithNumInferSpmdReverse(const DistMetaTensor& x,
const std::vector<DistMetaTensor*>& outs,
int num,
int axis);
SpmdInfo SplitWithNumInferSpmdReverse(
const DistMetaTensor& x,
const std::vector<const DistMetaTensor*>& outs,
int num,
int axis);

} // namespace distributed
} // namespace phi
54 changes: 30 additions & 24 deletions test/auto_parallel/spmd_rules/test_split_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class TestReductionSPMDRule(unittest.TestCase):
"""

def setUp(self):
self.rule = core.get_phi_spmd_rule("split")

x_shape = [64, 32, 48]
process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])

Expand All @@ -48,7 +46,7 @@ def test_single_mesh_dim(self):
self.attrs['axis'] = 1
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], list(self.attrs.values())
self.x_dist_tensor_spec, self.attrs['num'], self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand All @@ -69,7 +67,7 @@ def test_single_mesh_dim(self):
self.attrs['axis'] = 2
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], list(self.attrs.values())
self.x_dist_tensor_spec, self.attrs['sections'], self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand All @@ -90,7 +88,7 @@ def test_single_mesh_dim(self):
self.attrs['axis'] = 2
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], list(self.attrs.values())
self.x_dist_tensor_spec, self.attrs['sections'], self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -118,7 +116,7 @@ def test_single_mesh_dim(self):
self.attrs['axis'] = -2
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], list(self.attrs.values())
self.x_dist_tensor_spec, self.attrs['num'], self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand All @@ -144,7 +142,7 @@ def test_multi_mesh_dim(self):
self.attrs['axis'] = -1
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], list(self.attrs.values())
self.x_dist_tensor_spec, self.attrs['num'], self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -174,7 +172,7 @@ def test_multi_mesh_dim(self):
self.attrs['axis'] = 0
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], list(self.attrs.values())
self.x_dist_tensor_spec, self.attrs['sections'], self.attrs['axis']
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -214,9 +212,10 @@ def test_backward_single_mesh_dim(self):
self.out_spec_list[0].set_dims_mapping([0, -1, -1])
self.out_spec_list[1].set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['num'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -248,9 +247,10 @@ def test_backward_single_mesh_dim(self):
self.out_spec_list[1].set_dims_mapping([0, -1, -1])
self.out_spec_list[2].set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['sections'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -282,9 +282,10 @@ def test_backward_single_mesh_dim(self):
self.out_spec_list[1].set_dims_mapping([-1, -1, -1])
self.out_spec_list[2].set_dims_mapping([-1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['sections'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -319,9 +320,10 @@ def test_backward_single_mesh_dim(self):
self.out_spec_list[0].set_dims_mapping([0, -1, -1])
self.out_spec_list[1].set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['num'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -349,9 +351,10 @@ def test_backward_single_mesh_dim(self):
self.out_spec_list[0].set_dims_mapping([-1, 0, -1])
self.out_spec_list[1].set_dims_mapping([-1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['num'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -408,9 +411,10 @@ def test_backward_multi_mesh_dim(self):
self.out_spec_list[1].set_dims_mapping([0, 1, -1, -1])
self.out_spec_list[2].set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['num'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -451,9 +455,10 @@ def test_backward_multi_mesh_dim(self):
self.out_spec_list[1].set_dims_mapping([-1, 1, -1, -1])
self.out_spec_list[2].set_dims_mapping([-1, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['sections'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down Expand Up @@ -494,9 +499,10 @@ def test_backward_multi_mesh_dim(self):
self.out_spec_list[1].set_dims_mapping([-1, 1, -1, -1])
self.out_spec_list[2].set_dims_mapping([-1, -1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
self.x_dist_tensor_spec,
self.out_spec_list,
list(self.attrs.values()),
self.attrs['sections'],
self.attrs['axis'],
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
Expand Down

0 comments on commit 385e423

Please sign in to comment.