Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reshard] Implement reshard from s to r with same process_mesh #56039

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/pybind.h"

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#endif

namespace py = pybind11;
Expand Down Expand Up @@ -111,7 +115,43 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {

void BindAutoParallel(py::module *m) {
#ifdef PADDLE_WITH_DISTRIBUTE
py::class_<phi::distributed::RToSReshardFunction>(*m, "RToSReshardFunction")
auto ReshardFunction =
py::class_<phi::distributed::ReshardFunction>(*m, "ReshardFunction")
.def(
"is_suitable",
[](phi::distributed::ReshardFunction &self,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
return self.IsSuitable(*p_dist, dist_attr);
},
py::call_guard<py::gil_scoped_release>())
.def(
"eval",
[](phi::distributed::ReshardFunction &self,
phi::DeviceContext *dev_ctx,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
auto res_dist = self.Eval(dev_ctx, *p_dist, dist_attr);
return paddle::Tensor(res_dist);
},
py::call_guard<py::gil_scoped_release>());

py::class_<phi::distributed::RToSReshardFunction>(
*m, "RToSReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SToRReshardFunction>(
*m, "SToRReshardFunction", ReshardFunction)
.def(py::init<>());
#endif

Expand Down
11 changes: 9 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")

if(WITH_DISTRIBUTE)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
reshard_split_functor.cc r_to_s_reshard_function.cc)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
endif()

collect_srcs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ bool RToSReshardFunction::IsSuitable(
}

std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
const phi::DeviceContext& dev_ctx,
phi::DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
Expand Down Expand Up @@ -85,7 +85,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
num_of_process, in.dims()[split_axis] / num_of_process));

std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);

VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RToSReshardFunction final : public ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;

std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/all_gather_kernel.h"

namespace phi {
namespace distributed {

DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids) {
DenseTensor out;

int64_t world_size = process_ids.size();
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids);
dev_ctx->SetCommContext(comm_context);

if (phi::CPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const GPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
#include <vector>

namespace phi {
class DenseTensor;
class DeviceContext;

namespace distributed {

DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids);

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;

virtual std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
};
Expand Down
56 changes: 56 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

#include <cstdlib>
#include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {
Expand Down Expand Up @@ -109,6 +114,21 @@ std::string GetMasterEndpoint() {
return master_endpoint;
}

std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仅仅支持一维process mesh吗?如果高维和低维都是一个process id编号,key是否相同?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数只负责把传入的process_ids的vector变成唯一的comm_key,有两种情况:

  1. 如果输入输出的process_mesh相同,这种可以直接调用集合通信操作,只要把展平的process_id传入就行。
  2. 如果输入输出的process_mesh不同,需要在调用函数前,结合具体情况,分组创建通信组,这时候一般创建的是点对点通信组。

不管是高维还是低维,只要它们参与通信的进程相同,key就是相同的

std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}

int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}

} // namespace

std::string GetMasterAddr() {
Expand All @@ -133,5 +153,41 @@ std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
return store;
}

CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = GenUniqueCommKey(process_ids);

if (!CommContextManager::GetInstance().Has(unique_comm_key)) {
int64_t world_size = process_ids.size();
int64_t rank = GetLocalRankInParticipate(process_ids);
VLOG(3) << "local world size: " << world_size << " local rank: " << rank;

auto store = CreateOrGetGlobalTCPStore();
if (phi::CPUContext::classof(&dev_ctx)) {
#if defined(PADDLE_WITH_GLOO)
CommContextManager::CreateGlooCommContext(
store, unique_comm_key, rank, world_size);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (phi::GPUContext::classof(&dev_ctx)) {
CommContextManager::CreateNCCLCommContext(
store, unique_comm_key, rank, world_size);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"CommContext is only supported on CPU and GPU for now, other devices "
"will be supported later."));
#endif
}
}

auto* comm_context = CommContextManager::GetInstance().Get(unique_comm_key);
return comm_context;
}

} // namespace distributed
} // namespace phi
11 changes: 11 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
class DeviceContext;

namespace distributed {
class CommContext;

namespace auto_parallel {

class ProcessMesh;
Expand All @@ -48,6 +52,13 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);

// Create a comm context of the input process_ids. Once the newly comm context
// created, it will be cached in the global instance, and get from the global
// cache later. If the input dev_ctx is GPU, then nccl comm context will be
// created. If the input dev_ctx is CPU, then gloo comm context will be created.
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);

int64_t GetCurGlobalRank();

std::string GetMasterAddr();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {

bool SToRReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();

const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();

flag &= IsDimsMappingShard(in_dims_mapping);
flag &= IsDimsMappingReplicated(out_dims_mapping);

const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();

flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);

return flag;
}

std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();

// Since the precondition ensure the out_process_ids is equal to the
// in_process_ids, so the participate process ids mush equal to either
// in_process_ids or out_process_ids.
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);

return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather), out_dist_attr);
}

} // namespace distributed
} // namespace phi
Loading