Skip to content

Commit

Permalink
support reshard with different mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Sep 22, 2023
1 parent bcc3305 commit 9262541
Show file tree
Hide file tree
Showing 17 changed files with 466 additions and 29 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h"
#include "paddle/phi/core/enforce.h"

#ifdef PADDLE_WITH_DISTRIBUTE
Expand Down Expand Up @@ -200,6 +201,10 @@ void BindAutoParallel(py::module *m) {
*m, "SameNdMeshReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SameStatusReshardFunction>(
*m, "SameStatusReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<ProcessMesh>(*m, "ProcessMesh")
.def(py::init<>())
.def(py::init<const std::vector<int64_t> &,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ collect_srcs(
r_to_p_reshard_function.cc
p_to_r_reshard_function.cc
s_to_s_reshard_function.cc
nd_mesh_reshard_function.cc)
nd_mesh_reshard_function.cc
same_status_reshard_function.cc)
34 changes: 21 additions & 13 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/store/store_utils.h"

namespace phi {
namespace distributed {
Expand All @@ -35,20 +36,27 @@ inline void check_defined(const DistTensor& dist_tensor,
DistTensor::DistTensor(const phi::DenseTensor& global_value,
const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {
if (value_.initialized() && !dist_attr.is_replicated()) {
// 1. create replicated global tensor
int64_t dims_size = global_value.dims().size();
std::vector<int64_t> dims_mapping(dims_size, -1);
dist_attr_.set_dims_mapping(dims_mapping);
if (dist_attr_.is_partial()) {
dist_attr_.clean_partial_status();
// TODO(liyurui): This is a temporary solution. We need to support only infer
// meta when the input dense_tensor is empty.
// Support the value in DistTensor only has DenseTensor meta
// but without actual data. So we can visit its meta attr even if it is
// undefined.
if (IsCurRankInMesh(dist_attr.process_mesh())) {
if (value_.initialized() && !dist_attr.is_replicated()) {
// 1. create replicated global tensor
int64_t dims_size = global_value.dims().size();
std::vector<int64_t> dims_mapping(dims_size, -1);
dist_attr_.set_dims_mapping(dims_mapping);
if (dist_attr_.is_partial()) {
dist_attr_.clean_partial_status();
}
dist_attr_.set_dims_mapping(dims_mapping);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
func->Eval(dev_ctx, *this, dist_attr, this);
}
dist_attr_.set_dims_mapping(dims_mapping);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
func->Eval(dev_ctx, *this, dist_attr, this);
}
}

Expand Down
28 changes: 21 additions & 7 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@ namespace phi {
namespace distributed {

namespace {
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}

std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
Expand All @@ -40,6 +33,20 @@ std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
}
} // namespace

int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids,
int64_t global_rank) {
if (global_rank == -1) {
global_rank = GetCurGlobalRank();
}
auto iter = std::find(process_ids.begin(), process_ids.end(), global_rank);
PADDLE_ENFORCE_NE(
iter,
process_ids.end(),
phi::errors::NotFound("Global rank %lld cannot be found in process_mesh",
global_rank));
return iter - process_ids.begin();
}

std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids();
Expand Down Expand Up @@ -132,5 +139,12 @@ std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) {
return result;
}

bool IsCurRankInMesh(const ProcessMesh& process_mesh) {
int64_t cur_global_rank = GetCurGlobalRank();
const auto& process_ids = process_mesh.process_ids();
return (std::find(process_ids.begin(), process_ids.end(), cur_global_rank) !=
process_ids.end());
}

} // namespace distributed
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class DeviceContext;
namespace distributed {
class ProcessMesh;

bool IsCurRankInMesh(const ProcessMesh& process_mesh);

int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids,
int64_t global_rank = -1);

// Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1].
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h"

#include <algorithm>

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/kernels/p_recv_kernel.h"
#include "paddle/phi/kernels/p_send_kernel.h"

namespace phi {
namespace distributed {

namespace {

std::vector<int64_t> GetUnionProcessIds(std::vector<int64_t> in_process_ids,
std::vector<int64_t> out_process_ids) {
std::vector<int64_t> result;
std::sort(in_process_ids.begin(), in_process_ids.end());
std::sort(out_process_ids.begin(), out_process_ids.end());
std::set_union(in_process_ids.begin(),
in_process_ids.end(),
out_process_ids.begin(),
out_process_ids.end(),
std::back_inserter(result));
return result;
}

} // namespace

bool SameStatusReshardFunction::IsSuitable(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();

flag &= (in_dist_attr.dims_mapping() == out_dist_attr.dims_mapping());
flag &= (in_dist_attr.partial_dims() == out_dist_attr.partial_dims());

const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();
flag &= (in_process_mesh != out_process_mesh);
flag &= (in_process_mesh.shape() == out_process_mesh.shape());

return flag;
}

void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
const auto& out_process_mesh = out_dist_attr.process_mesh();
const auto& out_process_ids = out_process_mesh.process_ids();
auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids);
auto dtype = in.dtype();
// TODO(liyurui): Use dynamic shape will lead to poor performance, but we
// don't have any other good idea now. For the following reasons:
// 1. We can not ensure the meta being right deduce by the infermeta.
// 2. The meta of some kernels can't decide in compile time.
// 3. DenseTensor with empty value only need infermeta and skip the real
// kernel execution.
bool dynamic_shape = true;

std::vector<std::pair<int64_t, int64_t>> p2p_pair;
for (size_t i = 0; i < out_process_ids.size(); ++i) {
p2p_pair.emplace_back(
std::make_pair(in_process_ids[i], out_process_ids[i]));
}

int64_t cur_global_rank = GetCurGlobalRank();
for (const auto& iter : p2p_pair) {
int64_t src = iter.first;
int64_t dst = iter.second;
VLOG(3) << "Send/Recv from src " << src << " to dst " << dst;
if (src == cur_global_rank) {
int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst);
// Sice send kernel only has input, so we don't need to infermeta
// actually. According to this reason, just use the kernel directly.
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PSendKernel,
dtype,
all_process_ids,
in.value(),
dst_local_rank,
dynamic_shape);
} else if (dst == cur_global_rank) {
int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src);
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PRecv,
dtype,
all_process_ids,
src_local_rank,
dynamic_shape,
GetMutableTensor(out));
}
}
SetDistProps(out, in.dims(), out_dist_attr);
}

