diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index fe555cacb3e18..f6596f3db31d5 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -45,6 +45,7 @@ #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h" #include "paddle/phi/core/enforce.h" #ifdef PADDLE_WITH_DISTRIBUTE @@ -200,6 +201,10 @@ void BindAutoParallel(py::module *m) { *m, "SameNdMeshReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "SameStatusReshardFunction", ReshardFunction) + .def(py::init<>()); + py::class_(*m, "ProcessMesh") .def(py::init<>()) .def(py::init &, diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 038888196f09f..92e69e0dc7657 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -17,4 +17,5 @@ collect_srcs( r_to_p_reshard_function.cc p_to_r_reshard_function.cc s_to_s_reshard_function.cc - nd_mesh_reshard_function.cc) + nd_mesh_reshard_function.cc + same_status_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 6edc0bf188ee5..94d611e8043aa 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -18,6 +18,7 @@ #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/store/store_utils.h" namespace phi { namespace distributed { @@ -35,20 +36,27 @@ inline void check_defined(const DistTensor& dist_tensor, DistTensor::DistTensor(const phi::DenseTensor& global_value, const TensorDistAttr& dist_attr) : dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) { - if (value_.initialized() && !dist_attr.is_replicated()) { - // 1. create replicated global tensor - int64_t dims_size = global_value.dims().size(); - std::vector dims_mapping(dims_size, -1); - dist_attr_.set_dims_mapping(dims_mapping); - if (dist_attr_.is_partial()) { - dist_attr_.clean_partial_status(); + // TODO(liyurui): This is a temporary solution. We need to support only infer + // meta when the input dense_tensor is empty. + // Support the value in DistTensor only has DenseTensor meta + // but without actual data. So we can visit its meta attr even if it is + // undefined. + if (IsCurRankInMesh(dist_attr.process_mesh())) { + if (value_.initialized() && !dist_attr.is_replicated()) { + // 1. create replicated global tensor + int64_t dims_size = global_value.dims().size(); + std::vector dims_mapping(dims_size, -1); + dist_attr_.set_dims_mapping(dims_mapping); + if (dist_attr_.is_partial()) { + dist_attr_.clean_partial_status(); + } + dist_attr_.set_dims_mapping(dims_mapping); + + // 2. reshard from replicated to other state + auto* func = ChooseProperReshardFunction(*this, dist_attr); + auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place()); + func->Eval(dev_ctx, *this, dist_attr, this); } - dist_attr_.set_dims_mapping(dims_mapping); - - // 2. reshard from replicated to other state - auto* func = ChooseProperReshardFunction(*this, dist_attr); - auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place()); - func->Eval(dev_ctx, *this, dist_attr, this); } } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index 2767dfa836394..60c9cbdda3b67 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -24,13 +24,6 @@ namespace phi { namespace distributed { namespace { -int64_t GetLocalRankInParticipate(const std::vector& process_ids) { - int64_t cur_global_rank = GetCurGlobalRank(); - auto iter = - std::find(process_ids.begin(), process_ids.end(), cur_global_rank); - return iter - process_ids.begin(); -} - std::string GenUniqueCommKey(const std::vector& process_ids) { std::string unique_comm_key = "ReshardGroup"; for (const auto& id : process_ids) { @@ -40,6 +33,20 @@ std::string GenUniqueCommKey(const std::vector& process_ids) { } } // namespace +int64_t GetLocalRankInParticipate(const std::vector& process_ids, + int64_t global_rank) { + if (global_rank == -1) { + global_rank = GetCurGlobalRank(); + } + auto iter = std::find(process_ids.begin(), process_ids.end(), global_rank); + PADDLE_ENFORCE_NE( + iter, + process_ids.end(), + phi::errors::NotFound("Global rank %lld cannot be found in process_mesh", + global_rank)); + return iter - process_ids.begin(); +} + std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { const auto& process_shape = process_mesh.shape(); const auto& process_ids = process_mesh.process_ids(); @@ -132,5 +139,12 @@ std::vector BalancedSplit(int64_t total_nums, int64_t num_of_pieces) { return result; } +bool IsCurRankInMesh(const ProcessMesh& process_mesh) { + int64_t cur_global_rank = GetCurGlobalRank(); + const auto& process_ids = process_mesh.process_ids(); + return (std::find(process_ids.begin(), process_ids.end(), cur_global_rank) != + process_ids.end()); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index 831a4c6e0d2af..652840976194f 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -30,6 +30,11 @@ class DeviceContext; namespace distributed { class ProcessMesh; +bool IsCurRankInMesh(const ProcessMesh& process_mesh); + +int64_t GetLocalRankInParticipate(const std::vector& process_ids, + int64_t global_rank = -1); + // Get the coordinate of cur rank in process mesh. For example, the process mesh // is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will // return [2, 0]; if the current rank is 3, then will return [1, 1]. diff --git a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc new file mode 100644 index 0000000000000..a6f49268c5612 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h" + +#include + +#include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/distributed/store/store_utils.h" +#include "paddle/phi/kernels/p_recv_kernel.h" +#include "paddle/phi/kernels/p_send_kernel.h" + +namespace phi { +namespace distributed { + +namespace { + +std::vector GetUnionProcessIds(std::vector in_process_ids, + std::vector out_process_ids) { + std::vector result; + std::sort(in_process_ids.begin(), in_process_ids.end()); + std::sort(out_process_ids.begin(), out_process_ids.end()); + std::set_union(in_process_ids.begin(), + in_process_ids.end(), + out_process_ids.begin(), + out_process_ids.end(), + std::back_inserter(result)); + return result; +} + +} // namespace + +bool SameStatusReshardFunction::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + bool flag = true; + const auto& in_dist_attr = in.dist_attr(); + + flag &= (in_dist_attr.dims_mapping() == out_dist_attr.dims_mapping()); + flag &= (in_dist_attr.partial_dims() == out_dist_attr.partial_dims()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + flag &= (in_process_mesh != out_process_mesh); + flag &= (in_process_mesh.shape() == out_process_mesh.shape()); + + return flag; +} + +void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + const auto& in_dist_attr = in.dist_attr(); + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& in_process_ids = in_process_mesh.process_ids(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + const auto& out_process_ids = out_process_mesh.process_ids(); + auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids); + auto dtype = in.dtype(); + // TODO(liyurui): Use dynamic shape will lead to poor performance, but we + // don't have any other good idea now. For the following reasons: + // 1. We can not ensure the meta being right deduce by the infermeta. + // 2. The meta of some kernels can't decide in compile time. + // 3. DenseTensor with empty value only need infermeta and skip the real + // kernel execution. + bool dynamic_shape = true; + + std::vector> p2p_pair; + for (size_t i = 0; i < out_process_ids.size(); ++i) { + p2p_pair.emplace_back( + std::make_pair(in_process_ids[i], out_process_ids[i])); + } + + int64_t cur_global_rank = GetCurGlobalRank(); + for (const auto& iter : p2p_pair) { + int64_t src = iter.first; + int64_t dst = iter.second; + VLOG(3) << "Send/Recv from src " << src << " to dst " << dst; + if (src == cur_global_rank) { + int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst); + // Sice send kernel only has input, so we don't need to infermeta + // actually. According to this reason, just use the kernel directly. + RESHARD_FUNCTOR_WITH_COMM(dev_ctx, + PSendKernel, + dtype, + all_process_ids, + in.value(), + dst_local_rank, + dynamic_shape); + } else if (dst == cur_global_rank) { + int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src); + RESHARD_FUNCTOR_WITH_COMM(dev_ctx, + PRecv, + dtype, + all_process_ids, + src_local_rank, + dynamic_shape, + GetMutableTensor(out)); + } + } + SetDistProps(out, in.dims(), out_dist_attr); +} + +REGISTER_RESHARD_FUNC(SameStatusReshardFunction); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h new file mode 100644 index 0000000000000..38c044e083a09 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" + +namespace phi { +namespace distributed { + +class SameStatusReshardFunction final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/store/store_utils.cc b/paddle/phi/core/distributed/store/store_utils.cc index c2679ef2192a3..7730b23301af3 100644 --- a/paddle/phi/core/distributed/store/store_utils.cc +++ b/paddle/phi/core/distributed/store/store_utils.cc @@ -48,19 +48,17 @@ std::string GetMasterEndpoint() { int64_t GetCurGlobalRank() { const char* cur_rank = std::getenv("PADDLE_TRAINER_ID"); - PADDLE_ENFORCE_NOT_NULL( - cur_rank, - phi::errors::NotFound( - "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + if (cur_rank == nullptr) { + return 0; + } return std::atoi(cur_rank); } int64_t GetGlobalWorldSize() { const char* world_size = std::getenv("PADDLE_TRAINERS_NUM"); - PADDLE_ENFORCE_NOT_NULL( - world_size, - phi::errors::NotFound( - "The environment variable 'PADDLE_TRAINERS_NUM' cannot be found.")); + if (world_size == nullptr) { + return 1; + } return std::atoi(world_size); } diff --git a/paddle/phi/kernels/cpu/p_recv_kernel.cc b/paddle/phi/kernels/cpu/p_recv_kernel.cc index 10526e6935e1e..425335e3ce8a3 100644 --- a/paddle/phi/kernels/cpu/p_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/p_recv_kernel.cc @@ -54,6 +54,7 @@ PD_REGISTER_KERNEL(p_recv, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} @@ -67,5 +68,6 @@ PD_REGISTER_KERNEL(p_recv_array, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/p_send_kernel.cc b/paddle/phi/kernels/cpu/p_send_kernel.cc index a786de7ecaf3b..d417f19314423 100644 --- a/paddle/phi/kernels/cpu/p_send_kernel.cc +++ b/paddle/phi/kernels/cpu/p_send_kernel.cc @@ -53,6 +53,7 @@ PD_REGISTER_KERNEL(p_send, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} @@ -66,5 +67,6 @@ PD_REGISTER_KERNEL(p_send_array, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/p_recv_kernel.cu b/paddle/phi/kernels/gpu/p_recv_kernel.cu index 6f737eece9f54..1e413797b6b89 100644 --- a/paddle/phi/kernels/gpu/p_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/p_recv_kernel.cu @@ -190,6 +190,7 @@ PD_REGISTER_KERNEL(p_recv, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::bfloat16, phi::dtype::float16) {} @@ -218,6 +219,7 @@ PD_REGISTER_KERNEL(p_recv, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/p_send_kernel.cu b/paddle/phi/kernels/gpu/p_send_kernel.cu index 6bf5bcb8155f8..520adcf730a1d 100644 --- a/paddle/phi/kernels/gpu/p_send_kernel.cu +++ b/paddle/phi/kernels/gpu/p_send_kernel.cu @@ -178,6 +178,7 @@ PD_REGISTER_KERNEL(p_send, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::bfloat16, phi::dtype::float16) {} @@ -206,6 +207,7 @@ PD_REGISTER_KERNEL(p_send, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/p_recv_kernel.h b/paddle/phi/kernels/p_recv_kernel.h index 4478c838a61ff..8a013c9e653fb 100644 --- a/paddle/phi/kernels/p_recv_kernel.h +++ b/paddle/phi/kernels/p_recv_kernel.h @@ -16,6 +16,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_array.h" +#include "paddle/phi/infermeta/nullary.h" namespace phi { @@ -26,6 +27,19 @@ void PRecvKernel(const Context& dev_ctx, bool dynamic_shape, DenseTensor* out); +template +void PRecv(const Context& dev_ctx, + int peer, + bool dynamic_shape, + DenseTensor* out) { + MetaTensor out_meta(*out); + MetaTensor* out_meta_ptr = &out_meta; + DataType dtype = phi::CppTypeToDataType::Type(); + + PRecvInferMeta(peer, dtype, out_meta_ptr); + PRecvKernel(dev_ctx, peer, dtype, dynamic_shape, out); +} + template void PRecvArrayKernel(const Context& dev_ctx, int peer, diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 48df714387854..8efa6f6a5e400 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -106,6 +106,11 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_nd_mesh MODULES test_reshard_nd_mesh) set_tests_properties(test_reshard_nd_mesh PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + + py_test_modules(test_reshard_same_status MODULES test_reshard_same_status) + set_tests_properties(test_reshard_same_status + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_semi_auto_parallel_basic MODULES test_semi_auto_parallel_basic) set_tests_properties(test_semi_auto_parallel_basic diff --git a/test/auto_parallel/reshard_same_status.py b/test/auto_parallel/reshard_same_status.py new file mode 100644 index 0000000000000..f6c7c6eaff166 --- /dev/null +++ b/test/auto_parallel/reshard_same_status.py @@ -0,0 +1,174 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.framework import core + + +def get_coord(mesh_list, rank): + x = 0 + y = 0 + for sub_list in mesh_list: + if rank in sub_list: + y = sub_list.index(rank) + return x, y + x += 1 + return -1, -1 + + +class TestReshardSameStatus: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._backend = os.getenv("backend") + + def test_diff_1d_mesh_shard(self, dev_ctx): + paddle.seed(self._seeds) + + in_mesh_list = [0] + out_mesh_list = [1] + in_mesh = dist.ProcessMesh(in_mesh_list, dim_names=["x"]) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[0] = "x" + dist_attr = dist.DistAttr(mesh=in_mesh, sharding_specs=in_shard_specs) + + in_expected_local_tensor_list = paddle.split( + value, num_or_sections=in_mesh.shape[0], axis=0 + ) + if dist.get_rank() in in_mesh_list: + index = in_mesh_list.index(dist.get_rank()) % in_mesh.shape[0] + elif dist.get_rank() in out_mesh_list: + index = out_mesh_list.index(dist.get_rank()) % in_mesh.shape[0] + + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + if dist.get_rank() in in_mesh_list: + # check the value of input tensor + in_expected_local_tensor_list = paddle.split( + value, num_or_sections=in_mesh.shape[0], axis=0 + ) + np.testing.assert_equal( + input_tensor._local_value().numpy(), + in_expected_local_tensor_list[index].numpy(), + ) + + out_mesh = dist.ProcessMesh(out_mesh_list, dim_names=["x"]) + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[0] = "x" + out_dist_attr = dist.DistAttr( + mesh=out_mesh, sharding_specs=out_shard_specs + ) + + reshard_func = core.SameStatusReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + if dist.get_rank() in out_mesh_list: + np.testing.assert_equal( + out._local_value().numpy(), + in_expected_local_tensor_list[index].numpy(), + ) + + def test_diff_nd_mesh_shard_partial(self, dev_ctx): + paddle.seed(self._seeds) + + in_mesh_list = [[0], [1]] + out_mesh_list = [[1], [0]] + in_mesh = dist.ProcessMesh(in_mesh_list, dim_names=["x", "y"]) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[0] = "x" + dist_attr = dist.DistAttr(mesh=in_mesh, sharding_specs=in_shard_specs) + dist_attr._set_partial_dims([1]) + + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + in_expected_local_tensor_list = paddle.split( + value, num_or_sections=in_mesh.shape[0], axis=0 + ) + + in_flatten_list = [ + item for sub_list in in_mesh_list for item in sub_list + ] + out_flatten_list = [ + item for sub_list in out_mesh_list for item in sub_list + ] + + in_x, in_y = get_coord(in_mesh_list, dist.get_rank()) + out_x, out_y = get_coord(out_mesh_list, dist.get_rank()) + + if dist.get_rank() in in_flatten_list: + if in_y == 0: + np.testing.assert_equal( + input_tensor._local_value().numpy(), + in_expected_local_tensor_list[in_x].numpy(), + ) + else: + zeros = paddle.zeros(input_tensor._local_shape) + np.testing.assert_equal( + input_tensor._local_value().numpy(), + zeros.numpy(), + ) + + out_mesh = dist.ProcessMesh(out_mesh_list, dim_names=["x", "y"]) + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[0] = "x" + out_dist_attr = dist.DistAttr( + mesh=out_mesh, sharding_specs=out_shard_specs + ) + out_dist_attr._set_partial_dims([1]) + + reshard_func = core.SameStatusReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + if dist.get_rank() in out_flatten_list: + if out_y == 0: + np.testing.assert_equal( + out._local_value().numpy(), + in_expected_local_tensor_list[out_x].numpy(), + ) + else: + zeros = paddle.zeros(out._local_shape) + np.testing.assert_equal( + out._local_value().numpy(), + zeros.numpy(), + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + + self.test_diff_1d_mesh_shard(dev_ctx) + self.test_diff_nd_mesh_shard_partial(dev_ctx) + + +if __name__ == '__main__': + TestReshardSameStatus().run_test_case() diff --git a/test/auto_parallel/test_reshard_same_status.py b/test/auto_parallel/test_reshard_same_status.py new file mode 100644 index 0000000000000..795c5b0e67520 --- /dev/null +++ b/test/auto_parallel/test_reshard_same_status.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestReshardSameStatus(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "shape": "(6, 10, 20, 12)", + "dtype": "float32", + "seeds": "100", + } + self._changeable_envs = { + "backend": ["gpu"], + } + + def test_reshard_same_status(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "reshard_same_status.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/cpp/auto_parallel/dist_tensor_test.cc b/test/cpp/auto_parallel/dist_tensor_test.cc index c190c0e7b17ca..9882a4b831bb5 100644 --- a/test/cpp/auto_parallel/dist_tensor_test.cc +++ b/test/cpp/auto_parallel/dist_tensor_test.cc @@ -36,6 +36,12 @@ TEST(dist_tensor, constructor) { auto dist_attr = TensorDistAttr(phi::vectorize(dims)); + std::vector mesh_shape = {1}; + std::vector process_ids = {0}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + dist_attr.set_process_mesh(mesh); + // copy construct DenseTensor x1(alloc, meta); DistTensor dist_x1(x1, dist_attr);