Skip to content

Commit

Permalink
[auto parallel] enable spmd rules for maximum (PaddlePaddle#58360)
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxly authored and zeroRains committed Nov 8, 2023
1 parent 06c087a commit 0eb5034
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, y]
spmd_rule: ElementwiseBinaryGradInferSpmd
kernel :
func : maximum_grad
composite : maximum_grad(x, y, out_grad, axis, x_grad, y_grad)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
spmd_rule : ElementwiseBinaryInferSpmd
kernel :
func : maximum
backward : maximum_grad
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/spmd_rules/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,
SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out_grad,
int64_t axis);
int64_t axis = -1);

} // namespace distributed
} // namespace phi
43 changes: 43 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,49 @@ def test_relu_x_shard(self):
unary_func=F.relu,
)

def test_maximum_x_shard(self):
self.test_binary_body(
x_shape=[16, 32],
y_shape=[16, 32],
out_shape=[16, 32],
x_specs=['x', None],
y_specs=[None, None],
binary_func=paddle.maximum,
)

def test_maximum_x_shard_broadcast(self):
self.test_binary_body(
x_shape=[16, 32],
y_shape=[2, 16, 32],
out_shape=[2, 16, 32],
x_specs=['x', None],
y_specs=[None, None, None],
binary_func=paddle.maximum,
)

def test_maximum_x_y_shard(self):
if self._backend == "cpu":
return

self.test_binary_body(
x_shape=[16, 32],
y_shape=[16, 32],
out_shape=[16, 32],
x_specs=['x', None],
y_specs=[None, 'x'],
binary_func=paddle.maximum,
)

def test_maximum_x_y_shard_broadcast(self):
self.test_binary_body(
x_shape=[4, 16, 32],
y_shape=[16, 32],
out_shape=[4, 16, 32],
x_specs=['x', None, None],
y_specs=[None, None],
binary_func=paddle.maximum,
)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
Expand Down

0 comments on commit 0eb5034

Please sign in to comment.