Skip to content

Commit

Permalink
Implement reshard from s to r with same process_mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Aug 9, 2023
1 parent 0434b82 commit 504cf91
Show file tree
Hide file tree
Showing 19 changed files with 459 additions and 11 deletions.
42 changes: 41 additions & 1 deletion paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/pybind.h"

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#endif

namespace py = pybind11;
Expand Down Expand Up @@ -111,7 +115,43 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {

void BindAutoParallel(py::module *m) {
#ifdef PADDLE_WITH_DISTRIBUTE
py::class_<phi::distributed::RToSReshardFunction>(*m, "RToSReshardFunction")
auto ReshardFunction =
py::class_<phi::distributed::ReshardFunction>(*m, "ReshardFunction")
.def(
"is_suitable",
[](phi::distributed::ReshardFunction &self,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
return self.IsSuitable(*p_dist, dist_attr);
},
py::call_guard<py::gil_scoped_release>())
.def(
"eval",
[](phi::distributed::ReshardFunction &self,
phi::DeviceContext *dev_ctx,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
auto res_dist = self.Eval(dev_ctx, *p_dist, dist_attr);
return paddle::Tensor(res_dist);
},
py::call_guard<py::gil_scoped_release>());

py::class_<phi::distributed::RToSReshardFunction>(
*m, "RToSReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SToRReshardFunction>(
*m, "SToRReshardFunction", ReshardFunction)
.def(py::init<>());
#endif

Expand Down
11 changes: 9 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")

if(WITH_DISTRIBUTE)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
reshard_split_functor.cc r_to_s_reshard_function.cc)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
endif()

collect_srcs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ bool RToSReshardFunction::IsSuitable(
}

std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
const phi::DeviceContext& dev_ctx,
phi::DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
Expand Down Expand Up @@ -85,7 +85,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
num_of_process, in.dims()[split_axis] / num_of_process));

std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);

VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RToSReshardFunction final : public ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;

std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/all_gather_kernel.h"

namespace phi {
namespace distributed {

DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids) {
DenseTensor out;

int64_t world_size = process_ids.size();
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids);
dev_ctx->SetCommContext(comm_context);

if (phi::CPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const GPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
#include <vector>

namespace phi {
class DenseTensor;
class DeviceContext;

namespace distributed {

DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids);

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;

virtual std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
};
Expand Down
56 changes: 56 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

#include <cstdlib>
#include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {
Expand Down Expand Up @@ -109,6 +114,21 @@ std::string GetMasterEndpoint() {
return master_endpoint;
}

std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}

int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}

} // namespace

std::string GetMasterAddr() {
Expand All @@ -133,5 +153,41 @@ std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
return store;
}

CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = GenUniqueCommKey(process_ids);

if (!CommContextManager::GetInstance().Has(unique_comm_key)) {
int64_t world_size = process_ids.size();
int64_t rank = GetLocalRankInParticipate(process_ids);
VLOG(3) << "local world size: " << world_size << " local rank: " << rank;

auto store = CreateOrGetGlobalTCPStore();
if (phi::CPUContext::classof(&dev_ctx)) {
#if defined(PADDLE_WITH_GLOO)
CommContextManager::CreateGlooCommContext(
store, unique_comm_key, rank, world_size);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (phi::GPUContext::classof(&dev_ctx)) {
CommContextManager::CreateNCCLCommContext(
store, unique_comm_key, rank, world_size);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"CommContext is only supported on CPU and GPU for now, other devices "
"will be supported later."));
#endif
}
}

auto* comm_context = CommContextManager::GetInstance().Get(unique_comm_key);
return comm_context;
}

} // namespace distributed
} // namespace phi
11 changes: 11 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
class DeviceContext;

namespace distributed {
class CommContext;

namespace auto_parallel {

class ProcessMesh;
Expand All @@ -48,6 +52,13 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);

// Create a comm context of the input process_ids. Once the newly comm context
// created, it will be cached in the global instance, and get from the global
// cache later. If the input dev_ctx is GPU, then nccl comm context will be
// created. If the input dev_ctx is CPU, then gloo comm context will be created.
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);

int64_t GetCurGlobalRank();

std::string GetMasterAddr();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {

bool SToRReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();

const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();

flag &= IsDimsMappingShard(in_dims_mapping);
flag &= IsDimsMappingReplicated(out_dims_mapping);

const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();

flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);

return flag;
}

std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();

// Since the precondition ensure the out_process_ids is equal to the
// in_process_ids, so the participate process ids mush equal to either
// in_process_ids or out_process_ids.
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);

return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather), out_dist_attr);
}

} // namespace distributed
} // namespace phi
Loading

0 comments on commit 504cf91

Please sign in to comment.