From 7f5769d46d2fc1c18d779affa9ffe824d033ed38 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Mon, 17 Jul 2023 20:21:27 +0800 Subject: [PATCH 1/4] Implement replicated to split reshard function --- paddle/fluid/pybind/auto_parallel_py.cc | 4 + .../distributed/auto_parallel/CMakeLists.txt | 6 +- .../distributed/auto_parallel/dist_tensor.h | 2 +- .../auto_parallel/r_to_s_reshard_function.cc | 105 ++++++++++++ .../auto_parallel/r_to_s_reshard_function.h | 39 +++++ .../auto_parallel/reshard_function.cc | 24 +++ .../auto_parallel/reshard_function.h | 44 ++++++ .../auto_parallel/reshard_split_functor.cc | 64 ++++++++ .../auto_parallel/reshard_split_functor.h | 56 +++++++ .../auto_parallel/reshard_utils.cc | 77 +++++++++ .../distributed/auto_parallel/reshard_utils.h | 40 +++++ test/cpp/auto_parallel/CMakeLists.txt | 5 + test/cpp/auto_parallel/test_reshard_r_to_s.cc | 149 ++++++++++++++++++ 13 files changed, 613 insertions(+), 2 deletions(-) create mode 100644 paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_function.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_function.h create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_utils.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/reshard_utils.h create mode 100644 test/cpp/auto_parallel/test_reshard_r_to_s.cc diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index bdb8a763a91fd..9549e2081097e 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -22,6 +22,7 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/utils/optional.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" @@ -47,6 +48,7 @@ using phi::distributed::auto_parallel::Link; using phi::distributed::auto_parallel::LinkCapability; using phi::distributed::auto_parallel::Machine; using phi::distributed::auto_parallel::ProcessMesh; +using phi::distributed::auto_parallel::RToSReshardFunction; using phi::distributed::auto_parallel::TensorDistAttr; PyTypeObject *g_tensor_dist_attr_pytype = nullptr; @@ -107,6 +109,8 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { } void BindAutoParallel(py::module *m) { + py::class_(*m, "RToSReshardFunction").def(py::init<>()); + py::class_(*m, "ProcessMesh") .def(py::init<>()) .def(py::init &, diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index db639bba5f400..9364c14e4c8c3 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -7,4 +7,8 @@ collect_srcs( process_mesh.cc dist_attr.cc dist_mapper.cc - dist_tensor.cc) + dist_tensor.cc + reshard_function.cc + reshard_split_functor.cc + reshard_utils.cc + r_to_s_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index eb3a6dbbe3e66..63a7438a6ae7a 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -31,7 +31,7 @@ class DistTensor final public: /// \brief Construct a dist tensor and allocate space. /// \param a The allocator used to allocate space. - /// \param meta The meta data of dense tensor. + /// \param meta The meta data of dist tensor. DistTensor(Allocator* a, const DenseTensorMeta& meta, const std::shared_ptr& dist_attr) diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc new file mode 100644 index 0000000000000..e8d4bae1000e5 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#include "glog/logging.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/core/kernel_factory.h" + +namespace phi { +namespace distributed { +namespace auto_parallel { + +bool RToSReshardFunction::Check( + const DistTensor& in, + const std::shared_ptr& out_dist_attr) { + bool flag = true; + const auto& in_dist_attr = in.dist_attr(); + + const auto& in_dims_mapping = in_dist_attr->dims_mapping(); + const auto& out_dims_mapping = out_dist_attr->dims_mapping(); + flag &= IsDimsMappingReplicated(in_dims_mapping); + flag &= IsDimsMappingShard(out_dims_mapping); + + const auto& in_process_mesh = in_dist_attr->process_mesh(); + const auto& out_process_mesh = out_dist_attr->process_mesh(); + + flag &= (in_process_mesh == out_process_mesh); + return flag; +} + +std::shared_ptr RToSReshardFunction::Eval( + const phi::KernelKey& kernel_key, + const DistTensor& in, + const std::shared_ptr& out_dist_attr) { + const auto& out_dims_mapping = out_dist_attr->dims_mapping(); + const auto& out_process_mesh = out_dist_attr->process_mesh(); + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + const DenseTensor& origin_in_physical_tensor_cur_rank = in.value(); + const DenseTensor* in_physical_tensor_cur_rank = + &origin_in_physical_tensor_cur_rank; + + DenseTensor out_physical_tensor_cur_rank; + + std::map split_axis_to_mesh_axis = + GetSplitAxisWithDimsMapping(out_dims_mapping); + std::vector coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh); + + for (const auto& iter : split_axis_to_mesh_axis) { + int64_t split_axis = iter.first; + int64_t mesh_axis = iter.second; + + PADDLE_ENFORCE_LT( + mesh_axis, + out_process_mesh.ndim(), + phi::errors::OutOfRange( + "The mesh axis %lld exceed the size of process mesh %lld.", + mesh_axis, + out_process_mesh.ndim())); + + int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; + VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis + << ". Split will use axis " << mesh_axis << " of process_mesh." + << " There will have " << num_of_process + << " process participate in."; + + // TODO(liyurui): Consider the tensor can not be balanced split, + // for example, the shape of tensor is {6} but want to split it by 4 + // process. + IntArray sections(std::vector( + num_of_process, in.dims()[split_axis] / num_of_process)); + + ReshardSplitFunctor split_func(kernel_key, sections, split_axis); + + std::vector split_out_vec(sections.size(), DenseTensor()); + split_func(*dev_ctx, *in_physical_tensor_cur_rank, &split_out_vec); + + VLOG(3) << "The current process will remain the idx " + << coord_in_mesh[mesh_axis] << " piece of tensor"; + out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; + in_physical_tensor_cur_rank = &out_physical_tensor_cur_rank; + } + + return std::make_shared( + std::make_shared(out_physical_tensor_cur_rank), + out_dist_attr); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h new file mode 100644 index 0000000000000..214aebb04bad6 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" + +namespace phi { +namespace distributed { +namespace auto_parallel { + +class RToSReshardFunction final : public ReshardFunction { + public: + RToSReshardFunction() = default; + ~RToSReshardFunction() = default; + + bool Check(const DistTensor& in, + const std::shared_ptr& out_dist_attr) override; + + std::shared_ptr Eval( + const KernelKey& kernel_key, + const DistTensor& in, + const std::shared_ptr& out_dist_attr) override; +}; + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc new file mode 100644 index 0000000000000..f2b273e3eabbb --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" + +namespace phi { +namespace distributed { +namespace auto_parallel {} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard_function.h new file mode 100644 index 0000000000000..93b57ee81c6c2 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.h @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace phi { +class KernelKey; +class DeviceContext; + +namespace distributed { +namespace auto_parallel { + +class TensorDistAttr; +class DistTensor; + +class ReshardFunction { + public: + ReshardFunction() = default; + virtual ~ReshardFunction() = default; + + virtual bool Check(const DistTensor& in, + const std::shared_ptr& out_dist_attr) = 0; + + virtual std::shared_ptr Eval( + const KernelKey& kernel_key, + const DistTensor& in, + const std::shared_ptr& out_dist_attr) = 0; +}; + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc new file mode 100644 index 0000000000000..b6f62f01586e8 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h" +#include "glog/logging.h" +#include "paddle/phi/api/lib/api_gen_utils.h" +#include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/infermeta/unary.h" + +namespace phi { +namespace distributed { +namespace auto_parallel { + +ReshardSplitFunctor::ReshardSplitFunctor(const phi::KernelKey& kernel_key, + const IntArray& sections, + int64_t axis) + : sections_(sections), axis_(axis) { + KernelResult kernel_result = + phi::KernelFactory::Instance().SelectKernelOrThrowError("split", + kernel_key); + const Kernel& kernel = kernel_result.kernel; + VLOG(3) << "Select split kernel: " << kernel; + functor_ = kernel.GetVariadicKernelFn(); +} + +void ReshardSplitFunctor::operator()(const DeviceContext& dev_ctx, + const DenseTensor& input, + std::vector* output) { + std::vector out_ptr_vec; + for (size_t i = 0; i < output->size(); ++i) { + out_ptr_vec.emplace_back(&(output->at(i))); + } + PrepareOutput(input, out_ptr_vec); + (*functor_)(dev_ctx, input, sections_, axis_, out_ptr_vec); +} + +void ReshardSplitFunctor::PrepareOutput( + const DenseTensor& input, const std::vector& output) { + auto out_meta_vec = paddle::experimental::MakeMetaTensor(output); + + std::vector out_metas(out_meta_vec.size()); + for (size_t i = 0; i < out_meta_vec.size(); ++i) { + out_metas[i] = output[i] ? &out_meta_vec[i] : nullptr; + } + + phi::SplitInferMeta( + paddle::experimental::MakeMetaTensor(input), sections_, axis_, out_metas); +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h new file mode 100644 index 0000000000000..3081ee3f92bf0 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h @@ -0,0 +1,56 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/phi/api/include/tensor.h" + +namespace phi { +class DeviceContext; +class DenseTensor; +class KernelKey; + +namespace distributed { +namespace auto_parallel { + +class ReshardSplitFunctor final { + public: + using SPLIT_KERNEL_SIG = void (*)(const DeviceContext&, + const DenseTensor&, + const phi::IntArray&, + const phi::Scalar&, + std::vector); + + ReshardSplitFunctor(const KernelKey& kernel_key, + const IntArray& sections, + int64_t axis); + + void operator()(const DeviceContext& dev_ctx, + const DenseTensor& input, + std::vector* output); + + private: + IntArray sections_; + int64_t axis_; + SPLIT_KERNEL_SIG functor_; + + void PrepareOutput(const DenseTensor& input, + const std::vector& output); +}; + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc new file mode 100644 index 0000000000000..08a2e89990557 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include +#include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" + +namespace phi { +namespace distributed { +namespace auto_parallel { + +bool IsDimsMappingShard(const std::vector& dims_mapping) { + return std::any_of(dims_mapping.begin(), + dims_mapping.end(), + [](int64_t value) { return value != -1; }); +} + +bool IsDimsMappingReplicated(const std::vector& dims_mapping) { + return std::all_of(dims_mapping.begin(), + dims_mapping.end(), + [](int64_t value) { return value == -1; }); +} + +int64_t GetCurGlobalRank() { + const char* cur_rank = std::getenv("PADDLE_TRAINER_ID"); + PADDLE_ENFORCE_NOT_NULL( + cur_rank, + phi::errors::NotFound( + "The environment variable 'PADDLE_TRAINER_ID' cannot be found.")); + return std::atoi(cur_rank); +} + +std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { + const auto& process_shape = process_mesh.shape(); + const auto& process_ids = process_mesh.process_ids(); + int64_t ndims_mesh = process_shape.size(); + int64_t cur_global_rank = GetCurGlobalRank(); + VLOG(3) << "Current global rank is " << cur_global_rank << " with ndims_mesh " + << ndims_mesh; + int64_t flat_idx_in_mesh = + std::find(process_ids.begin(), process_ids.end(), cur_global_rank) - + process_ids.begin(); + + std::vector coord(ndims_mesh, -1); + for (int64_t i = ndims_mesh - 1; i >= 0; --i) { + coord[i] = flat_idx_in_mesh % process_shape[i]; + flat_idx_in_mesh /= process_shape[i]; + } + return coord; +} + +std::map GetSplitAxisWithDimsMapping( + const std::vector& dims_mapping) { + std::map split_axis_to_mesh_axis; + for (size_t i = 0; i < dims_mapping.size(); ++i) { + if (dims_mapping[i] != -1) { + split_axis_to_mesh_axis.emplace(i, dims_mapping[i]); + } + } + return split_axis_to_mesh_axis; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h new file mode 100644 index 0000000000000..c02e0a78ce690 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace phi { +namespace distributed { +namespace auto_parallel { + +class ProcessMesh; + +bool IsDimsMappingShard(const std::vector& dims_mapping); + +bool IsDimsMappingReplicated(const std::vector& dims_mapping); + +int64_t GetCurGlobalRank(); + +std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh); + +std::map GetSplitAxisWithDimsMapping( + const std::vector& dims_mapping); + +} // namespace auto_parallel +} // namespace distributed +} // namespace phi diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index c5912a6fa1021..6e0ea8db1e045 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -9,6 +9,11 @@ if(WITH_DISTRIBUTE) dist_tensor_test SRCS dist_tensor_test.cc DEPS phi) + + cc_test( + test_reshard_r_to_s + SRCS test_reshard_r_to_s.cc + DEPS phi) endif() cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi) diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc new file mode 100644 index 0000000000000..13d160e7957e6 --- /dev/null +++ b/test/cpp/auto_parallel/test_reshard_r_to_s.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/kernel_factory.h" +#include "test/cpp/phi/core/allocator.h" + +namespace phi { +namespace distributed { +namespace auto_parallel { +namespace tests { + +std::shared_ptr ConstructBroadcastDistTensor( + const std::vector& shape, + const DataLayout& layout, + const DataType& dtype, + const ProcessMesh& mesh) { + const DDim dims(shape.data(), shape.size()); + const LoD lod{}; + DenseTensorMeta meta(dtype, dims, layout, lod); + + auto fancy_allocator = + std::unique_ptr(new phi::tests::FancyAllocator); + auto* alloc = fancy_allocator.get(); + std::shared_ptr dist_attr = + std::make_shared(shape); + + std::vector dims_mapping(shape.size(), -1); + dist_attr->set_dims_mapping(dims_mapping); + dist_attr->set_process_mesh(mesh); + + return std::make_shared(alloc, meta, dist_attr); +} + +TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { + setenv("PADDLE_TRAINER_ID", "1", 1); + + std::vector tensor_shape = {6, 12}; + const DataType dtype{DataType::FLOAT32}; + const DataLayout layout{DataLayout::NHWC}; + + std::vector mesh_shape = {4}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructBroadcastDistTensor(tensor_shape, layout, dtype, mesh); + int64_t split_axis = 1; + + // Use process mesh axis 0 to split tensor axis 1 + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping(tensor_shape.size(), -1); + out_dims_mapping[split_axis] = 0; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + RToSReshardFunction r_to_s_func; + KernelKey kernel_key = {Backend::CPU, layout, dtype}; + std::shared_ptr output = + r_to_s_func.Eval(kernel_key, *input, out_dist_attr); + + CHECK_EQ(r_to_s_func.Check(*input, out_dist_attr), true); + CHECK_EQ(output->numel(), 18); + CHECK_EQ(output->dims(), DDim({6, 3})); +} + +TEST(reshard_r_to_s, r_to_s_diff_placement) { + std::vector tensor_shape = {6, 12}; + const DataType dtype{DataType::FLOAT32}; + const DataLayout layout{DataLayout::NHWC}; + + std::vector mesh_shape = {4}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructBroadcastDistTensor(tensor_shape, layout, dtype, mesh); + int64_t split_axis = 1; + + std::vector out_process_ids = {2, 3, 4, 5}; + ProcessMesh out_mesh(mesh_shape, out_process_ids, dim_names); + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping(tensor_shape.size(), -1); + out_dims_mapping[split_axis] = 0; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(out_mesh); + + RToSReshardFunction r_to_s_func; + CHECK_EQ(r_to_s_func.Check(*input, out_dist_attr), false); +} + +TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { + setenv("PADDLE_TRAINER_ID", "6", 1); + + std::vector tensor_shape = {6, 12}; + const DataType dtype{DataType::FLOAT32}; + const DataLayout layout{DataLayout::NHWC}; + + std::vector mesh_shape = {4, 2}; + std::vector process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; + std::vector dim_names = {"x", "y"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructBroadcastDistTensor(tensor_shape, layout, dtype, mesh); + + // Use process mesh axis 0 to split tensor axis 1, use process mesh axis 1 to + // split tensor axis 0 + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {1, 0}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + RToSReshardFunction r_to_s_func; + KernelKey kernel_key = {Backend::CPU, layout, dtype}; + std::shared_ptr output = + r_to_s_func.Eval(kernel_key, *input, out_dist_attr); + + CHECK_EQ(r_to_s_func.Check(*input, out_dist_attr), true); + CHECK_EQ(output->numel(), 9); + CHECK_EQ(output->dims(), DDim({3, 3})); +} + +} // namespace tests +} // namespace auto_parallel +} // namespace distributed +} // namespace phi From 2904c298afdd396b532cc214a15a16e9427dd726 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Thu, 20 Jul 2023 11:51:21 +0800 Subject: [PATCH 2/4] fix link error in clang --- paddle/fluid/pybind/auto_parallel_py.cc | 8 ++++++-- .../distributed/auto_parallel/CMakeLists.txt | 19 ++++++++++++++----- .../auto_parallel/reshard_utils.cc | 18 +++++++++++++----- .../distributed/auto_parallel/reshard_utils.h | 8 ++++++++ test/cpp/auto_parallel/test_reshard_r_to_s.cc | 8 ++++---- 5 files changed, 45 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 9549e2081097e..2d9b81cd17b4f 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -22,11 +22,13 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" -#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/utils/optional.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" +#endif namespace py = pybind11; @@ -48,7 +50,6 @@ using phi::distributed::auto_parallel::Link; using phi::distributed::auto_parallel::LinkCapability; using phi::distributed::auto_parallel::Machine; using phi::distributed::auto_parallel::ProcessMesh; -using phi::distributed::auto_parallel::RToSReshardFunction; using phi::distributed::auto_parallel::TensorDistAttr; PyTypeObject *g_tensor_dist_attr_pytype = nullptr; @@ -109,7 +110,10 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { } void BindAutoParallel(py::module *m) { +#ifdef PADDLE_WITH_DISTRIBUTE + using phi::distributed::auto_parallel::RToSReshardFunction; py::class_(*m, "RToSReshardFunction").def(py::init<>()); +#endif py::class_(*m, "ProcessMesh") .def(py::init<>()) diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 9364c14e4c8c3..d4af259a5906c 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -1,5 +1,18 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto) +set(DISTRIBUTED_SRCS "") + +if(WITH_DISTRIBUTE) + list( + APPEND + DISTRIBUTED_SRCS + dist_tensor.cc + reshard_function.cc + reshard_split_functor.cc + reshard_utils.cc + r_to_s_reshard_function.cc) +endif() + collect_srcs( core_srcs SRCS @@ -7,8 +20,4 @@ collect_srcs( process_mesh.cc dist_attr.cc dist_mapper.cc - dist_tensor.cc - reshard_function.cc - reshard_split_functor.cc - reshard_utils.cc - r_to_s_reshard_function.cc) + ${DISTRIBUTED_SRCS}) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index 08a2e89990557..357ed35918a2d 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -47,11 +47,19 @@ std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { const auto& process_ids = process_mesh.process_ids(); int64_t ndims_mesh = process_shape.size(); int64_t cur_global_rank = GetCurGlobalRank(); - VLOG(3) << "Current global rank is " << cur_global_rank << " with ndims_mesh " - << ndims_mesh; - int64_t flat_idx_in_mesh = - std::find(process_ids.begin(), process_ids.end(), cur_global_rank) - - process_ids.begin(); + + VLOG(3) << "Searching current global rank " << cur_global_rank + << " in process_mesh " << process_mesh; + + auto iter = + std::find(process_ids.begin(), process_ids.end(), cur_global_rank); + PADDLE_ENFORCE_NE( + iter, + process_ids.end(), + phi::errors::NotFound("Rank %lld cannot be found in process_mesh", + cur_global_rank)); + + int64_t flat_idx_in_mesh = iter - process_ids.begin(); std::vector coord(ndims_mesh, -1); for (int64_t i = ndims_mesh - 1; i >= 0; --i) { diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index c02e0a78ce690..2a8d79053f13a 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -30,8 +30,16 @@ bool IsDimsMappingReplicated(const std::vector& dims_mapping); int64_t GetCurGlobalRank(); +// Get the coordinate of cur rank in process mesh. For example, the process mesh +// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will +// return [2, 0]; if the current rank is 3, then will return [1, 1]. std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh); +// If the index i's value in dims_mapping is x ( x != -1), means the ith axis of +// tensor need be split by xth axis of process_mesh. The function analyze the +// input vector, return a key-value map of tensor_split_axis and +// process_mesh_split_axis. +// For example, if dims_mapping is [-1, 1, -1, 0], will return {1: 1, 3: 0}. std::map GetSplitAxisWithDimsMapping( const std::vector& dims_mapping); diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc index 13d160e7957e6..1fdbb161beecb 100644 --- a/test/cpp/auto_parallel/test_reshard_r_to_s.cc +++ b/test/cpp/auto_parallel/test_reshard_r_to_s.cc @@ -27,7 +27,7 @@ namespace distributed { namespace auto_parallel { namespace tests { -std::shared_ptr ConstructBroadcastDistTensor( +std::shared_ptr ConstructReplicatedDistTensor( const std::vector& shape, const DataLayout& layout, const DataType& dtype, @@ -62,7 +62,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructBroadcastDistTensor(tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistTensor(tensor_shape, layout, dtype, mesh); int64_t split_axis = 1; // Use process mesh axis 0 to split tensor axis 1 @@ -94,7 +94,7 @@ TEST(reshard_r_to_s, r_to_s_diff_placement) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructBroadcastDistTensor(tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistTensor(tensor_shape, layout, dtype, mesh); int64_t split_axis = 1; std::vector out_process_ids = {2, 3, 4, 5}; @@ -123,7 +123,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructBroadcastDistTensor(tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistTensor(tensor_shape, layout, dtype, mesh); // Use process mesh axis 0 to split tensor axis 1, use process mesh axis 1 to // split tensor axis 0 From 323c13c59f0effc090a683e91f3ef01119ecab09 Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Mon, 24 Jul 2023 13:18:31 +0800 Subject: [PATCH 3/4] refine split functor --- .../auto_parallel/r_to_s_reshard_function.cc | 12 ++- .../auto_parallel/r_to_s_reshard_function.h | 7 +- .../auto_parallel/reshard_function.h | 8 +- .../auto_parallel/reshard_split_functor.cc | 84 +++++++++++-------- .../auto_parallel/reshard_split_functor.h | 32 ++----- paddle/phi/kernels/cpu/split_kernel.cc | 5 +- paddle/phi/kernels/gpu/split_kernel.cu | 5 +- test/cpp/auto_parallel/test_reshard_r_to_s.cc | 41 +++++---- 8 files changed, 103 insertions(+), 91 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index e8d4bae1000e5..a68fa8fc231f6 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" + #include "glog/logging.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" @@ -25,7 +26,7 @@ namespace phi { namespace distributed { namespace auto_parallel { -bool RToSReshardFunction::Check( +bool RToSReshardFunction::IsSuitable( const DistTensor& in, const std::shared_ptr& out_dist_attr) { bool flag = true; @@ -44,12 +45,11 @@ bool RToSReshardFunction::Check( } std::shared_ptr RToSReshardFunction::Eval( - const phi::KernelKey& kernel_key, + const phi::DeviceContext& dev_ctx, const DistTensor& in, const std::shared_ptr& out_dist_attr) { const auto& out_dims_mapping = out_dist_attr->dims_mapping(); const auto& out_process_mesh = out_dist_attr->process_mesh(); - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); const DenseTensor& origin_in_physical_tensor_cur_rank = in.value(); const DenseTensor* in_physical_tensor_cur_rank = &origin_in_physical_tensor_cur_rank; @@ -84,10 +84,8 @@ std::shared_ptr RToSReshardFunction::Eval( IntArray sections(std::vector( num_of_process, in.dims()[split_axis] / num_of_process)); - ReshardSplitFunctor split_func(kernel_key, sections, split_axis); - - std::vector split_out_vec(sections.size(), DenseTensor()); - split_func(*dev_ctx, *in_physical_tensor_cur_rank, &split_out_vec); + std::vector split_out_vec = ReshardSplitFunctor( + dev_ctx, *in_physical_tensor_cur_rank, sections, split_axis); VLOG(3) << "The current process will remain the idx " << coord_in_mesh[mesh_axis] << " piece of tensor"; diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h index 214aebb04bad6..de5ea94692d0c 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h @@ -25,11 +25,12 @@ class RToSReshardFunction final : public ReshardFunction { RToSReshardFunction() = default; ~RToSReshardFunction() = default; - bool Check(const DistTensor& in, - const std::shared_ptr& out_dist_attr) override; + bool IsSuitable( + const DistTensor& in, + const std::shared_ptr& out_dist_attr) override; std::shared_ptr Eval( - const KernelKey& kernel_key, + const DeviceContext& dev_ctx, const DistTensor& in, const std::shared_ptr& out_dist_attr) override; }; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard_function.h index 93b57ee81c6c2..641d42cbef6c2 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.h @@ -16,7 +16,6 @@ #include namespace phi { -class KernelKey; class DeviceContext; namespace distributed { @@ -30,11 +29,12 @@ class ReshardFunction { ReshardFunction() = default; virtual ~ReshardFunction() = default; - virtual bool Check(const DistTensor& in, - const std::shared_ptr& out_dist_attr) = 0; + virtual bool IsSuitable( + const DistTensor& in, + const std::shared_ptr& out_dist_attr) = 0; virtual std::shared_ptr Eval( - const KernelKey& kernel_key, + const DeviceContext& dev_ctx, const DistTensor& in, const std::shared_ptr& out_dist_attr) = 0; }; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc index b6f62f01586e8..b76dfe3f1249f 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc @@ -13,50 +13,66 @@ // limitations under the License. #include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h" -#include "glog/logging.h" -#include "paddle/phi/api/lib/api_gen_utils.h" -#include "paddle/phi/api/lib/kernel_dispatch.h" -#include "paddle/phi/core/kernel_factory.h" -#include "paddle/phi/infermeta/unary.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/split_kernel.h" namespace phi { namespace distributed { namespace auto_parallel { -ReshardSplitFunctor::ReshardSplitFunctor(const phi::KernelKey& kernel_key, - const IntArray& sections, - int64_t axis) - : sections_(sections), axis_(axis) { - KernelResult kernel_result = - phi::KernelFactory::Instance().SelectKernelOrThrowError("split", - kernel_key); - const Kernel& kernel = kernel_result.kernel; - VLOG(3) << "Select split kernel: " << kernel; - functor_ = kernel.GetVariadicKernelFn(); -} +std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, + const DenseTensor& input, + const IntArray& sections, + int64_t axis) { + size_t out_number = sections.size(); + std::vector result(out_number); -void ReshardSplitFunctor::operator()(const DeviceContext& dev_ctx, - const DenseTensor& input, - std::vector* output) { - std::vector out_ptr_vec; - for (size_t i = 0; i < output->size(); ++i) { - out_ptr_vec.emplace_back(&(output->at(i))); - } - PrepareOutput(input, out_ptr_vec); - (*functor_)(dev_ctx, input, sections_, axis_, out_ptr_vec); -} + std::vector out_meta; + std::vector out_meta_ptr; -void ReshardSplitFunctor::PrepareOutput( - const DenseTensor& input, const std::vector& output) { - auto out_meta_vec = paddle::experimental::MakeMetaTensor(output); + out_meta.reserve(out_number); + out_meta_ptr.reserve(out_number); + for (size_t i = 0; i < out_number; ++i) { + out_meta.emplace_back(result[i]); + out_meta_ptr.emplace_back(&out_meta.back()); + } + SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr); - std::vector out_metas(out_meta_vec.size()); - for (size_t i = 0; i < out_meta_vec.size(); ++i) { - out_metas[i] = output[i] ? &out_meta_vec[i] : nullptr; + std::vector outs; + for (size_t i = 0; i < out_number; ++i) { + outs.emplace_back(&result[i]); } - phi::SplitInferMeta( - paddle::experimental::MakeMetaTensor(input), sections_, axis_, out_metas); + if (phi::CPUContext::classof(&dev_ctx)) { + PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] { + SplitKernel( + static_cast(dev_ctx), + input, + sections, + axis, + outs); + })); + return result; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (phi::GPUContext::classof(&dev_ctx)) { + PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] { + SplitKernel( + static_cast(dev_ctx), + input, + sections, + axis, + outs); + })); + return result; + } +#endif + PADDLE_THROW(phi::errors::Unimplemented( + "The split in reshard only supported on CPU and GPU for now.")); } } // namespace auto_parallel diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h index 3081ee3f92bf0..78aa52caf5d7d 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h @@ -16,40 +16,18 @@ #include #include -#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" namespace phi { class DeviceContext; class DenseTensor; -class KernelKey; namespace distributed { namespace auto_parallel { - -class ReshardSplitFunctor final { - public: - using SPLIT_KERNEL_SIG = void (*)(const DeviceContext&, - const DenseTensor&, - const phi::IntArray&, - const phi::Scalar&, - std::vector); - - ReshardSplitFunctor(const KernelKey& kernel_key, - const IntArray& sections, - int64_t axis); - - void operator()(const DeviceContext& dev_ctx, - const DenseTensor& input, - std::vector* output); - - private: - IntArray sections_; - int64_t axis_; - SPLIT_KERNEL_SIG functor_; - - void PrepareOutput(const DenseTensor& input, - const std::vector& output); -}; +std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, + const DenseTensor& input, + const IntArray& sections, + int64_t axis); } // namespace auto_parallel } // namespace distributed diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index f277e0c39f375..13ac7eed3d577 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -29,8 +29,11 @@ PD_REGISTER_KERNEL(split, bool, uint8_t, int8_t, + int16_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_KERNEL(split_with_num, CPU, diff --git a/paddle/phi/kernels/gpu/split_kernel.cu b/paddle/phi/kernels/gpu/split_kernel.cu index 133734621360d..ea140b54eb170 100644 --- a/paddle/phi/kernels/gpu/split_kernel.cu +++ b/paddle/phi/kernels/gpu/split_kernel.cu @@ -29,8 +29,11 @@ PD_REGISTER_KERNEL(split, bool, uint8_t, int8_t, + int16_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::complex64, + phi::complex128) {} PD_REGISTER_KERNEL(split_with_num, GPU, diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc index 1fdbb161beecb..8213223db8f6c 100644 --- a/test/cpp/auto_parallel/test_reshard_r_to_s.cc +++ b/test/cpp/auto_parallel/test_reshard_r_to_s.cc @@ -15,11 +15,11 @@ #include #include "glog/logging.h" #include "gtest/gtest.h" +#include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "paddle/phi/core/kernel_factory.h" #include "test/cpp/phi/core/allocator.h" namespace phi { @@ -28,6 +28,7 @@ namespace auto_parallel { namespace tests { std::shared_ptr ConstructReplicatedDistTensor( + Allocator* alloc, const std::vector& shape, const DataLayout& layout, const DataType& dtype, @@ -36,9 +37,6 @@ std::shared_ptr ConstructReplicatedDistTensor( const LoD lod{}; DenseTensorMeta meta(dtype, dims, layout, lod); - auto fancy_allocator = - std::unique_ptr(new phi::tests::FancyAllocator); - auto* alloc = fancy_allocator.get(); std::shared_ptr dist_attr = std::make_shared(shape); @@ -55,6 +53,9 @@ TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { std::vector tensor_shape = {6, 12}; const DataType dtype{DataType::FLOAT32}; const DataLayout layout{DataLayout::NHWC}; + auto fancy_allocator = + std::unique_ptr(new phi::tests::FancyAllocator); + auto* alloc = fancy_allocator.get(); std::vector mesh_shape = {4}; std::vector process_ids = {0, 1, 2, 3}; @@ -62,7 +63,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructReplicatedDistTensor(tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistTensor(alloc, tensor_shape, layout, dtype, mesh); int64_t split_axis = 1; // Use process mesh axis 0 to split tensor axis 1 @@ -73,12 +74,15 @@ TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { out_dist_attr->set_dims_mapping(out_dims_mapping); out_dist_attr->set_process_mesh(mesh); + phi::CPUPlace cpu_place; + CPUContext dev_ctx(cpu_place); + dev_ctx.SetAllocator(alloc); + RToSReshardFunction r_to_s_func; - KernelKey kernel_key = {Backend::CPU, layout, dtype}; std::shared_ptr output = - r_to_s_func.Eval(kernel_key, *input, out_dist_attr); + r_to_s_func.Eval(dev_ctx, *input, out_dist_attr); - CHECK_EQ(r_to_s_func.Check(*input, out_dist_attr), true); + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); CHECK_EQ(output->numel(), 18); CHECK_EQ(output->dims(), DDim({6, 3})); } @@ -87,6 +91,9 @@ TEST(reshard_r_to_s, r_to_s_diff_placement) { std::vector tensor_shape = {6, 12}; const DataType dtype{DataType::FLOAT32}; const DataLayout layout{DataLayout::NHWC}; + auto fancy_allocator = + std::unique_ptr(new phi::tests::FancyAllocator); + auto* alloc = fancy_allocator.get(); std::vector mesh_shape = {4}; std::vector process_ids = {0, 1, 2, 3}; @@ -94,7 +101,7 @@ TEST(reshard_r_to_s, r_to_s_diff_placement) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructReplicatedDistTensor(tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistTensor(alloc, tensor_shape, layout, dtype, mesh); int64_t split_axis = 1; std::vector out_process_ids = {2, 3, 4, 5}; @@ -107,7 +114,7 @@ TEST(reshard_r_to_s, r_to_s_diff_placement) { out_dist_attr->set_process_mesh(out_mesh); RToSReshardFunction r_to_s_func; - CHECK_EQ(r_to_s_func.Check(*input, out_dist_attr), false); + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false); } TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { @@ -116,6 +123,9 @@ TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { std::vector tensor_shape = {6, 12}; const DataType dtype{DataType::FLOAT32}; const DataLayout layout{DataLayout::NHWC}; + auto fancy_allocator = + std::unique_ptr(new phi::tests::FancyAllocator); + auto* alloc = fancy_allocator.get(); std::vector mesh_shape = {4, 2}; std::vector process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; @@ -123,7 +133,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructReplicatedDistTensor(tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistTensor(alloc, tensor_shape, layout, dtype, mesh); // Use process mesh axis 0 to split tensor axis 1, use process mesh axis 1 to // split tensor axis 0 @@ -133,12 +143,15 @@ TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { out_dist_attr->set_dims_mapping(out_dims_mapping); out_dist_attr->set_process_mesh(mesh); + phi::CPUPlace cpu_place; + CPUContext dev_ctx(cpu_place); + dev_ctx.SetAllocator(alloc); + RToSReshardFunction r_to_s_func; - KernelKey kernel_key = {Backend::CPU, layout, dtype}; std::shared_ptr output = - r_to_s_func.Eval(kernel_key, *input, out_dist_attr); + r_to_s_func.Eval(dev_ctx, *input, out_dist_attr); - CHECK_EQ(r_to_s_func.Check(*input, out_dist_attr), true); + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); CHECK_EQ(output->numel(), 9); CHECK_EQ(output->dims(), DDim({3, 3})); } From 1e363762709631ee69dc58c14f49f35b8c63193b Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Mon, 24 Jul 2023 19:12:19 +0800 Subject: [PATCH 4/4] simplify reshard code --- paddle/fluid/pybind/auto_parallel_py.cc | 4 +- .../auto_parallel/r_to_s_reshard_function.cc | 71 ++++---- .../auto_parallel/r_to_s_reshard_function.h | 2 - .../auto_parallel/reshard_function.cc | 4 +- .../auto_parallel/reshard_function.h | 5 +- .../auto_parallel/reshard_split_functor.cc | 2 - .../auto_parallel/reshard_split_functor.h | 2 - .../auto_parallel/reshard_utils.cc | 3 +- .../distributed/auto_parallel/reshard_utils.h | 4 +- test/cpp/auto_parallel/test_reshard_r_to_s.cc | 156 ++++++++++++------ 10 files changed, 145 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 2d9b81cd17b4f..96c49b4170519 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -111,8 +111,8 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { void BindAutoParallel(py::module *m) { #ifdef PADDLE_WITH_DISTRIBUTE - using phi::distributed::auto_parallel::RToSReshardFunction; - py::class_(*m, "RToSReshardFunction").def(py::init<>()); + py::class_(*m, "RToSReshardFunction") + .def(py::init<>()); #endif py::class_(*m, "ProcessMesh") diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index a68fa8fc231f6..a9db48e631cff 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -24,7 +24,6 @@ namespace phi { namespace distributed { -namespace auto_parallel { bool RToSReshardFunction::IsSuitable( const DistTensor& in, @@ -34,13 +33,17 @@ bool RToSReshardFunction::IsSuitable( const auto& in_dims_mapping = in_dist_attr->dims_mapping(); const auto& out_dims_mapping = out_dist_attr->dims_mapping(); + flag &= IsDimsMappingReplicated(in_dims_mapping); flag &= IsDimsMappingShard(out_dims_mapping); const auto& in_process_mesh = in_dist_attr->process_mesh(); const auto& out_process_mesh = out_dist_attr->process_mesh(); + flag &= (in_process_mesh.ndim() == 1); + flag &= (out_process_mesh.ndim() == 1); flag &= (in_process_mesh == out_process_mesh); + return flag; } @@ -50,9 +53,7 @@ std::shared_ptr RToSReshardFunction::Eval( const std::shared_ptr& out_dist_attr) { const auto& out_dims_mapping = out_dist_attr->dims_mapping(); const auto& out_process_mesh = out_dist_attr->process_mesh(); - const DenseTensor& origin_in_physical_tensor_cur_rank = in.value(); - const DenseTensor* in_physical_tensor_cur_rank = - &origin_in_physical_tensor_cur_rank; + const DenseTensor& in_physical_tensor_cur_rank = in.value(); DenseTensor out_physical_tensor_cur_rank; @@ -60,44 +61,40 @@ std::shared_ptr RToSReshardFunction::Eval( GetSplitAxisWithDimsMapping(out_dims_mapping); std::vector coord_in_mesh = GetCurRankCoordInMesh(out_process_mesh); - for (const auto& iter : split_axis_to_mesh_axis) { - int64_t split_axis = iter.first; - int64_t mesh_axis = iter.second; - - PADDLE_ENFORCE_LT( - mesh_axis, - out_process_mesh.ndim(), - phi::errors::OutOfRange( - "The mesh axis %lld exceed the size of process mesh %lld.", - mesh_axis, - out_process_mesh.ndim())); - - int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; - VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis - << ". Split will use axis " << mesh_axis << " of process_mesh." - << " There will have " << num_of_process - << " process participate in."; - - // TODO(liyurui): Consider the tensor can not be balanced split, - // for example, the shape of tensor is {6} but want to split it by 4 - // process. - IntArray sections(std::vector( - num_of_process, in.dims()[split_axis] / num_of_process)); - - std::vector split_out_vec = ReshardSplitFunctor( - dev_ctx, *in_physical_tensor_cur_rank, sections, split_axis); - - VLOG(3) << "The current process will remain the idx " - << coord_in_mesh[mesh_axis] << " piece of tensor"; - out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; - in_physical_tensor_cur_rank = &out_physical_tensor_cur_rank; - } + int64_t split_axis = split_axis_to_mesh_axis.begin()->first; + int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; + + PADDLE_ENFORCE_LT( + mesh_axis, + out_process_mesh.ndim(), + phi::errors::OutOfRange( + "The mesh axis %lld exceed the size of process mesh %lld.", + mesh_axis, + out_process_mesh.ndim())); + + int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; + VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis + << ". Split will use axis " << mesh_axis << " of process_mesh." + << " There will have " << num_of_process + << " process participate in."; + + // TODO(liyurui): Consider the tensor can not be balanced split, + // for example, the shape of tensor is {6} but want to split it by 4 + // process. + IntArray sections(std::vector( + num_of_process, in.dims()[split_axis] / num_of_process)); + + std::vector split_out_vec = ReshardSplitFunctor( + dev_ctx, in_physical_tensor_cur_rank, sections, split_axis); + + VLOG(3) << "The current process will remain the idx " + << coord_in_mesh[mesh_axis] << " piece of tensor"; + out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; return std::make_shared( std::make_shared(out_physical_tensor_cur_rank), out_dist_attr); } -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h index de5ea94692d0c..61b77820297e4 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h @@ -18,7 +18,6 @@ namespace phi { namespace distributed { -namespace auto_parallel { class RToSReshardFunction final : public ReshardFunction { public: @@ -35,6 +34,5 @@ class RToSReshardFunction final : public ReshardFunction { const std::shared_ptr& out_dist_attr) override; }; -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc index f2b273e3eabbb..04bbc4a09fe1f 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc @@ -18,7 +18,5 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" namespace phi { -namespace distributed { -namespace auto_parallel {} // namespace auto_parallel -} // namespace distributed +namespace distributed {} // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard_function.h index 641d42cbef6c2..2c8574ca376ce 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.h @@ -20,9 +20,11 @@ class DeviceContext; namespace distributed { namespace auto_parallel { - class TensorDistAttr; +} // namespace auto_parallel + class DistTensor; +using auto_parallel::TensorDistAttr; class ReshardFunction { public: @@ -39,6 +41,5 @@ class ReshardFunction { const std::shared_ptr& out_dist_attr) = 0; }; -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc index b76dfe3f1249f..189738b81367f 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc @@ -22,7 +22,6 @@ namespace phi { namespace distributed { -namespace auto_parallel { std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, const DenseTensor& input, @@ -75,6 +74,5 @@ std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, "The split in reshard only supported on CPU and GPU for now.")); } -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h index 78aa52caf5d7d..87b9f2301ad0b 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h @@ -23,12 +23,10 @@ class DeviceContext; class DenseTensor; namespace distributed { -namespace auto_parallel { std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, const DenseTensor& input, const IntArray& sections, int64_t axis); -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index 357ed35918a2d..b777b53c23043 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -13,13 +13,13 @@ // limitations under the License. #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" + #include #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" namespace phi { namespace distributed { -namespace auto_parallel { bool IsDimsMappingShard(const std::vector& dims_mapping) { return std::any_of(dims_mapping.begin(), @@ -80,6 +80,5 @@ std::map GetSplitAxisWithDimsMapping( return split_axis_to_mesh_axis; } -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index 2a8d79053f13a..dceaa5150a6b0 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -23,6 +23,9 @@ namespace distributed { namespace auto_parallel { class ProcessMesh; +} // namespace auto_parallel + +using auto_parallel::ProcessMesh; bool IsDimsMappingShard(const std::vector& dims_mapping); @@ -43,6 +46,5 @@ std::vector GetCurRankCoordInMesh(const ProcessMesh& process_mesh); std::map GetSplitAxisWithDimsMapping( const std::vector& dims_mapping); -} // namespace auto_parallel } // namespace distributed } // namespace phi diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc index 8213223db8f6c..03bd8d247781a 100644 --- a/test/cpp/auto_parallel/test_reshard_r_to_s.cc +++ b/test/cpp/auto_parallel/test_reshard_r_to_s.cc @@ -16,26 +16,35 @@ #include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" -#include "test/cpp/phi/core/allocator.h" +#include "paddle/phi/core/tensor_utils.h" namespace phi { namespace distributed { namespace auto_parallel { namespace tests { -std::shared_ptr ConstructReplicatedDistTensor( - Allocator* alloc, +std::shared_ptr ConstructReplicatedDistCPU( + phi::CPUContext* dev_ctx, const std::vector& shape, - const DataLayout& layout, - const DataType& dtype, const ProcessMesh& mesh) { + phi::CPUPlace cpu_place = dev_ctx->GetPlace(); const DDim dims(shape.data(), shape.size()); - const LoD lod{}; - DenseTensorMeta meta(dtype, dims, layout, lod); + + int64_t num_of_elems = 1; + for (const auto& value : shape) { + num_of_elems *= value; + } + + phi::DenseTensor input_dense; + float* input_dense_ptr = input_dense.mutable_data(dims, cpu_place); + + std::vector vec(num_of_elems); + memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float)); std::shared_ptr dist_attr = std::make_shared(shape); @@ -44,18 +53,50 @@ std::shared_ptr ConstructReplicatedDistTensor( dist_attr->set_dims_mapping(dims_mapping); dist_attr->set_process_mesh(mesh); - return std::make_shared(alloc, meta, dist_attr); + return std::make_shared( + std::make_shared(input_dense), dist_attr); } -TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +std::shared_ptr ConstructReplicatedDistGPU( + phi::GPUContext* dev_ctx, + const std::vector& shape, + const ProcessMesh& mesh) { + phi::GPUPlace gpu_place = dev_ctx->GetPlace(); + phi::CPUPlace cpu_place; + const DDim dims(shape.data(), shape.size()); + + int64_t num_of_elems = 1; + for (const auto& value : shape) { + num_of_elems *= value; + } + + phi::DenseTensor input_dense; + phi::DenseTensor input_dense_gpu; + float* input_dense_ptr = input_dense.mutable_data(dims, cpu_place); + + std::vector vec(num_of_elems); + memcpy(input_dense_ptr, vec.data(), num_of_elems * sizeof(float)); + phi::Copy(*dev_ctx, input_dense, gpu_place, true, &input_dense_gpu); + + std::shared_ptr dist_attr = + std::make_shared(shape); + + std::vector dims_mapping(shape.size(), -1); + dist_attr->set_dims_mapping(dims_mapping); + dist_attr->set_process_mesh(mesh); + + return std::make_shared( + std::make_shared(input_dense_gpu), dist_attr); +} +#endif + +TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) { setenv("PADDLE_TRAINER_ID", "1", 1); - std::vector tensor_shape = {6, 12}; - const DataType dtype{DataType::FLOAT32}; - const DataLayout layout{DataLayout::NHWC}; - auto fancy_allocator = - std::unique_ptr(new phi::tests::FancyAllocator); - auto* alloc = fancy_allocator.get(); + std::vector tensor_shape = {6, 8}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); std::vector mesh_shape = {4}; std::vector process_ids = {0, 1, 2, 3}; @@ -63,37 +104,59 @@ TEST(reshard_r_to_s, r_to_s_same_placement_1d_mesh) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructReplicatedDistTensor(alloc, tensor_shape, layout, dtype, mesh); - int64_t split_axis = 1; + ConstructReplicatedDistCPU(context, tensor_shape, mesh); - // Use process mesh axis 0 to split tensor axis 1 std::shared_ptr out_dist_attr = std::make_shared(tensor_shape); - std::vector out_dims_mapping(tensor_shape.size(), -1); - out_dims_mapping[split_axis] = 0; + std::vector out_dims_mapping = {-1, 0}; out_dist_attr->set_dims_mapping(out_dims_mapping); out_dist_attr->set_process_mesh(mesh); - phi::CPUPlace cpu_place; - CPUContext dev_ctx(cpu_place); - dev_ctx.SetAllocator(alloc); + RToSReshardFunction r_to_s_func; + std::shared_ptr output = + r_to_s_func.Eval(*context, *input, out_dist_attr); + + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); + CHECK_EQ(output->numel(), 12); + CHECK_EQ(output->dims(), DDim({6, 2})); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) { + setenv("PADDLE_TRAINER_ID", "0", 0); + + std::vector tensor_shape = {6, 8, 4}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::GPUPlace())); + + std::vector mesh_shape = {6}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {0, -1}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + std::shared_ptr input = + ConstructReplicatedDistGPU(context, tensor_shape, mesh); RToSReshardFunction r_to_s_func; std::shared_ptr output = - r_to_s_func.Eval(dev_ctx, *input, out_dist_attr); + r_to_s_func.Eval(*context, *input, out_dist_attr); CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); - CHECK_EQ(output->numel(), 18); - CHECK_EQ(output->dims(), DDim({6, 3})); + CHECK_EQ(output->numel(), 32); + CHECK_EQ(output->dims(), DDim({1, 8, 4})); } +#endif TEST(reshard_r_to_s, r_to_s_diff_placement) { - std::vector tensor_shape = {6, 12}; - const DataType dtype{DataType::FLOAT32}; - const DataLayout layout{DataLayout::NHWC}; - auto fancy_allocator = - std::unique_ptr(new phi::tests::FancyAllocator); - auto* alloc = fancy_allocator.get(); + std::vector tensor_shape = {6, 8}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); std::vector mesh_shape = {4}; std::vector process_ids = {0, 1, 2, 3}; @@ -101,15 +164,13 @@ TEST(reshard_r_to_s, r_to_s_diff_placement) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructReplicatedDistTensor(alloc, tensor_shape, layout, dtype, mesh); - int64_t split_axis = 1; + ConstructReplicatedDistCPU(context, tensor_shape, mesh); std::vector out_process_ids = {2, 3, 4, 5}; ProcessMesh out_mesh(mesh_shape, out_process_ids, dim_names); std::shared_ptr out_dist_attr = std::make_shared(tensor_shape); - std::vector out_dims_mapping(tensor_shape.size(), -1); - out_dims_mapping[split_axis] = 0; + std::vector out_dims_mapping = {-1, 0}; out_dist_attr->set_dims_mapping(out_dims_mapping); out_dist_attr->set_process_mesh(out_mesh); @@ -118,14 +179,9 @@ TEST(reshard_r_to_s, r_to_s_diff_placement) { } TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { - setenv("PADDLE_TRAINER_ID", "6", 1); - std::vector tensor_shape = {6, 12}; - const DataType dtype{DataType::FLOAT32}; - const DataLayout layout{DataLayout::NHWC}; - auto fancy_allocator = - std::unique_ptr(new phi::tests::FancyAllocator); - auto* alloc = fancy_allocator.get(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); std::vector mesh_shape = {4, 2}; std::vector process_ids = {0, 1, 2, 3, 4, 5, 6, 7}; @@ -133,27 +189,17 @@ TEST(reshard_r_to_s, r_to_s_same_placement_nd_mesh) { ProcessMesh mesh(mesh_shape, process_ids, dim_names); std::shared_ptr input = - ConstructReplicatedDistTensor(alloc, tensor_shape, layout, dtype, mesh); + ConstructReplicatedDistCPU(context, tensor_shape, mesh); - // Use process mesh axis 0 to split tensor axis 1, use process mesh axis 1 to - // split tensor axis 0 std::shared_ptr out_dist_attr = std::make_shared(tensor_shape); std::vector out_dims_mapping = {1, 0}; out_dist_attr->set_dims_mapping(out_dims_mapping); out_dist_attr->set_process_mesh(mesh); - phi::CPUPlace cpu_place; - CPUContext dev_ctx(cpu_place); - dev_ctx.SetAllocator(alloc); - RToSReshardFunction r_to_s_func; - std::shared_ptr output = - r_to_s_func.Eval(dev_ctx, *input, out_dist_attr); - CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); - CHECK_EQ(output->numel(), 9); - CHECK_EQ(output->dims(), DDim({3, 3})); + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), false); } } // namespace tests