diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 7be497318443a7..5e39b764fa96d7 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1817,6 +1817,7 @@ infer_meta : func : UnchangedInferMeta param : [out] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : relu_grad backward: relu_double_grad @@ -2234,6 +2235,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : square_grad backward : square_double_grad diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index a6ab3c3ec954f0..57703dbf6659bf 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2073,6 +2073,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : relu inplace : (x -> out) @@ -2458,6 +2459,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : square {dense -> dense}, square_sr {selected_rows -> selected_rows} diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc index 24d6bed03c52d0..3a9e422320210f 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.cc +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -309,6 +309,11 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}}; } +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad) { + return {{out_grad.dist_attr(), out_grad.dist_attr()}, {out_grad.dist_attr()}}; +} + SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h index 736aeec35ed0a0..188e557e6737b0 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.h +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -27,6 +27,9 @@ SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x); SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out); +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad); + SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y); diff --git a/test/auto_parallel/semi_auto_parallel_for_elementwise.py b/test/auto_parallel/semi_auto_parallel_for_elementwise.py index b7e3e30b89e562..24bf0c8be9e88b 100644 --- a/test/auto_parallel/semi_auto_parallel_for_elementwise.py +++ b/test/auto_parallel/semi_auto_parallel_for_elementwise.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed as dist +import paddle.nn.functional as F class TestElementwiseApiForSemiAutoParallel: @@ -27,17 +28,34 @@ def __init__(self): self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + paddle.seed(self._seed) + np.random.seed(self._seed) + def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + def test_unary_body(self, x_shape, out_shape, x_specs, unary_func): + x = paddle.randn(x_shape, self._dtype) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + dist_out = unary_func(dist_x) + out = unary_func(x) + self.check_tensor_eq(out, dist_out) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + def test_binary_body( self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func ): - paddle.seed(self._seed) - np.random.seed(self._seed) - x = paddle.randn(x_shape, self._dtype) y = paddle.randn(y_shape, self._dtype) x.stop_gradient = False @@ -129,6 +147,22 @@ def test_sub_x_y_shard_broadcast(self): binary_func=paddle.subtract, ) + def test_square_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.square, + ) + + def test_relu_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=F.relu, + ) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -141,6 +175,10 @@ def run_test_case(self): self.test_add_x_shard_broadcast() self.test_add_x_y_shard() self.test_add_x_y_shard_broadcast() + self.test_sub_x_shard() + self.test_sub_x_y_shard_broadcast() + self.test_square_x_shard() + self.test_relu_x_shard() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py index b83d9ffb87e7e8..68c44f206aa00a 100644 --- a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py +++ b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py @@ -18,7 +18,6 @@ import paddle import paddle.distributed as dist -import paddle.nn.functional as F class TestReplicatedSPmdApiForSemiAutoParallel: @@ -49,29 +48,6 @@ def create_local_and_dist_tensor_pair(self, np_array, sharding_specs): return local_t, dist_t - # input: phi::Tensor - # output: phi::Tensor - def test_relu(self): - x = np.random.random(size=[4, 4]).astype(self._dtype) - local_in, dist_in = self.create_local_and_dist_tensor_pair( - x, ['x', None] - ) - local_out = F.relu(local_in) - dist_out = F.relu(dist_in) - np.testing.assert_equal( - dist_out.dist_attr.dims_mapping, [-1, -1], verbose=True - ) - self.check_tensor_eq(local_out, dist_out) - - # test backward - local_out.backward() - dist_out.backward() - np.testing.assert_equal(dist_in.grad._local_shape, [2, 4], verbose=True) - np.testing.assert_equal( - dist_in.grad.dist_attr.dims_mapping, [0, -1], verbose=True - ) - self.check_tensor_eq(local_in.grad, dist_in.grad) - def test_mse_loss(self): x = np.random.random(size=[4, 4]).astype(self._dtype) y = np.random.random(size=[4]).astype(self._dtype) @@ -104,7 +80,6 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") - self.test_relu() self.test_mse_loss()