REGISTER_RESHARD_FUNC(SameStatusReshardFunction);

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"

namespace phi {
namespace distributed {

class SameStatusReshardFunction final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};

} // namespace distributed
} // namespace phi
14 changes: 6 additions & 8 deletions paddle/phi/core/distributed/store/store_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,17 @@ std::string GetMasterEndpoint() {

int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
if (cur_rank == nullptr) {
return 0;
}
return std::atoi(cur_rank);
}

int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
if (world_size == nullptr) {
return 1;
}
return std::atoi(world_size);
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/p_recv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ PD_REGISTER_KERNEL(p_recv,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}

Expand All @@ -67,5 +68,6 @@ PD_REGISTER_KERNEL(p_recv_array,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/p_send_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ PD_REGISTER_KERNEL(p_send,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}

Expand All @@ -66,5 +67,6 @@ PD_REGISTER_KERNEL(p_send_array,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/p_recv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ PD_REGISTER_KERNEL(p_recv,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
Expand Down Expand Up @@ -218,6 +219,7 @@ PD_REGISTER_KERNEL(p_recv,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}

Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/p_send_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ PD_REGISTER_KERNEL(p_send,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
Expand Down Expand Up @@ -206,6 +207,7 @@ PD_REGISTER_KERNEL(p_send,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}

Expand Down
Loading

0 comments on commit 9262541

Please sign in to comment.