diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index ac02cd0fc87ac..3276a6afd2d00 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -183,6 +183,10 @@ void BindAutoParallel(py::module *m) { *m, "RToSReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "RToSReshardFunctionCrossMesh", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "SToRReshardFunction", ReshardFunction) .def(py::init<>()); diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index 3adf488efca4e..d342dc53a86c7 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -18,6 +18,7 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h" #include "paddle/phi/kernels/split_kernel.h" namespace phi { @@ -85,7 +86,57 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, SetDistProps(out, in.dims(), out_dist_attr); } +bool RToSReshardFunctionCrossMesh::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == + out_process_mesh.shape()); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + + return true; +} + +void RToSReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call RToSReshardFunctionCrossMesh Eval"; + const auto& in_dist_attr = in.dist_attr(); + + DistTensor tmp_result; + TensorDistAttr in_dist_attr_shard = in_dist_attr; + in_dist_attr_shard.set_dims_mapping(out_dist_attr.dims_mapping()); + RToSReshardFunction r_to_s_func; + PADDLE_ENFORCE( + r_to_s_func.IsSuitable(in, in_dist_attr_shard), + phi::errors::InvalidArgument( + "Invoke the r to s reshard function is not valid from %s to %s.", + tmp_result.dist_attr(), + out_dist_attr)); + r_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result); + + // Step 2: Same status from the input mesh to output mesh + SameStatusReshardFunction same_status_func; + PADDLE_ENFORCE( + same_status_func.IsSuitable(tmp_result, out_dist_attr), + phi::errors::InvalidArgument("Invoke the same status reshard function " + "is not valid from %s to %s.", + tmp_result.dist_attr(), + out_dist_attr)); + same_status_func.Eval(dev_ctx, tmp_result, out_dist_attr, out); +} + REGISTER_RESHARD_FUNC(RToSReshardFunction); +REGISTER_RESHARD_FUNC(RToSReshardFunctionCrossMesh); } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h index 3a86ff0cfa074..4149426aacf66 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h @@ -21,9 +21,17 @@ namespace distributed { class RToSReshardFunction final : public ReshardFunction { public: - RToSReshardFunction() = default; - ~RToSReshardFunction() = default; + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + +class RToSReshardFunctionCrossMesh final : public ReshardFunction { + public: bool IsSuitable(const DistTensor& in, const TensorDistAttr& out_dist_attr) override; diff --git a/test/auto_parallel/reshard_r_to_s_cross_mesh.py b/test/auto_parallel/reshard_r_to_s_cross_mesh.py new file mode 100644 index 0000000000000..68db1bcd7ef0c --- /dev/null +++ b/test/auto_parallel/reshard_r_to_s_cross_mesh.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardRToSCrossMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"]) + + def run_test_case(self): + # cpu does not support send/recv + if self._backend == "cpu": + return + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + + paddle.seed(self._seeds) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[self._shard] = "x" + + dist_attr = dist.DistAttr( + mesh=self._in_mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._out_mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + reshard_func = core.RToSReshardFunctionCrossMesh() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + out_shape = list(self._shape) + + if out_shape[self._shard] % 2 == 0: + out_shape[self._shard] = out_shape[self._shard] // 2 + split_shape = self._in_mesh.shape[0] + else: + split_shape = [ + out_shape[self._shard] // 2 + 1, + out_shape[self._shard] // 2, + ] + out_shape[self._shard] = ( + split_shape[0] if dist.get_rank() == 1 else split_shape[1] + ) + + out_expected_local_tensor_list = paddle.split( + value, num_or_sections=split_shape, axis=self._shard + ) + + np.testing.assert_equal( + out._local_value().numpy(), + out_expected_local_tensor_list[0].numpy() + if dist.get_rank() == 1 + else out_expected_local_tensor_list[1].numpy(), + ) + + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, out_shape).all() + + +if __name__ == '__main__': + TestReshardRToSCrossMesh().run_test_case() diff --git a/test/auto_parallel/test_reshard_r_to_s.py b/test/auto_parallel/test_reshard_r_to_s.py index 68699885094de..b951508f8c1c9 100644 --- a/test/auto_parallel/test_reshard_r_to_s.py +++ b/test/auto_parallel/test_reshard_r_to_s.py @@ -22,7 +22,7 @@ def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { "dtype": "float32", - "seeds": str(self._seeds), + "seeds": "2023", } self._changeable_envs = { "shape": ["(10, 20)", "(5, 7)"], @@ -40,6 +40,17 @@ def test_reshard_r_to_s(self): user_defined_envs=envs, ) + def test_reshard_r_to_s_cross_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + if envs["backend"] != "cpu": + self.run_test_case( + "reshard_r_to_s_cross_mesh.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main()