diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 6401712457b348..d51ae660fd2c74 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -47,6 +47,7 @@ #include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" @@ -215,6 +216,10 @@ void BindAutoParallel(py::module *m) { *m, "SToSReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "SToPReshardFunction", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "PToSReshardFunction", ReshardFunction) .def(py::init<>()); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt index 10f9887e4d2a46..0a5753fe840472 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/reshard/CMakeLists.txt @@ -9,5 +9,6 @@ collect_srcs( p_to_r_reshard_function.cc s_to_s_reshard_function.cc p_to_s_reshard_function.cc + s_to_p_reshard_function.cc nd_mesh_reshard_function.cc same_status_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.cc new file mode 100644 index 00000000000000..fc51482ee8660e --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.h" + +#include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/kernels/reduce_scatter_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { +namespace distributed { + +bool SToPReshardFunction::IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_partial()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); + + return true; +} + +void SToPReshardFunction::Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call SToPReshardFunction Eval"; + + // step 1, create tmp dist attr and tmp dist tensor + TensorDistAttr tmp_attr(out_dist_attr); + DistTensor tmp_tensor; + tmp_attr.clean_partial_status(); + + // step 2, do s to r reshard on `in` to `tmp` + SToRReshardFunction s_to_r; + s_to_r.Eval(dev_ctx, in, tmp_attr, &tmp_tensor); + + // step 3, do r to p reshard on `tmp` to `out` + RToPReshardFunction r_to_p; + r_to_p.Eval(dev_ctx, tmp_tensor, out_dist_attr, out); +} + +REGISTER_RESHARD_FUNC(SToPReshardFunction); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h new file mode 100644 index 00000000000000..7a72bcb6716e70 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_p_reshard_function.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" + +namespace phi { +namespace distributed { + +class SToPReshardFunction final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; + + std::string Name() override { return "SToPReshard"; } +}; + +} // namespace distributed +} // namespace phi diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 83df9ba7b622bf..374455dd5b3a5d 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -96,6 +96,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r) set_tests_properties(test_reshard_s_to_r PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_reshard_s_to_p MODULES test_reshard_s_to_p) + set_tests_properties(test_reshard_s_to_p + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_reshard_p_to_s MODULES test_reshard_p_to_s) set_tests_properties(test_reshard_p_to_s PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) diff --git a/test/auto_parallel/reshard_s_to_p.py b/test/auto_parallel/reshard_s_to_p.py new file mode 100644 index 00000000000000..4868d663b8fffb --- /dev/null +++ b/test/auto_parallel/reshard_s_to_p.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardSToR: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._backend = os.getenv("backend") + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + a = paddle.ones(self._shape) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[0] = "x" + out_shard_specs = [None for i in range(len(self._shape))] + + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + out_dist_attr._set_partial_dims([0]) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + assert input_tensor._local_shape[0] == self._shape[0] // 2 + + reshard_func = core.SToPReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + if dist.get_rank() == 0: + np.testing.assert_equal(out._local_value().numpy(), a.numpy()) + else: + zeros = paddle.zeros(self._shape) + np.testing.assert_equal(out._local_value().numpy(), zeros.numpy()) + + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, input_tensor.shape).all() + + +if __name__ == '__main__': + TestReshardSToR().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_shard_optimizer.py b/test/auto_parallel/semi_auto_parallel_shard_optimizer.py index 090a244c886d1e..97073db86bcc74 100644 --- a/test/auto_parallel/semi_auto_parallel_shard_optimizer.py +++ b/test/auto_parallel/semi_auto_parallel_shard_optimizer.py @@ -171,8 +171,7 @@ def run_test_case(self): if self._backend == "gpu": self.test_adamw_mp() self.test_adamw_shard_optimizer(stage1=True) - # A problem has to be addressed if not shard batch. - # self.test_adamw_shard_optimizer(stage1=False) + self.test_adamw_shard_optimizer(stage1=False) if __name__ == '__main__': diff --git a/test/auto_parallel/test_reshard_s_to_p.py b/test/auto_parallel/test_reshard_s_to_p.py new file mode 100644 index 00000000000000..ad2e6228f8e729 --- /dev/null +++ b/test/auto_parallel/test_reshard_s_to_p.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestReshardSToP(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "shape": "(10, 20)", + "dtype": "float32", + "seeds": "1234", + } + self._changeable_envs = { + "backend": ["cpu", "gpu"], + } + + def test_reshard_s_to_r(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "reshard_s_to_p.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main()